opensportslib 0.0.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opensportslib/__init__.py +18 -0
- opensportslib/apis/__init__.py +21 -0
- opensportslib/apis/classification.py +361 -0
- opensportslib/apis/localization.py +228 -0
- opensportslib/config/classification.yaml +104 -0
- opensportslib/config/classification_tracking.yaml +103 -0
- opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
- opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
- opensportslib/config/localization.yaml +132 -0
- opensportslib/config/sngar_frames.yaml +98 -0
- opensportslib/core/__init__.py +0 -0
- opensportslib/core/loss/__init__.py +0 -0
- opensportslib/core/loss/builder.py +40 -0
- opensportslib/core/loss/calf.py +258 -0
- opensportslib/core/loss/ce.py +23 -0
- opensportslib/core/loss/combine.py +42 -0
- opensportslib/core/loss/nll.py +25 -0
- opensportslib/core/optimizer/__init__.py +0 -0
- opensportslib/core/optimizer/builder.py +38 -0
- opensportslib/core/sampler/weighted_sampler.py +104 -0
- opensportslib/core/scheduler/__init__.py +0 -0
- opensportslib/core/scheduler/builder.py +77 -0
- opensportslib/core/trainer/__init__.py +0 -0
- opensportslib/core/trainer/classification_trainer.py +1131 -0
- opensportslib/core/trainer/localization_trainer.py +1009 -0
- opensportslib/core/utils/checkpoint.py +238 -0
- opensportslib/core/utils/config.py +199 -0
- opensportslib/core/utils/data.py +85 -0
- opensportslib/core/utils/ddp.py +77 -0
- opensportslib/core/utils/default_args.py +110 -0
- opensportslib/core/utils/load_annotations.py +485 -0
- opensportslib/core/utils/seed.py +26 -0
- opensportslib/core/utils/video_processing.py +389 -0
- opensportslib/core/utils/wandb.py +110 -0
- opensportslib/datasets/__init__.py +0 -0
- opensportslib/datasets/builder.py +42 -0
- opensportslib/datasets/classification_dataset.py +582 -0
- opensportslib/datasets/localization_dataset.py +813 -0
- opensportslib/datasets/utils/__init__.py +15 -0
- opensportslib/datasets/utils/tracking.py +615 -0
- opensportslib/metrics/classification_metric.py +176 -0
- opensportslib/metrics/localization_metric.py +1482 -0
- opensportslib/models/__init__.py +0 -0
- opensportslib/models/backbones/builder.py +590 -0
- opensportslib/models/base/e2e.py +252 -0
- opensportslib/models/base/tracking.py +73 -0
- opensportslib/models/base/vars.py +29 -0
- opensportslib/models/base/video.py +130 -0
- opensportslib/models/base/video_mae.py +60 -0
- opensportslib/models/builder.py +43 -0
- opensportslib/models/heads/builder.py +266 -0
- opensportslib/models/neck/builder.py +210 -0
- opensportslib/models/utils/common.py +176 -0
- opensportslib/models/utils/impl/__init__.py +0 -0
- opensportslib/models/utils/impl/asformer.py +390 -0
- opensportslib/models/utils/impl/calf.py +74 -0
- opensportslib/models/utils/impl/gsm.py +112 -0
- opensportslib/models/utils/impl/gtad.py +347 -0
- opensportslib/models/utils/impl/tsm.py +123 -0
- opensportslib/models/utils/litebase.py +59 -0
- opensportslib/models/utils/modules.py +120 -0
- opensportslib/models/utils/shift.py +135 -0
- opensportslib/models/utils/utils.py +276 -0
- opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
- opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
- opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
- opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,813 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import random
|
|
4
|
+
from torch.utils.data import Dataset
|
|
5
|
+
from nvidia.dali import pipeline_def, backend
|
|
6
|
+
import nvidia.dali.fn as fn
|
|
7
|
+
import nvidia.dali.types as types
|
|
8
|
+
from nvidia.dali.plugin.pytorch import DALIGenericIterator
|
|
9
|
+
import tempfile
|
|
10
|
+
import cupy
|
|
11
|
+
import copy
|
|
12
|
+
import math
|
|
13
|
+
import numpy as np
|
|
14
|
+
from opensportslib.core.utils.default_args import get_default_args_dataset
|
|
15
|
+
from opensportslib.core.utils.load_annotations import get_repartition_gpu
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LocalizationDataset(Dataset):
|
|
19
|
+
def __init__(self, config, annotations_path=None, processor=None, split="train"):
|
|
20
|
+
self.config = config
|
|
21
|
+
self.split = split
|
|
22
|
+
self.config.TRAIN.repartitions = get_repartition_gpu(self.config.SYSTEM.GPU)
|
|
23
|
+
if split == "train":
|
|
24
|
+
self.cfg = self.config.DATA.train
|
|
25
|
+
self.default_args = get_default_args_dataset("train", self.config)
|
|
26
|
+
elif split == "valid":
|
|
27
|
+
self.cfg = self.config.DATA.valid
|
|
28
|
+
self.default_args = get_default_args_dataset("valid", self.config)
|
|
29
|
+
elif split == "test":
|
|
30
|
+
self.cfg = self.config.DATA.test
|
|
31
|
+
self.default_args = get_default_args_dataset("test", self.config)
|
|
32
|
+
elif split == "valid_data_frames":
|
|
33
|
+
self.cfg = self.config.DATA.valid_data_frames
|
|
34
|
+
self.default_args = get_default_args_dataset("valid_data_frames", self.config)
|
|
35
|
+
#self.built_dataset = self.building_dataset(cfg=cfg, default_args=default_args)
|
|
36
|
+
#self.data_loader = self.building_dataloader(self.built_dataset, cfg=cfg.dataloader, gpu=0, dali=True)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def building_dataset(self, cfg, gpu=None, default_args=None):
|
|
40
|
+
print(cfg)
|
|
41
|
+
if cfg.type == "VideoGameWithDali":
|
|
42
|
+
loader_batch_size = cfg.dataloader.batch_size // default_args["acc_grad_iter"]
|
|
43
|
+
dataset_len = self.config.DATA.epoch_num_frames // self.config.DATA.clip_len
|
|
44
|
+
dataset = DaliDataSet(
|
|
45
|
+
epochs=default_args["num_epochs"],
|
|
46
|
+
batch_size=loader_batch_size,
|
|
47
|
+
output_map=cfg.output_map,
|
|
48
|
+
devices=(
|
|
49
|
+
default_args["repartitions"][0]
|
|
50
|
+
if default_args["train"]
|
|
51
|
+
else default_args["repartitions"][1]
|
|
52
|
+
),
|
|
53
|
+
#devices=list(range(gpu)),
|
|
54
|
+
classes=default_args["classes"],
|
|
55
|
+
label_file=cfg.path,
|
|
56
|
+
modality=self.config.DATA.modality,
|
|
57
|
+
clip_len=self.config.DATA.clip_len,
|
|
58
|
+
dataset_len=dataset_len if default_args["train"] else dataset_len // 4,
|
|
59
|
+
video_dir=cfg.video_path,
|
|
60
|
+
input_fps=self.config.DATA.input_fps,
|
|
61
|
+
extract_fps=self.config.DATA.extract_fps,
|
|
62
|
+
IMAGENET_MEAN=self.config.DATA.imagenet_mean,
|
|
63
|
+
IMAGENET_STD=self.config.DATA.imagenet_std,
|
|
64
|
+
TARGET_HEIGHT=self.config.DATA.target_height,
|
|
65
|
+
TARGET_WIDTH=self.config.DATA.target_width,
|
|
66
|
+
is_eval=False if default_args["train"] else True,
|
|
67
|
+
crop_dim=self.config.DATA.crop_dim,
|
|
68
|
+
dilate_len=self.config.DATA.dilate_len,
|
|
69
|
+
mixup=self.config.DATA.mixup,
|
|
70
|
+
)
|
|
71
|
+
elif cfg.type == "VideoGameWithDaliVideo":
|
|
72
|
+
dataset = DaliDataSetVideo(
|
|
73
|
+
batch_size=cfg.dataloader.batch_size,
|
|
74
|
+
output_map=cfg.output_map,
|
|
75
|
+
#devices=list(range(gpu)),
|
|
76
|
+
devices=default_args["repartitions"][1],
|
|
77
|
+
classes=default_args["classes"],
|
|
78
|
+
label_file=cfg.path,
|
|
79
|
+
modality=self.config.DATA.modality,
|
|
80
|
+
clip_len=self.config.DATA.clip_len,
|
|
81
|
+
video_dir=cfg.video_path,
|
|
82
|
+
input_fps=self.config.DATA.input_fps,
|
|
83
|
+
extract_fps=self.config.DATA.extract_fps,
|
|
84
|
+
IMAGENET_MEAN=self.config.DATA.imagenet_mean,
|
|
85
|
+
IMAGENET_STD=self.config.DATA.imagenet_std,
|
|
86
|
+
TARGET_HEIGHT=self.config.DATA.target_height,
|
|
87
|
+
TARGET_WIDTH=self.config.DATA.target_width,
|
|
88
|
+
overlap_len=cfg.overlap_len,
|
|
89
|
+
crop_dim=self.config.DATA.crop_dim,
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
dataset = None
|
|
93
|
+
return dataset
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def building_dataloader(self, dataset, cfg, gpu, dali):
|
|
97
|
+
"""Build a dataloader from config dict.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
cfg (dict): Config dict. It should at least contain the key "type".
|
|
101
|
+
default_args (dict | None, optional): Default initialization arguments.
|
|
102
|
+
Default: None.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Dataloader: The constructed dataloader.
|
|
106
|
+
"""
|
|
107
|
+
def worker_init_fn(id):
|
|
108
|
+
random.seed(id + 100 * 100)
|
|
109
|
+
if dali:
|
|
110
|
+
return dataset
|
|
111
|
+
dataloader = torch.utils.data.DataLoader(
|
|
112
|
+
dataset,
|
|
113
|
+
batch_size=cfg.batch_size,
|
|
114
|
+
shuffle=cfg.shuffle,
|
|
115
|
+
num_workers=cfg.num_workers if gpu >= 0 else 0,
|
|
116
|
+
pin_memory=cfg.pin_memory if gpu >= 0 else False,
|
|
117
|
+
prefetch_factor=(
|
|
118
|
+
cfg.prefetch_factor if "prefetch_factor" in cfg.keys() else None
|
|
119
|
+
),
|
|
120
|
+
worker_init_fn=worker_init_fn
|
|
121
|
+
)
|
|
122
|
+
return dataloader
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class DatasetVideoSharedMethods:
|
|
126
|
+
def get_labels(self, video):
|
|
127
|
+
"""Get labels of a video.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
video (string): Name of the video.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
labels (np.array): Array of length being the number of frame with elements being the index of the class.
|
|
134
|
+
"""
|
|
135
|
+
meta = self._labels[self._video_idxs[video]]
|
|
136
|
+
num_frames = meta["num_frames"]
|
|
137
|
+
num_labels = num_frames // self._stride
|
|
138
|
+
if num_frames % self._stride != 0:
|
|
139
|
+
num_labels += 1
|
|
140
|
+
labels = np.zeros(num_labels, np.int64)
|
|
141
|
+
for event in meta["events"]:
|
|
142
|
+
frame = event["frame"]
|
|
143
|
+
if frame < num_frames:
|
|
144
|
+
labels[frame // self._stride] = self._class_dict[event["label"]]
|
|
145
|
+
return labels
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def augment(self):
|
|
149
|
+
"""Whether flip or multi cropping have been applied to frames or not."""
|
|
150
|
+
return self._flip or self._multi_crop
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def videos(self):
|
|
154
|
+
"""Return a list containing metadatas of videos sorted by their names."""
|
|
155
|
+
# return [
|
|
156
|
+
# (v['video'], v['num_frames_dali'] // self._stride,
|
|
157
|
+
# v['fps'] / self._stride) for v in self._labels]
|
|
158
|
+
return sorted(
|
|
159
|
+
[
|
|
160
|
+
(
|
|
161
|
+
v["path"],
|
|
162
|
+
# os.path.splitext(v["path"])[0],
|
|
163
|
+
v["num_frames"] // self._stride,
|
|
164
|
+
v["fps"] / self._stride,
|
|
165
|
+
)
|
|
166
|
+
for v in self._labels
|
|
167
|
+
]
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def labels(self):
|
|
172
|
+
"""Return the metadatas containing in the json file."""
|
|
173
|
+
assert self._stride > 0
|
|
174
|
+
if self._stride == 1:
|
|
175
|
+
return self._labels
|
|
176
|
+
else:
|
|
177
|
+
labels = []
|
|
178
|
+
for x in self._labels:
|
|
179
|
+
x_copy = copy.deepcopy(x)
|
|
180
|
+
x_copy["fps"] /= self._stride
|
|
181
|
+
x_copy["num_frames"] //= self._stride
|
|
182
|
+
for e in x_copy["events"]:
|
|
183
|
+
e["frame"] //= self._stride
|
|
184
|
+
labels.append(x_copy)
|
|
185
|
+
return labels
|
|
186
|
+
|
|
187
|
+
def print_info(self):
|
|
188
|
+
num_frames = sum([x["num_frames"] for x in self._labels])
|
|
189
|
+
num_events = sum([len(x["events"]) for x in self._labels])
|
|
190
|
+
print(
|
|
191
|
+
"{} : {} videos, {} frames ({} stride), {:0.5f}% non-bg".format(
|
|
192
|
+
self._src_file,
|
|
193
|
+
len(self._labels),
|
|
194
|
+
num_frames,
|
|
195
|
+
self._stride,
|
|
196
|
+
num_events / num_frames * 100,
|
|
197
|
+
)
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class DaliDataSet(DALIGenericIterator):
|
|
202
|
+
"""Class that overrides DALIGenericIterator class. This class is to prepare training data using nvidia dali.
|
|
203
|
+
Training data consists of frames, associated labels and a boolean indicating if the clip of frames contains an event.
|
|
204
|
+
A training sample can be mixed up with another one if mixup is used or not.
|
|
205
|
+
In particular, a training sample contains the following informations without mixup:
|
|
206
|
+
"frame": The frames.
|
|
207
|
+
"contains_event": True if event occurs within these frames, False otherwise.
|
|
208
|
+
"label": The labels associated to the frames.
|
|
209
|
+
and the following informations with mixup:
|
|
210
|
+
"frame": A combination of the frames of the first video and the second one.
|
|
211
|
+
"contains_event": True if event occurs within these frames, False otherwise.
|
|
212
|
+
"label": Rearrangement of the labels of each video.
|
|
213
|
+
"mix_frame": Frames of the second video.
|
|
214
|
+
"mix_weight": The weight that have been used for mixing frames and labels.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
epochs (int): Number of training epochs.
|
|
218
|
+
batch_size (int).
|
|
219
|
+
output_map (List[string]): List of strings which maps consecutive outputs of DALI pipelines to user specified name. Outputs will be returned from iterator as dictionary of those names. Each name should be distinct.
|
|
220
|
+
devices (list[int]): List of indexes of gpu to use.
|
|
221
|
+
classes (dict): dict of class names to idx.
|
|
222
|
+
label_file (list[string]|string): Paths to label jsons. Can be a single json file or a list of json files.
|
|
223
|
+
clip_len (int): Length of a clip of frames.
|
|
224
|
+
dataset_len (int): Number of clips.
|
|
225
|
+
video_dir (list[string]|string): Paths to folder where videos are located. Can be a single folder file or a list of folders. Must match the number of json files.
|
|
226
|
+
input_fps (int): Fps of the input videos.
|
|
227
|
+
extract_fps (int): Fps at which we extract the frames.
|
|
228
|
+
is_eval (bool): Disable random augmentation
|
|
229
|
+
Default: True.
|
|
230
|
+
crop_dim (int): The dimension for cropping frames.
|
|
231
|
+
Default: None.
|
|
232
|
+
dilate_len (int): Dilate ground truth labels.
|
|
233
|
+
Default: 0.
|
|
234
|
+
mixup (bool): Whether to mixup clips of two videos or not.
|
|
235
|
+
Default: False.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
def __init__(
|
|
239
|
+
self,
|
|
240
|
+
epochs,
|
|
241
|
+
batch_size,
|
|
242
|
+
output_map,
|
|
243
|
+
devices,
|
|
244
|
+
classes,
|
|
245
|
+
label_file,
|
|
246
|
+
modality,
|
|
247
|
+
clip_len,
|
|
248
|
+
dataset_len,
|
|
249
|
+
video_dir,
|
|
250
|
+
input_fps,
|
|
251
|
+
extract_fps,
|
|
252
|
+
IMAGENET_MEAN,
|
|
253
|
+
IMAGENET_STD,
|
|
254
|
+
TARGET_HEIGHT,
|
|
255
|
+
TARGET_WIDTH,
|
|
256
|
+
is_eval=True,
|
|
257
|
+
crop_dim=None,
|
|
258
|
+
dilate_len=0,
|
|
259
|
+
mixup=False,
|
|
260
|
+
):
|
|
261
|
+
import random
|
|
262
|
+
from opensportslib.core.utils.load_annotations import annotationstoe2eformat
|
|
263
|
+
from opensportslib.core.utils.video_processing import distribute_elements, _get_deferred_rgb_transform, get_stride
|
|
264
|
+
|
|
265
|
+
self._src_file = label_file
|
|
266
|
+
# self._labels = load_json(label_file)
|
|
267
|
+
self._labels, self.task_name = annotationstoe2eformat(
|
|
268
|
+
label_file, video_dir, input_fps, extract_fps, True
|
|
269
|
+
)
|
|
270
|
+
self._class_dict = classes
|
|
271
|
+
self.original_batch_size = batch_size
|
|
272
|
+
|
|
273
|
+
if mixup:
|
|
274
|
+
self.batch_size = 2 * batch_size
|
|
275
|
+
else:
|
|
276
|
+
self.batch_size = batch_size
|
|
277
|
+
|
|
278
|
+
self.batch_size_per_pipe = distribute_elements(self.batch_size, len(devices))
|
|
279
|
+
|
|
280
|
+
self.batch_size = batch_size
|
|
281
|
+
self.nb_videos = dataset_len * 2 if mixup else dataset_len
|
|
282
|
+
self.mixup = mixup
|
|
283
|
+
self.output_map = output_map
|
|
284
|
+
self.devices = devices
|
|
285
|
+
self.is_eval = is_eval
|
|
286
|
+
self.crop_dim = crop_dim
|
|
287
|
+
self.dilate_len = dilate_len
|
|
288
|
+
self.clip_len = clip_len
|
|
289
|
+
self.IMAGENET_MEAN = IMAGENET_MEAN
|
|
290
|
+
self.IMAGENET_STD = IMAGENET_STD
|
|
291
|
+
self.TARGET_HEIGHT = TARGET_HEIGHT
|
|
292
|
+
self.TARGET_WIDTH = TARGET_WIDTH
|
|
293
|
+
|
|
294
|
+
self._stride = get_stride(input_fps, extract_fps)
|
|
295
|
+
|
|
296
|
+
if is_eval:
|
|
297
|
+
nb_clips_per_video = math.ceil(dataset_len / len(self._labels)) * epochs
|
|
298
|
+
else:
|
|
299
|
+
nb_clips_per_video = math.ceil(dataset_len / len(self._labels)) * epochs
|
|
300
|
+
|
|
301
|
+
if mixup:
|
|
302
|
+
nb_clips_per_video = nb_clips_per_video * 2
|
|
303
|
+
|
|
304
|
+
file_list_txt = ""
|
|
305
|
+
for index, video in enumerate(self._labels):
|
|
306
|
+
video_path = video["video"]
|
|
307
|
+
#print("video_path :", video_path)
|
|
308
|
+
# video_path = os.path.join(video_dir, video["video"] + extension)
|
|
309
|
+
for _ in range(nb_clips_per_video):
|
|
310
|
+
#print(video["num_frames"], (clip_len + 1))
|
|
311
|
+
random_start = random.randint(1, video["num_frames"] - (clip_len + 1))
|
|
312
|
+
file_list_txt += f"{video_path} {index} {random_start * self._stride} {(random_start+clip_len) * self._stride}\n"
|
|
313
|
+
|
|
314
|
+
tf = tempfile.NamedTemporaryFile()
|
|
315
|
+
tf.write(str.encode(file_list_txt))
|
|
316
|
+
tf.flush()
|
|
317
|
+
|
|
318
|
+
self.pipes = [
|
|
319
|
+
self.video_pipe(
|
|
320
|
+
batch_size=self.batch_size_per_pipe[index],
|
|
321
|
+
sequence_length=self.clip_len,
|
|
322
|
+
stride_dali=self._stride,
|
|
323
|
+
step=-1,
|
|
324
|
+
num_threads=8,
|
|
325
|
+
device_id=i,
|
|
326
|
+
file_list=tf.name,
|
|
327
|
+
shard_id=index,
|
|
328
|
+
num_shards=len(devices),
|
|
329
|
+
)
|
|
330
|
+
for index, i in enumerate(devices)
|
|
331
|
+
]
|
|
332
|
+
|
|
333
|
+
for pipe in self.pipes:
|
|
334
|
+
pipe.build()
|
|
335
|
+
|
|
336
|
+
super().__init__(self.pipes, output_map, size=self.nb_videos)
|
|
337
|
+
|
|
338
|
+
self.device = torch.device(
|
|
339
|
+
"cuda:{}".format(self.devices[1 if len(self.devices) > 1 else 0])
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
self.gpu_transform = None
|
|
343
|
+
if not self.is_eval:
|
|
344
|
+
self.gpu_transform = _get_deferred_rgb_transform(self.IMAGENET_MEAN, self.IMAGENET_STD)
|
|
345
|
+
# self.gpu_transform = self.get_deferred_rgb_transform()
|
|
346
|
+
|
|
347
|
+
def __next__(self):
|
|
348
|
+
out = super().__next__()
|
|
349
|
+
ret = self.getitem(out)
|
|
350
|
+
if self.is_eval:
|
|
351
|
+
frame = ret["frame"]
|
|
352
|
+
else:
|
|
353
|
+
frame = self.load_frame_deferred(self.gpu_transform, ret)
|
|
354
|
+
return {"frame": frame, "label": ret["label"]}
|
|
355
|
+
|
|
356
|
+
def delete(self):
|
|
357
|
+
"""Useful method to free memory used by gpu when the dataset is no longer needed."""
|
|
358
|
+
for pipe in self.pipes:
|
|
359
|
+
pipe.__del__()
|
|
360
|
+
del pipe
|
|
361
|
+
backend.ReleaseUnusedMemory()
|
|
362
|
+
|
|
363
|
+
def get_attr(self, batch):
|
|
364
|
+
"""Return a dictionnary containing attributes of the batch.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
batch (dict).
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
dict :{"frames","contains_event","labels"}.
|
|
371
|
+
"""
|
|
372
|
+
batch_labels = batch["label"]
|
|
373
|
+
batch_images = batch["data"]
|
|
374
|
+
sum_labels = torch.sum(
|
|
375
|
+
batch_labels, dim=1 if len(batch_labels.shape) == 2 else 0
|
|
376
|
+
)
|
|
377
|
+
contains_event = (sum_labels > 0).int()
|
|
378
|
+
return {
|
|
379
|
+
"frame": batch_images,
|
|
380
|
+
"contains_event": contains_event,
|
|
381
|
+
"label": batch_labels,
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
def move_to_device(self, batch):
|
|
385
|
+
"""Move all tensors of the batch to a device. Useful since samples are handled by different gpus in a first time.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
batch : Batch containing samples that are located on different gpus.
|
|
389
|
+
"""
|
|
390
|
+
for key, tensor in batch.items():
|
|
391
|
+
batch[key] = tensor.to(self.device)
|
|
392
|
+
|
|
393
|
+
def getitem(self, data):
|
|
394
|
+
"""Construct and return a batch. Mixup clips of two videos if mixup is true.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
data: List of samples that are located on different gpus.
|
|
398
|
+
"""
|
|
399
|
+
nb_devices = len(self.devices)
|
|
400
|
+
if nb_devices == 1:
|
|
401
|
+
ret = self.get_attr(data[0])
|
|
402
|
+
if nb_devices >= 2:
|
|
403
|
+
ret = self.get_attr(data[0])
|
|
404
|
+
mix = self.get_attr(data[1])
|
|
405
|
+
self.move_to_device(ret)
|
|
406
|
+
self.move_to_device(mix)
|
|
407
|
+
|
|
408
|
+
if nb_devices >= 4:
|
|
409
|
+
ret2 = self.get_attr(data[2])
|
|
410
|
+
mix2 = self.get_attr(data[3])
|
|
411
|
+
self.move_to_device(ret2)
|
|
412
|
+
self.move_to_device(mix2)
|
|
413
|
+
|
|
414
|
+
if self.mixup:
|
|
415
|
+
if nb_devices == 1:
|
|
416
|
+
mix = {}
|
|
417
|
+
for key, tensor in ret.items():
|
|
418
|
+
ret[key], mix[key] = torch.chunk(tensor, 2, dim=0)
|
|
419
|
+
if nb_devices >= 4:
|
|
420
|
+
for key, tensor in ret.items():
|
|
421
|
+
ret[key] = torch.cat((tensor, ret2[key]))
|
|
422
|
+
for key, tensor in mix.items():
|
|
423
|
+
mix[key] = torch.cat((tensor, mix2[key]))
|
|
424
|
+
|
|
425
|
+
l = [random.betavariate(0.2, 0.2) for i in range(ret["frame"].shape[0])]
|
|
426
|
+
l = torch.tensor(l)
|
|
427
|
+
label_dist = torch.zeros(
|
|
428
|
+
(ret["frame"].shape[0], self.clip_len, len(self._class_dict) + 1),
|
|
429
|
+
device=self.device,
|
|
430
|
+
)
|
|
431
|
+
for i in range(ret["frame"].shape[0]):
|
|
432
|
+
label_dist[i, range(self.clip_len), ret["label"][i]] = l[i].item()
|
|
433
|
+
label_dist[i, range(self.clip_len), mix["label"][i]] += (
|
|
434
|
+
1.0 - l[i].item()
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
if self.gpu_transform is None:
|
|
438
|
+
for i in range(ret["frame"].shape[0]):
|
|
439
|
+
ret["frame"][i] = (
|
|
440
|
+
l[i].item() * ret["frame"][i]
|
|
441
|
+
+ (1.0 - l[i].item()) * mix["frame"][i]
|
|
442
|
+
)
|
|
443
|
+
else:
|
|
444
|
+
ret["mix_frame"] = mix["frame"]
|
|
445
|
+
ret["mix_weight"] = l
|
|
446
|
+
|
|
447
|
+
ret["contains_event"] = torch.max(
|
|
448
|
+
ret["contains_event"], mix["contains_event"]
|
|
449
|
+
)
|
|
450
|
+
ret["label"] = label_dist
|
|
451
|
+
else:
|
|
452
|
+
if nb_devices >= 4:
|
|
453
|
+
for key, tensor in ret.items():
|
|
454
|
+
ret[key] = torch.cat((tensor, mix[key], ret2[key], mix2[key]))
|
|
455
|
+
elif nb_devices >= 2:
|
|
456
|
+
for key, tensor in ret.items():
|
|
457
|
+
ret[key] = torch.cat((tensor, mix[key]))
|
|
458
|
+
return ret
|
|
459
|
+
|
|
460
|
+
def load_frame_deferred(self, gpu_transform, batch):
|
|
461
|
+
"""Load frames on the device and applying some transforms.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
gpu_transform : Transform to apply to the frames.
|
|
465
|
+
batch : Batch containing the frames and possibly some other datas as
|
|
466
|
+
"mix_weight" and "mix_frame" is mixup is applied while processing videos.
|
|
467
|
+
device : The device on which we load the data.
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
frame (torch.tensor).
|
|
471
|
+
"""
|
|
472
|
+
frame = batch["frame"]
|
|
473
|
+
with torch.no_grad():
|
|
474
|
+
for i in range(frame.shape[0]):
|
|
475
|
+
frame[i] = gpu_transform(frame[i])
|
|
476
|
+
|
|
477
|
+
if "mix_weight" in batch:
|
|
478
|
+
weight = batch["mix_weight"].to(self.device)
|
|
479
|
+
# weight = batch['mix_weight'].to(torch.device('cuda:0'))
|
|
480
|
+
frame *= weight[:, None, None, None, None]
|
|
481
|
+
|
|
482
|
+
frame_mix = batch["mix_frame"]
|
|
483
|
+
for i in range(frame.shape[0]):
|
|
484
|
+
frame[i] += (1.0 - weight[i]) * gpu_transform(frame_mix[i])
|
|
485
|
+
|
|
486
|
+
return frame
|
|
487
|
+
|
|
488
|
+
@pipeline_def
|
|
489
|
+
def video_pipe(
|
|
490
|
+
self, file_list, sequence_length, stride_dali, step, shard_id, num_shards
|
|
491
|
+
):
|
|
492
|
+
"""Construct the pipeline to process a video. This pipeline process a clip with specified arguments such as stride,step and sequence length.
|
|
493
|
+
The first step returns clip of frames with associated labels (index of the clip in the list of clips) and the index of the first frame.
|
|
494
|
+
The second step is the cropping, mirroring (only if non eval) and normalizing the frames.
|
|
495
|
+
The last step is to construct the list of labels (corresponding to events) corresponding with the extracted frames.
|
|
496
|
+
|
|
497
|
+
Args:
|
|
498
|
+
file_list (string): Path to the file with a list of <file label [start_frame [end_frame]]> values.
|
|
499
|
+
sequence_length (int): Frames to load per sequence.
|
|
500
|
+
stride_dali (int): Distance between consecutive frames in the sequence.
|
|
501
|
+
step(int): Frame interval between each sequence.
|
|
502
|
+
shard_id (int): Index of the shard to read.
|
|
503
|
+
num_shards (int): Partitions the data into the specified number of parts.
|
|
504
|
+
|
|
505
|
+
Returns:
|
|
506
|
+
video (torch.tensor): The frames processed.
|
|
507
|
+
label : the list of labels (corresponding to events) corresponding with the extracted frames.
|
|
508
|
+
"""
|
|
509
|
+
video, label, frame_num = fn.readers.video_resize(
|
|
510
|
+
device="gpu",
|
|
511
|
+
size=(self.TARGET_HEIGHT, self.TARGET_WIDTH),
|
|
512
|
+
file_list=file_list,
|
|
513
|
+
sequence_length=sequence_length,
|
|
514
|
+
random_shuffle=True,
|
|
515
|
+
shard_id=shard_id,
|
|
516
|
+
num_shards=num_shards,
|
|
517
|
+
image_type=types.RGB,
|
|
518
|
+
file_list_include_preceding_frame=True,
|
|
519
|
+
file_list_frame_num=True,
|
|
520
|
+
enable_frame_num=True,
|
|
521
|
+
stride=stride_dali,
|
|
522
|
+
step=step,
|
|
523
|
+
pad_sequences=True,
|
|
524
|
+
skip_vfr_check=True,
|
|
525
|
+
)
|
|
526
|
+
if self.is_eval:
|
|
527
|
+
video = fn.crop_mirror_normalize(
|
|
528
|
+
video,
|
|
529
|
+
dtype=types.FLOAT,
|
|
530
|
+
# crop = self.crop_dim,
|
|
531
|
+
crop=(self.crop_dim, self.crop_dim) if self.crop_dim != None else None,
|
|
532
|
+
out_of_bounds_policy="trim_to_shape",
|
|
533
|
+
output_layout="FCHW",
|
|
534
|
+
mean=[self.IMAGENET_MEAN[i] * 255.0 for i in range(len(self.IMAGENET_MEAN))],
|
|
535
|
+
std=[self.IMAGENET_STD[i] * 255.0 for i in range(len(self.IMAGENET_STD))],
|
|
536
|
+
)
|
|
537
|
+
else:
|
|
538
|
+
video = fn.crop_mirror_normalize(
|
|
539
|
+
video,
|
|
540
|
+
dtype=types.FLOAT,
|
|
541
|
+
output_layout="FCHW",
|
|
542
|
+
# crop = self.crop_dim,
|
|
543
|
+
crop=(self.crop_dim, self.crop_dim) if self.crop_dim != None else None,
|
|
544
|
+
out_of_bounds_policy="trim_to_shape",
|
|
545
|
+
# crop_w=self.crop_dim, crop_h=self.crop_dim,
|
|
546
|
+
std=[255, 255, 255],
|
|
547
|
+
mirror=fn.random.coin_flip(),
|
|
548
|
+
)
|
|
549
|
+
label = fn.python_function(
|
|
550
|
+
label, frame_num, function=self.edit_labels, device="gpu"
|
|
551
|
+
)
|
|
552
|
+
return video, label
|
|
553
|
+
|
|
554
|
+
def edit_labels(self, label, frame_num):
|
|
555
|
+
"""Construct a list having the same length as the number of frames. The elements of the list are the indexes (starting at 1) of the class where an event occurs, 0 otherwise.
|
|
556
|
+
|
|
557
|
+
Args:
|
|
558
|
+
label :index of the video to get the metadata.
|
|
559
|
+
frame_num :index of start frame.
|
|
560
|
+
|
|
561
|
+
Returns:
|
|
562
|
+
labels (cupy.array): the list of labels (corresponding to events) corresponding with the extracted frames.
|
|
563
|
+
"""
|
|
564
|
+
video_meta = self._labels[label.item()]
|
|
565
|
+
base_idx = frame_num.item() // self._stride
|
|
566
|
+
labels = cupy.zeros(self.clip_len, np.int64)
|
|
567
|
+
|
|
568
|
+
for event in video_meta["events"]:
|
|
569
|
+
event_frame = event["frame"]
|
|
570
|
+
# Index of event in label array
|
|
571
|
+
label_idx = (event_frame - base_idx) // 1
|
|
572
|
+
if (
|
|
573
|
+
label_idx >= self.dilate_len
|
|
574
|
+
and label_idx < self.clip_len + self.dilate_len
|
|
575
|
+
):
|
|
576
|
+
label = self._class_dict[event["label"]]
|
|
577
|
+
for i in range(
|
|
578
|
+
max(0, label_idx - self.dilate_len),
|
|
579
|
+
min(self.clip_len, label_idx + self.dilate_len + 1),
|
|
580
|
+
):
|
|
581
|
+
labels[i] = label
|
|
582
|
+
return labels
|
|
583
|
+
|
|
584
|
+
def print_info(self):
|
|
585
|
+
from core.utils.config import _print_info_helper
|
|
586
|
+
_print_info_helper(self._src_file, self._labels)
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
class DaliDataSetVideo(DALIGenericIterator, DatasetVideoSharedMethods):
|
|
590
|
+
"""Class that overrides DALIGenericIterator class. This class is to prepare testing data using nvidia dali.
|
|
591
|
+
Testing data consists of frames, the name of the video and index of the first frame in the video.
|
|
592
|
+
This class can process as input a json file containing metadatas of video or just a video.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
batch_size (int).
|
|
596
|
+
output_map (List[string]): List of strings which maps consecutive outputs of DALI pipelines to user specified name. Outputs will be returned from iterator as dictionary of those names. Each name should be distinct.
|
|
597
|
+
devices (list[int]): List of indexes of gpu to use.
|
|
598
|
+
classes (dict): dict of class names to idx.
|
|
599
|
+
label_file (string): Can be path to label json or path of a video.
|
|
600
|
+
clip_len (int): Length of a clip of frames.
|
|
601
|
+
video_dir (string): path to folder where videos are located.
|
|
602
|
+
input_fps (int): Fps of the input videos.
|
|
603
|
+
extract_fps (int): The fps at which we extract frames. This variable is used if dataset is a single video.
|
|
604
|
+
overlap_len (int): The number of overlapping frames between consecutive clips.
|
|
605
|
+
Default: 0.
|
|
606
|
+
crop_dim (int): The dimension for cropping frames.
|
|
607
|
+
Default: None.
|
|
608
|
+
flip (bool): Whether to flip or not the frames.
|
|
609
|
+
Default: False.
|
|
610
|
+
multi_crop (bool): Whether multi croping or not
|
|
611
|
+
Default: False.
|
|
612
|
+
"""
|
|
613
|
+
|
|
614
|
+
def __init__(
|
|
615
|
+
self,
|
|
616
|
+
batch_size,
|
|
617
|
+
output_map,
|
|
618
|
+
devices,
|
|
619
|
+
classes,
|
|
620
|
+
label_file,
|
|
621
|
+
modality,
|
|
622
|
+
clip_len,
|
|
623
|
+
video_dir,
|
|
624
|
+
input_fps,
|
|
625
|
+
extract_fps,
|
|
626
|
+
IMAGENET_MEAN,
|
|
627
|
+
IMAGENET_STD,
|
|
628
|
+
TARGET_HEIGHT,
|
|
629
|
+
TARGET_WIDTH,
|
|
630
|
+
overlap_len=0,
|
|
631
|
+
crop_dim=None,
|
|
632
|
+
flip=False,
|
|
633
|
+
multi_crop=False,
|
|
634
|
+
):
|
|
635
|
+
import random
|
|
636
|
+
from opensportslib.core.utils.load_annotations import annotationstoe2eformat, construct_labels
|
|
637
|
+
from opensportslib.core.utils.video_processing import distribute_elements, _get_deferred_rgb_transform, get_stride, get_remaining
|
|
638
|
+
self._src_file = label_file
|
|
639
|
+
# self.infer = False
|
|
640
|
+
if label_file.endswith(".json"):
|
|
641
|
+
self._labels, self.task_name = annotationstoe2eformat(
|
|
642
|
+
label_file, video_dir, input_fps, extract_fps, True
|
|
643
|
+
)
|
|
644
|
+
stride_dali = get_stride(input_fps, extract_fps)
|
|
645
|
+
# self._labels = load_json(label_file)
|
|
646
|
+
else:
|
|
647
|
+
# self.infer = True
|
|
648
|
+
self._labels, stride_dali = construct_labels(label_file, extract_fps)
|
|
649
|
+
# self._labels = self._labels[:3]
|
|
650
|
+
self._class_dict = classes
|
|
651
|
+
self._video_idxs = {x["path"]: i for i, x in enumerate(self._labels)}
|
|
652
|
+
self._clip_len = clip_len
|
|
653
|
+
self.crop_dim = crop_dim
|
|
654
|
+
stride = 1
|
|
655
|
+
self._stride = stride
|
|
656
|
+
self._flip = flip
|
|
657
|
+
self._multi_crop = multi_crop
|
|
658
|
+
self.batch_size = batch_size // len(devices)
|
|
659
|
+
self.devices = devices
|
|
660
|
+
self._clips = []
|
|
661
|
+
self.IMAGENET_MEAN = IMAGENET_MEAN
|
|
662
|
+
self.IMAGENET_STD = IMAGENET_STD
|
|
663
|
+
self.TARGET_HEIGHT = TARGET_HEIGHT
|
|
664
|
+
self.TARGET_WIDTH = TARGET_WIDTH
|
|
665
|
+
file_list_txt = ""
|
|
666
|
+
cmp = 0
|
|
667
|
+
for l in self._labels:
|
|
668
|
+
has_clip = False
|
|
669
|
+
for i in range(
|
|
670
|
+
1,
|
|
671
|
+
l[
|
|
672
|
+
"num_frames"
|
|
673
|
+
], # Need to ensure that all clips have at least one frame
|
|
674
|
+
(clip_len - overlap_len) * self._stride,
|
|
675
|
+
):
|
|
676
|
+
if i + clip_len > l["num_frames"]:
|
|
677
|
+
end = l["num_frames_base"]
|
|
678
|
+
else:
|
|
679
|
+
end = (i + clip_len) * stride_dali
|
|
680
|
+
has_clip = True
|
|
681
|
+
self._clips.append((l["path"], l["video"], i))
|
|
682
|
+
# if self.infer:
|
|
683
|
+
# video_path = l["video"]
|
|
684
|
+
# else:
|
|
685
|
+
# video_path = os.path.join(video_dir, l["video"] + extension)
|
|
686
|
+
video_path = l["video"]
|
|
687
|
+
file_list_txt += f"{video_path} {cmp} {i * stride_dali} {end}\n"
|
|
688
|
+
# if cmp2 <5:
|
|
689
|
+
# print(file_list_txt)
|
|
690
|
+
# cmp2+=1
|
|
691
|
+
cmp += 1
|
|
692
|
+
last_video = l["video"]
|
|
693
|
+
last_path = l["path"]
|
|
694
|
+
assert has_clip, l
|
|
695
|
+
|
|
696
|
+
x = get_remaining(len(self._clips), batch_size)
|
|
697
|
+
for _ in range(x):
|
|
698
|
+
self._clips.append((last_path, last_video, i))
|
|
699
|
+
# if self.infer:
|
|
700
|
+
# video_path = l["video"]
|
|
701
|
+
# else:
|
|
702
|
+
# video_path = os.path.join(video_dir, l["video"] + extension)
|
|
703
|
+
video_path = l["video"]
|
|
704
|
+
file_list_txt += f"{video_path} {cmp} {i * stride_dali} {end}\n"
|
|
705
|
+
cmp += 1
|
|
706
|
+
# print(file_list_txt)
|
|
707
|
+
tf = tempfile.NamedTemporaryFile()
|
|
708
|
+
tf.write(str.encode(file_list_txt))
|
|
709
|
+
tf.flush()
|
|
710
|
+
|
|
711
|
+
self.pipes = [
|
|
712
|
+
self.video_pipe(
|
|
713
|
+
batch_size=self.batch_size,
|
|
714
|
+
sequence_length=self._clip_len,
|
|
715
|
+
stride_dali=stride_dali,
|
|
716
|
+
step=-1,
|
|
717
|
+
num_threads=8,
|
|
718
|
+
device_id=i,
|
|
719
|
+
file_list=tf.name,
|
|
720
|
+
shard_id=index,
|
|
721
|
+
num_shards=len(devices),
|
|
722
|
+
)
|
|
723
|
+
for index, i in enumerate(devices)
|
|
724
|
+
]
|
|
725
|
+
|
|
726
|
+
for pipe in self.pipes:
|
|
727
|
+
pipe.build()
|
|
728
|
+
|
|
729
|
+
size = len(self._clips)
|
|
730
|
+
|
|
731
|
+
super().__init__(self.pipes, output_map, size=size)
|
|
732
|
+
|
|
733
|
+
def __next__(self):
|
|
734
|
+
out = super().__next__()
|
|
735
|
+
video_names = []
|
|
736
|
+
starts = cupy.zeros(len(self.devices) * self.batch_size, np.int64)
|
|
737
|
+
cmp = 0
|
|
738
|
+
for j in range(len(out)):
|
|
739
|
+
for i in range(out[j]["label"].shape[0]):
|
|
740
|
+
video_path, video_name, start = self._clips[out[j]["label"][i]]
|
|
741
|
+
video_names.append(video_path)
|
|
742
|
+
starts[cmp] = start
|
|
743
|
+
cmp += 1
|
|
744
|
+
return {
|
|
745
|
+
"video": video_names,
|
|
746
|
+
"start": torch.as_tensor(starts),
|
|
747
|
+
"frame": torch.cat(
|
|
748
|
+
([data["data"].to(torch.device("cuda")) for data in out])
|
|
749
|
+
),
|
|
750
|
+
}
|
|
751
|
+
|
|
752
|
+
def delete(self):
|
|
753
|
+
"""Useful method to free memory used by gpu when the dataset is no longer needed."""
|
|
754
|
+
for pipe in self.pipes:
|
|
755
|
+
pipe.__del__()
|
|
756
|
+
del pipe
|
|
757
|
+
backend.ReleaseUnusedMemory()
|
|
758
|
+
|
|
759
|
+
@pipeline_def
|
|
760
|
+
def video_pipe(
|
|
761
|
+
self, file_list, sequence_length, stride_dali, step, shard_id, num_shards
|
|
762
|
+
):
|
|
763
|
+
"""Construct the pipeline to process a video. This pipeline process a clip with specified arguments such as stride,step and sequence length.
|
|
764
|
+
The first step returns clip of frames with associated labels (index of the clip in the list of clips) and the index of the first frame.
|
|
765
|
+
The second step is the cropping, mirroring (only if non eval) and normalizing the frames.
|
|
766
|
+
|
|
767
|
+
Args:
|
|
768
|
+
file_list (string): Path to the file with a list of <file label [start_frame [end_frame]]> values.
|
|
769
|
+
sequence_length (int): Frames to load per sequence.
|
|
770
|
+
stride_dali (int): Distance between consecutive frames in the sequence.
|
|
771
|
+
step(int): Frame interval between each sequence.
|
|
772
|
+
shard_id (int): Index of the shard to read.
|
|
773
|
+
num_shards (int): Partitions the data into the specified number of parts.
|
|
774
|
+
|
|
775
|
+
Returns:
|
|
776
|
+
video (torch.tensor): The frames processed.
|
|
777
|
+
label : the index of the clip in the list of clips.
|
|
778
|
+
"""
|
|
779
|
+
video, label = fn.readers.video_resize(
|
|
780
|
+
device="gpu",
|
|
781
|
+
size=(self.TARGET_HEIGHT, self.TARGET_WIDTH),
|
|
782
|
+
file_list=file_list,
|
|
783
|
+
sequence_length=sequence_length,
|
|
784
|
+
random_shuffle=False,
|
|
785
|
+
shard_id=shard_id,
|
|
786
|
+
num_shards=num_shards,
|
|
787
|
+
image_type=types.RGB,
|
|
788
|
+
file_list_include_preceding_frame=True,
|
|
789
|
+
file_list_frame_num=True,
|
|
790
|
+
stride=stride_dali,
|
|
791
|
+
step=step,
|
|
792
|
+
pad_sequences=True,
|
|
793
|
+
skip_vfr_check=True,
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
video = fn.crop_mirror_normalize(
|
|
797
|
+
video,
|
|
798
|
+
dtype=types.FLOAT,
|
|
799
|
+
output_layout="FCHW",
|
|
800
|
+
crop=(self.crop_dim, self.crop_dim) if self.crop_dim != None else None,
|
|
801
|
+
out_of_bounds_policy="trim_to_shape",
|
|
802
|
+
mean=[self.IMAGENET_MEAN[i] * 255.0 for i in range(len(self.IMAGENET_MEAN))],
|
|
803
|
+
std=[self.IMAGENET_STD[i] * 255.0 for i in range(len(self.IMAGENET_STD))],
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
return video, label
|
|
807
|
+
|
|
808
|
+
def get_dims(video):
|
|
809
|
+
print(video.shape)
|
|
810
|
+
|
|
811
|
+
|
|
812
|
+
if __name__ == "__main__":
|
|
813
|
+
LocalizationDataset(config="/home/vorajv/opensportslib-ml/opensportslib/opensportslib/config/localization.yaml")
|