opensportslib 0.0.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opensportslib/__init__.py +18 -0
- opensportslib/apis/__init__.py +21 -0
- opensportslib/apis/classification.py +361 -0
- opensportslib/apis/localization.py +228 -0
- opensportslib/config/classification.yaml +104 -0
- opensportslib/config/classification_tracking.yaml +103 -0
- opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
- opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
- opensportslib/config/localization.yaml +132 -0
- opensportslib/config/sngar_frames.yaml +98 -0
- opensportslib/core/__init__.py +0 -0
- opensportslib/core/loss/__init__.py +0 -0
- opensportslib/core/loss/builder.py +40 -0
- opensportslib/core/loss/calf.py +258 -0
- opensportslib/core/loss/ce.py +23 -0
- opensportslib/core/loss/combine.py +42 -0
- opensportslib/core/loss/nll.py +25 -0
- opensportslib/core/optimizer/__init__.py +0 -0
- opensportslib/core/optimizer/builder.py +38 -0
- opensportslib/core/sampler/weighted_sampler.py +104 -0
- opensportslib/core/scheduler/__init__.py +0 -0
- opensportslib/core/scheduler/builder.py +77 -0
- opensportslib/core/trainer/__init__.py +0 -0
- opensportslib/core/trainer/classification_trainer.py +1131 -0
- opensportslib/core/trainer/localization_trainer.py +1009 -0
- opensportslib/core/utils/checkpoint.py +238 -0
- opensportslib/core/utils/config.py +199 -0
- opensportslib/core/utils/data.py +85 -0
- opensportslib/core/utils/ddp.py +77 -0
- opensportslib/core/utils/default_args.py +110 -0
- opensportslib/core/utils/load_annotations.py +485 -0
- opensportslib/core/utils/seed.py +26 -0
- opensportslib/core/utils/video_processing.py +389 -0
- opensportslib/core/utils/wandb.py +110 -0
- opensportslib/datasets/__init__.py +0 -0
- opensportslib/datasets/builder.py +42 -0
- opensportslib/datasets/classification_dataset.py +582 -0
- opensportslib/datasets/localization_dataset.py +813 -0
- opensportslib/datasets/utils/__init__.py +15 -0
- opensportslib/datasets/utils/tracking.py +615 -0
- opensportslib/metrics/classification_metric.py +176 -0
- opensportslib/metrics/localization_metric.py +1482 -0
- opensportslib/models/__init__.py +0 -0
- opensportslib/models/backbones/builder.py +590 -0
- opensportslib/models/base/e2e.py +252 -0
- opensportslib/models/base/tracking.py +73 -0
- opensportslib/models/base/vars.py +29 -0
- opensportslib/models/base/video.py +130 -0
- opensportslib/models/base/video_mae.py +60 -0
- opensportslib/models/builder.py +43 -0
- opensportslib/models/heads/builder.py +266 -0
- opensportslib/models/neck/builder.py +210 -0
- opensportslib/models/utils/common.py +176 -0
- opensportslib/models/utils/impl/__init__.py +0 -0
- opensportslib/models/utils/impl/asformer.py +390 -0
- opensportslib/models/utils/impl/calf.py +74 -0
- opensportslib/models/utils/impl/gsm.py +112 -0
- opensportslib/models/utils/impl/gtad.py +347 -0
- opensportslib/models/utils/impl/tsm.py +123 -0
- opensportslib/models/utils/litebase.py +59 -0
- opensportslib/models/utils/modules.py +120 -0
- opensportslib/models/utils/shift.py +135 -0
- opensportslib/models/utils/utils.py +276 -0
- opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
- opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
- opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
- opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,582 @@
|
|
|
1
|
+
# opensportslib/datasets/classification_dataset.py
|
|
2
|
+
|
|
3
|
+
"""classification dataset implmentations for video and tracking modalities.
|
|
4
|
+
|
|
5
|
+
provides three concrete dataset classes:
|
|
6
|
+
|
|
7
|
+
- VideoDataset (MVFoul, SN-GAR video)
|
|
8
|
+
- TrackingDataset (SN-GAR tracking / parquet)
|
|
9
|
+
|
|
10
|
+
both inherit from ClassificationDataset, which handles annotation loading,
|
|
11
|
+
label mapping, and class-weight computation.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import os
|
|
15
|
+
import random
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
from torch.utils.data import Dataset
|
|
20
|
+
from tqdm import tqdm
|
|
21
|
+
|
|
22
|
+
from opensportslib.core.utils.load_annotations import load_annotations
|
|
23
|
+
from opensportslib.core.utils.video_processing import *
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# -------------------------------------------------------------
|
|
27
|
+
# factory
|
|
28
|
+
# -------------------------------------------------------------
|
|
29
|
+
|
|
30
|
+
def build(config, annotations_path, processor=None, split="train"):
|
|
31
|
+
"""construct the appropriate dataset for the configured modality.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
config: the loaded YAML configuration.
|
|
35
|
+
annotations_path: path to the annotation JSON file.
|
|
36
|
+
processor: HuggingFace image processor (video modality only).
|
|
37
|
+
split: one of "train", "valid", "test".
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
a Dataset instance (VideoDataset or TrackingDataset).
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ValueError: if the data_modality is not recognized.
|
|
44
|
+
"""
|
|
45
|
+
modality = config.DATA.data_modality.lower()
|
|
46
|
+
|
|
47
|
+
if modality == "tracking_parquet":
|
|
48
|
+
return TrackingDataset(config, annotations_path, split)
|
|
49
|
+
elif modality in ("video", "frames_npy"):
|
|
50
|
+
return VideoDataset(config, annotations_path, processor, split)
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError(f"Unknown data_modality: {modality}")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# -------------------------------------------------------------
|
|
56
|
+
# base class
|
|
57
|
+
# -------------------------------------------------------------
|
|
58
|
+
|
|
59
|
+
class ClassificationDataset(Dataset):
|
|
60
|
+
"""shared base for all classification datasets.
|
|
61
|
+
|
|
62
|
+
loads annotations, builds a label map, and exposes helpers for
|
|
63
|
+
computing sample-level and class-level weights (useful for
|
|
64
|
+
balanced sampling and weighted loss respectively).
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
config: the loaded YAML configuration.
|
|
68
|
+
annotations_path: path to the annotation JSON file.
|
|
69
|
+
split: one of "train", "valid", "test".
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(self, config, annotations_path, processor, split="train"):
|
|
73
|
+
self.config = config
|
|
74
|
+
self.split = split
|
|
75
|
+
self.exclude_labels = ["Unknown", "Dont know"]
|
|
76
|
+
self.data_dir = config.DATA.data_dir
|
|
77
|
+
self.processor = None
|
|
78
|
+
|
|
79
|
+
# view_type is optional; only MVFoul uses it as of now
|
|
80
|
+
is_multiview = getattr(config.DATA, "view_type", None) == "multi"
|
|
81
|
+
|
|
82
|
+
allow_missing_labels = split in ["test", "infer"]
|
|
83
|
+
|
|
84
|
+
self.samples, self.label_map = load_annotations(
|
|
85
|
+
annotations_path,
|
|
86
|
+
exclude_labels=self.exclude_labels,
|
|
87
|
+
multiview=is_multiview,
|
|
88
|
+
input_type=config.DATA.data_modality,
|
|
89
|
+
allow_missing_labels=allow_missing_labels
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
max_samples = getattr(config.DATA, 'max_samples', None)
|
|
93
|
+
if max_samples:
|
|
94
|
+
self.samples = self.samples[:max_samples]
|
|
95
|
+
|
|
96
|
+
# invert to id -> name and propagate into the config so
|
|
97
|
+
# downstream components (metrics, logging) can look it up.
|
|
98
|
+
self.label_map = {v: k for k, v in self.label_map.items()}
|
|
99
|
+
self.config.DATA.classes = list(self.label_map.values())
|
|
100
|
+
self.config.DATA.num_classes = len(self.label_map)
|
|
101
|
+
|
|
102
|
+
print(self.config.DATA.num_classes, "classes:", self.config.DATA.classes)
|
|
103
|
+
print("Label Map : ", self.label_map)
|
|
104
|
+
|
|
105
|
+
self.has_labels = len(self.samples) > 0 and "label" in self.samples[0]
|
|
106
|
+
|
|
107
|
+
# -- Sampling / loss weights ------------------------------------------
|
|
108
|
+
|
|
109
|
+
def get_sample_weights(self):
|
|
110
|
+
"""per-sample inverse-frequency weights for WeightedRandomSampler.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
torch.Tensor of length len(self) with one weight per sample.
|
|
114
|
+
"""
|
|
115
|
+
labels = [item["label"] for item in self.samples]
|
|
116
|
+
|
|
117
|
+
class_counts = torch.bincount(torch.tensor(labels))
|
|
118
|
+
class_weights = 1.0 / class_counts.float()
|
|
119
|
+
sample_weights = torch.tensor(
|
|
120
|
+
[class_weights[label] for label in labels],
|
|
121
|
+
dtype=torch.float
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return sample_weights
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def get_class_weights(self, num_classes=None, normalize=True, sqrt=False):
|
|
128
|
+
"""per-class inverse-frequency weights for WeightedRandomSampler.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
num_classes: if None, inferred from the label tensor.
|
|
132
|
+
normalize: if True, weights are scaled so they sum to num_classes.
|
|
133
|
+
sqrt: if True, use inverse square-root frequency instead of raw counts.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
torch.Tensor of shape (num_classes,).
|
|
137
|
+
"""
|
|
138
|
+
labels = torch.tensor([item["label"] for item in self.samples])
|
|
139
|
+
|
|
140
|
+
if num_classes is None:
|
|
141
|
+
num_classes = int(labels.max().item() + 1)
|
|
142
|
+
|
|
143
|
+
counts = torch.bincount(labels, minlength=num_classes).float()
|
|
144
|
+
counts[counts == 0] = 1.0 # avoid division by zero for unseen classes
|
|
145
|
+
|
|
146
|
+
weights = 1.0 / torch.sqrt(counts) if sqrt else 1.0 / counts
|
|
147
|
+
|
|
148
|
+
if normalize:
|
|
149
|
+
weights = weights / weights.sum() * num_classes
|
|
150
|
+
|
|
151
|
+
return weights
|
|
152
|
+
|
|
153
|
+
def __len__(self):
|
|
154
|
+
return len(self.samples)
|
|
155
|
+
|
|
156
|
+
def num_classes(self):
|
|
157
|
+
return len(self.label_map)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
# -------------------------------------------------------------
|
|
161
|
+
# video modality
|
|
162
|
+
# -------------------------------------------------------------
|
|
163
|
+
|
|
164
|
+
class VideoDataset(ClassificationDataset):
|
|
165
|
+
"""frame-sampled video clips for classification.
|
|
166
|
+
|
|
167
|
+
for MVFoul: supports single-view and multi-view modes. In multi-view
|
|
168
|
+
training, two views are randomly sampled per clip; at test time all available
|
|
169
|
+
views are returned and stacked along a view dimension.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
config: the loaded YAML configuration.
|
|
173
|
+
annotations_path: path to the annotation JSON file.
|
|
174
|
+
processor: HuggingFace image processor (used only for HuggingFace models).
|
|
175
|
+
split: one of "train", "valid", "test".
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
def __init__(self, config, annotations_path, processor, split="train"):
|
|
179
|
+
super().__init__(config, annotations_path, split)
|
|
180
|
+
|
|
181
|
+
self.processor = processor
|
|
182
|
+
self.view_type = getattr(config.DATA, "view_type", "single")
|
|
183
|
+
self.num_frames = getattr(config.DATA, "num_frames", None)
|
|
184
|
+
self.input_fps = getattr(config.DATA, "input_fps", None)
|
|
185
|
+
self.transform = build_transform(config, mode=self.split)
|
|
186
|
+
|
|
187
|
+
def _select_views(self, video_paths):
|
|
188
|
+
""" choose which camera views to laod for this sample.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
video_paths: list of available view paths for the clip.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
a (possibly subsampled) list of paths.
|
|
195
|
+
|
|
196
|
+
"""
|
|
197
|
+
if self.view_type == "single":
|
|
198
|
+
return [video_paths[0]]
|
|
199
|
+
|
|
200
|
+
if self.split.lower() == "train" and self.view_type == "multi":
|
|
201
|
+
return random.sample(video_paths, min(2, len(video_paths)))
|
|
202
|
+
|
|
203
|
+
return video_paths
|
|
204
|
+
|
|
205
|
+
def _load_and_sample_clip(self, path):
|
|
206
|
+
"""read a video file, temporally sub-sample, and apply transforms.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
path: realtive path (under data_dir) to the video file.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
numpy.ndarray of shape (T, H, W, C).
|
|
213
|
+
"""
|
|
214
|
+
full_path = os.path.join(self.config.DATA.data_dir, path)
|
|
215
|
+
|
|
216
|
+
if full_path.endswith(".npy"):
|
|
217
|
+
frames = np.load(full_path).astype(np.float32) / 255.0
|
|
218
|
+
if self.transform is not None:
|
|
219
|
+
frames = self.transform(frames)
|
|
220
|
+
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
|
221
|
+
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
|
222
|
+
frames = (frames - mean) / std
|
|
223
|
+
return frames
|
|
224
|
+
|
|
225
|
+
v = read_video(os.path.join(self.config.DATA.data_dir, path))
|
|
226
|
+
|
|
227
|
+
v = process_frames(
|
|
228
|
+
v,
|
|
229
|
+
self.num_frames,
|
|
230
|
+
self.input_fps,
|
|
231
|
+
self.config.DATA.target_fps,
|
|
232
|
+
start_frame=self.config.DATA.start_frame,
|
|
233
|
+
end_frame=self.config.DATA.end_frame
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
if isinstance(v, list):
|
|
237
|
+
v = np.stack(v) # (T, H, W, C)
|
|
238
|
+
|
|
239
|
+
if self.transform is not None:
|
|
240
|
+
v = self.transform(v)
|
|
241
|
+
|
|
242
|
+
return v # (T, H, W, C)
|
|
243
|
+
|
|
244
|
+
def __getitem__(self, idx):
|
|
245
|
+
item = self.samples[idx]
|
|
246
|
+
label = item.get("label", None)
|
|
247
|
+
if label is not None:
|
|
248
|
+
label = torch.tensor(label, dtype=torch.long)
|
|
249
|
+
video_paths = item["video_paths"]
|
|
250
|
+
sample_id = item["id"]
|
|
251
|
+
|
|
252
|
+
# --- Choose which clips to load ---
|
|
253
|
+
if not video_paths:
|
|
254
|
+
raise ValueError(f"No video paths found for item {idx}")
|
|
255
|
+
|
|
256
|
+
selected_paths = self._select_views(video_paths)
|
|
257
|
+
|
|
258
|
+
# --- Load and process frames for selected clips ---
|
|
259
|
+
if self.config.MODEL.type == "huggingface":
|
|
260
|
+
path = selected_paths[0]
|
|
261
|
+
v = self._load_and_sample_clip(path)
|
|
262
|
+
# convert clip -> list of frames
|
|
263
|
+
v = list(v) # each element is (H, W, C)
|
|
264
|
+
|
|
265
|
+
#print(type(v), v)
|
|
266
|
+
v = self.processor(v, return_tensors="pt")#, do_rescale=False)
|
|
267
|
+
pixel_values = v["pixel_values"].float()
|
|
268
|
+
pixel_values = pixel_values.squeeze(0)
|
|
269
|
+
out = {"pixel_values": pixel_values, "id": sample_id}
|
|
270
|
+
if label is not None:
|
|
271
|
+
out["labels"] = label
|
|
272
|
+
return out
|
|
273
|
+
|
|
274
|
+
else:
|
|
275
|
+
view_tensors = []
|
|
276
|
+
for path in selected_paths:
|
|
277
|
+
v = self._load_and_sample_clip(path)
|
|
278
|
+
|
|
279
|
+
if path.endswith(".npy"):
|
|
280
|
+
# frames_npy
|
|
281
|
+
v = torch.from_numpy(v) # (T, H, W, C)
|
|
282
|
+
else:
|
|
283
|
+
# existing raw video path: apply torchvision model transforms
|
|
284
|
+
v = torch.from_numpy(v).permute(0, 3, 1, 2) # (T, C, H, W)
|
|
285
|
+
v = get_transforms_model(self.config.MODEL.pretrained_model)(v) # (C, T, H, W)
|
|
286
|
+
|
|
287
|
+
view_tensors.append(v)
|
|
288
|
+
|
|
289
|
+
if selected_paths[0].endswith(".npy"):
|
|
290
|
+
# frames_npy: single view, return (T, H, W, C) matching pixels_vs_positions
|
|
291
|
+
out = {"pixel_values": view_tensors[0], "id": sample_id}
|
|
292
|
+
else:
|
|
293
|
+
# existing multi-view path: stack to (V, C, T, H, W)
|
|
294
|
+
videos = torch.stack(view_tensors, dim=0)
|
|
295
|
+
out = {"pixel_values": videos, "id": sample_id}
|
|
296
|
+
|
|
297
|
+
if label is not None:
|
|
298
|
+
out["labels"] = label
|
|
299
|
+
return out
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
# -------------------------------------------------------------
|
|
303
|
+
# tracking modality
|
|
304
|
+
# -------------------------------------------------------------
|
|
305
|
+
|
|
306
|
+
class TrackingDataset(ClassificationDataset):
|
|
307
|
+
"""graph-based classification dataset built from player tracking data.
|
|
308
|
+
|
|
309
|
+
each sample is a temporal sequence of per-frame graphs where nodes
|
|
310
|
+
represent the ball and 22 players, and edges encode spatial or
|
|
311
|
+
tactical relationships (see build_edge_index).
|
|
312
|
+
|
|
313
|
+
supports optional preloading of all clips into memory and
|
|
314
|
+
training-time augmentation (horizontal/vertical/team flip).
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
config: the loaded YAML configuration.
|
|
318
|
+
annotations_path: path to the annotation JSON file.
|
|
319
|
+
split: one of "train", "valid", "test".
|
|
320
|
+
"""
|
|
321
|
+
|
|
322
|
+
def __init__(self, config, annotations_path, split="train"):
|
|
323
|
+
super().__init__(config, annotations_path, split)
|
|
324
|
+
|
|
325
|
+
from opensportslib.datasets.utils.tracking import (
|
|
326
|
+
FEATURE_DIM,
|
|
327
|
+
NUM_OBJECTS,
|
|
328
|
+
HorizontalFlip,
|
|
329
|
+
TeamFlip,
|
|
330
|
+
VerticalFlip,
|
|
331
|
+
build_edge_index,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# storing references for the constants without repeating the import.
|
|
335
|
+
self._NUM_OBJECTS = NUM_OBJECTS
|
|
336
|
+
self._FEATURE_DIM = FEATURE_DIM
|
|
337
|
+
self._build_edge_index = build_edge_index
|
|
338
|
+
|
|
339
|
+
self.num_frames = config.DATA.num_frames
|
|
340
|
+
self.normalize = config.DATA.normalize
|
|
341
|
+
self.edge_type = config.MODEL.edge
|
|
342
|
+
self.k = config.MODEL.k
|
|
343
|
+
self.r = config.MODEL.r
|
|
344
|
+
self.preload_data = config.DATA.preload_data
|
|
345
|
+
|
|
346
|
+
self.transforms = self._build_transforms(
|
|
347
|
+
config, split, HorizontalFlip, VerticalFlip, TeamFlip
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
self.processed_samples = None
|
|
351
|
+
if self.preload_data:
|
|
352
|
+
self._preload_all_data()
|
|
353
|
+
|
|
354
|
+
@staticmethod
|
|
355
|
+
def _build_transforms(config, split, HorizontalFlip, VerticalFlip, TeamFlip):
|
|
356
|
+
"""assmeble the list of training-time augmentations.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
config: the loaded YAML configuration.
|
|
360
|
+
split: dataset split; augmentations are only applied during training.
|
|
361
|
+
HorizontalFlip: augmentation class (passed to avoid re-importing).
|
|
362
|
+
VerticalFlip: augmentation class.
|
|
363
|
+
TeamFlip: augmentation class.
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
list of callable augmentation transforms (empty for
|
|
367
|
+
non-training splits).
|
|
368
|
+
"""
|
|
369
|
+
if split != "train":
|
|
370
|
+
return []
|
|
371
|
+
|
|
372
|
+
transforms = []
|
|
373
|
+
aug_config = config.DATA.augmentations
|
|
374
|
+
|
|
375
|
+
# augmentation flags are optional in the config; default to off.
|
|
376
|
+
if getattr(aug_config, "horizontal_flip", False):
|
|
377
|
+
transforms.append(HorizontalFlip(probability=0.5))
|
|
378
|
+
|
|
379
|
+
if getattr(aug_config, "vertical_flip", False):
|
|
380
|
+
transforms.append(VerticalFlip(probability=0.5))
|
|
381
|
+
|
|
382
|
+
if getattr(aug_config, "team_flip", False):
|
|
383
|
+
transforms.append(TeamFlip(probability=0.5))
|
|
384
|
+
|
|
385
|
+
return transforms
|
|
386
|
+
|
|
387
|
+
def _preload_all_data(self):
|
|
388
|
+
"""parse every clip and cache features and edge indices in memory.
|
|
389
|
+
|
|
390
|
+
edge indices are built on raw (un-augmented, un-normalized)
|
|
391
|
+
features so that the graph topology is deterministic and
|
|
392
|
+
augmentation-independent.
|
|
393
|
+
"""
|
|
394
|
+
from tqdm import tqdm
|
|
395
|
+
|
|
396
|
+
from opensportslib.datasets.utils.tracking import (
|
|
397
|
+
compute_deltas,
|
|
398
|
+
parse_frame
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
print(f"Preloading {len(self.samples)} {self.split} samples into memory...")
|
|
402
|
+
|
|
403
|
+
self.processed_samples = []
|
|
404
|
+
|
|
405
|
+
for item in tqdm(self.samples, desc=f"Loading {self.split}"):
|
|
406
|
+
clip_paths = item["video_paths"]
|
|
407
|
+
if not clip_paths:
|
|
408
|
+
continue
|
|
409
|
+
|
|
410
|
+
clip_path = clip_paths[0]
|
|
411
|
+
df = self._load_tracking_clip(clip_path)
|
|
412
|
+
|
|
413
|
+
num_frames = len(df)
|
|
414
|
+
all_features = np.zeros(
|
|
415
|
+
(num_frames, self._NUM_OBJECTS, self._FEATURE_DIM),
|
|
416
|
+
dtype=np.float32
|
|
417
|
+
)
|
|
418
|
+
all_positions = []
|
|
419
|
+
|
|
420
|
+
for t, (_, row) in enumerate(df.iterrows()):
|
|
421
|
+
features, positions = parse_frame(row)
|
|
422
|
+
all_features[t] = features
|
|
423
|
+
all_positions.append(positions)
|
|
424
|
+
|
|
425
|
+
all_features = compute_deltas(all_features)
|
|
426
|
+
|
|
427
|
+
# build edge indices on raw features (before any augmentation
|
|
428
|
+
# or normalization) so the graph topology stays consistent
|
|
429
|
+
# regardless of transforms.
|
|
430
|
+
edge_indices = []
|
|
431
|
+
for t in range(num_frames):
|
|
432
|
+
edge_index = self.build_edge_index(
|
|
433
|
+
all_features[t],
|
|
434
|
+
all_positions[t],
|
|
435
|
+
self.edge_type,
|
|
436
|
+
self.k,
|
|
437
|
+
self.r
|
|
438
|
+
)
|
|
439
|
+
edge_indices.append(edge_index)
|
|
440
|
+
|
|
441
|
+
self.processed_samples.append({
|
|
442
|
+
"features": all_features,
|
|
443
|
+
"positions": all_positions,
|
|
444
|
+
"edge_indices": edge_indices,
|
|
445
|
+
"label": item["label"],
|
|
446
|
+
"id": item["id"]
|
|
447
|
+
})
|
|
448
|
+
|
|
449
|
+
print(f"Loaded {len(self.processed_samples)} {self.split} samples")
|
|
450
|
+
|
|
451
|
+
def _load_tracking_clip(self, path):
|
|
452
|
+
"""read a single parquet tracking clip.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
path: Relative path (under ``data_dir``) to the parquet
|
|
456
|
+
file.
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
``pandas.DataFrame`` with one row per frame.
|
|
460
|
+
"""
|
|
461
|
+
import pandas as pd
|
|
462
|
+
|
|
463
|
+
full_path = os.path.join(self.data_dir, path)
|
|
464
|
+
return pd.read_parquet(full_path)
|
|
465
|
+
|
|
466
|
+
def __getitem__(self, idx):
|
|
467
|
+
if self.preload_data:
|
|
468
|
+
return self._getitem_preloaded(idx)
|
|
469
|
+
else:
|
|
470
|
+
return self._getitem_on_the_fly(idx)
|
|
471
|
+
|
|
472
|
+
def _getitem_preloaded(self, idx):
|
|
473
|
+
"""return a sample from the in-memory cache.
|
|
474
|
+
|
|
475
|
+
a copy of the feature array is made before augmentation and
|
|
476
|
+
normalization so the cached data is never mutated.
|
|
477
|
+
"""
|
|
478
|
+
from torch_geometric.data import Data
|
|
479
|
+
|
|
480
|
+
from opensportslib.datasets.utils.tracking import normalize_features
|
|
481
|
+
|
|
482
|
+
sample = self.processed_samples[idx]
|
|
483
|
+
features = sample["features"].copy()
|
|
484
|
+
|
|
485
|
+
for transform in self.transforms:
|
|
486
|
+
features = transform(features)
|
|
487
|
+
|
|
488
|
+
if self.normalize:
|
|
489
|
+
features = normalize_features(features)
|
|
490
|
+
|
|
491
|
+
# build one PyG Data object per frame. The downstream collate function
|
|
492
|
+
# (tracking_collate) uses PyG Batch.from_data_list to merge these across
|
|
493
|
+
# the batch dimension.
|
|
494
|
+
graphs = []
|
|
495
|
+
for t in range(features.shape[0]):
|
|
496
|
+
data = Data(
|
|
497
|
+
x=torch.tensor(features[t], dtype=torch.float),
|
|
498
|
+
edge_index=torch.tensor(
|
|
499
|
+
sample["edge_indices"][t], dtype=torch.long
|
|
500
|
+
),
|
|
501
|
+
)
|
|
502
|
+
graphs.append(data)
|
|
503
|
+
|
|
504
|
+
out = {
|
|
505
|
+
"graphs": graphs,
|
|
506
|
+
"seq_len": len(graphs),
|
|
507
|
+
"id": sample["id"]
|
|
508
|
+
}
|
|
509
|
+
if "label" in sample:
|
|
510
|
+
out["label"] = sample["label"]
|
|
511
|
+
return out
|
|
512
|
+
|
|
513
|
+
def _getitem_on_the_fly(self, idx):
|
|
514
|
+
"""load, parse, and process a single sample from disk."""
|
|
515
|
+
from torch_geometric.data import Data
|
|
516
|
+
|
|
517
|
+
from opensportslib.datasets.utils.tracking import (
|
|
518
|
+
compute_deltas,
|
|
519
|
+
normalize_features,
|
|
520
|
+
parse_frame,
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
item = self.samples[idx]
|
|
524
|
+
label = item["label"]
|
|
525
|
+
|
|
526
|
+
clip_paths = item["video_paths"]
|
|
527
|
+
if not clip_paths:
|
|
528
|
+
raise ValueError(f"No tracking paths found for item {idx}")
|
|
529
|
+
|
|
530
|
+
clip_path = clip_paths[0]
|
|
531
|
+
df = self._load_tracking_clip(clip_path)
|
|
532
|
+
|
|
533
|
+
num_frames = len(df)
|
|
534
|
+
all_features = np.zeros(
|
|
535
|
+
(num_frames, self._NUM_OBJECTS, self._FEATURE_DIM),
|
|
536
|
+
dtype=np.float32
|
|
537
|
+
)
|
|
538
|
+
all_positions = []
|
|
539
|
+
|
|
540
|
+
for t, (_, row) in enumerate(df.iterrows()):
|
|
541
|
+
features, positions = parse_frame(row)
|
|
542
|
+
all_features[t] = features
|
|
543
|
+
all_positions.append(positions)
|
|
544
|
+
|
|
545
|
+
all_features = compute_deltas(all_features)
|
|
546
|
+
|
|
547
|
+
# edge indices are built on raw features (before any augmentation /
|
|
548
|
+
# normalization) so the graph structure is augmentation-invariant.
|
|
549
|
+
edge_indices = []
|
|
550
|
+
for t in range(num_frames):
|
|
551
|
+
edge_index = self._build_edge_index(
|
|
552
|
+
all_features[t],
|
|
553
|
+
all_positions[t],
|
|
554
|
+
self.edge_type,
|
|
555
|
+
self.k,
|
|
556
|
+
self.r
|
|
557
|
+
)
|
|
558
|
+
edge_indices.append(edge_index)
|
|
559
|
+
|
|
560
|
+
for transform in self.transforms:
|
|
561
|
+
all_features = transform(all_features)
|
|
562
|
+
|
|
563
|
+
if self.normalize:
|
|
564
|
+
all_features = normalize_features(all_features)
|
|
565
|
+
|
|
566
|
+
graphs = []
|
|
567
|
+
for t in range(num_frames):
|
|
568
|
+
data = Data(
|
|
569
|
+
x=torch.tensor(all_features[t], dtype=torch.float),
|
|
570
|
+
edge_index=torch.tensor(edge_indices[t], dtype=torch.long),
|
|
571
|
+
)
|
|
572
|
+
graphs.append(data)
|
|
573
|
+
|
|
574
|
+
out = {
|
|
575
|
+
"graphs": graphs,
|
|
576
|
+
"seq_len": len(graphs),
|
|
577
|
+
"id": item["id"]
|
|
578
|
+
}
|
|
579
|
+
if "label" in label:
|
|
580
|
+
out["label"] = label
|
|
581
|
+
return out
|
|
582
|
+
|