opensportslib 0.0.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opensportslib/__init__.py +18 -0
- opensportslib/apis/__init__.py +21 -0
- opensportslib/apis/classification.py +361 -0
- opensportslib/apis/localization.py +228 -0
- opensportslib/config/classification.yaml +104 -0
- opensportslib/config/classification_tracking.yaml +103 -0
- opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
- opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
- opensportslib/config/localization.yaml +132 -0
- opensportslib/config/sngar_frames.yaml +98 -0
- opensportslib/core/__init__.py +0 -0
- opensportslib/core/loss/__init__.py +0 -0
- opensportslib/core/loss/builder.py +40 -0
- opensportslib/core/loss/calf.py +258 -0
- opensportslib/core/loss/ce.py +23 -0
- opensportslib/core/loss/combine.py +42 -0
- opensportslib/core/loss/nll.py +25 -0
- opensportslib/core/optimizer/__init__.py +0 -0
- opensportslib/core/optimizer/builder.py +38 -0
- opensportslib/core/sampler/weighted_sampler.py +104 -0
- opensportslib/core/scheduler/__init__.py +0 -0
- opensportslib/core/scheduler/builder.py +77 -0
- opensportslib/core/trainer/__init__.py +0 -0
- opensportslib/core/trainer/classification_trainer.py +1131 -0
- opensportslib/core/trainer/localization_trainer.py +1009 -0
- opensportslib/core/utils/checkpoint.py +238 -0
- opensportslib/core/utils/config.py +199 -0
- opensportslib/core/utils/data.py +85 -0
- opensportslib/core/utils/ddp.py +77 -0
- opensportslib/core/utils/default_args.py +110 -0
- opensportslib/core/utils/load_annotations.py +485 -0
- opensportslib/core/utils/seed.py +26 -0
- opensportslib/core/utils/video_processing.py +389 -0
- opensportslib/core/utils/wandb.py +110 -0
- opensportslib/datasets/__init__.py +0 -0
- opensportslib/datasets/builder.py +42 -0
- opensportslib/datasets/classification_dataset.py +582 -0
- opensportslib/datasets/localization_dataset.py +813 -0
- opensportslib/datasets/utils/__init__.py +15 -0
- opensportslib/datasets/utils/tracking.py +615 -0
- opensportslib/metrics/classification_metric.py +176 -0
- opensportslib/metrics/localization_metric.py +1482 -0
- opensportslib/models/__init__.py +0 -0
- opensportslib/models/backbones/builder.py +590 -0
- opensportslib/models/base/e2e.py +252 -0
- opensportslib/models/base/tracking.py +73 -0
- opensportslib/models/base/vars.py +29 -0
- opensportslib/models/base/video.py +130 -0
- opensportslib/models/base/video_mae.py +60 -0
- opensportslib/models/builder.py +43 -0
- opensportslib/models/heads/builder.py +266 -0
- opensportslib/models/neck/builder.py +210 -0
- opensportslib/models/utils/common.py +176 -0
- opensportslib/models/utils/impl/__init__.py +0 -0
- opensportslib/models/utils/impl/asformer.py +390 -0
- opensportslib/models/utils/impl/calf.py +74 -0
- opensportslib/models/utils/impl/gsm.py +112 -0
- opensportslib/models/utils/impl/gtad.py +347 -0
- opensportslib/models/utils/impl/tsm.py +123 -0
- opensportslib/models/utils/litebase.py +59 -0
- opensportslib/models/utils/modules.py +120 -0
- opensportslib/models/utils/shift.py +135 -0
- opensportslib/models/utils/utils.py +276 -0
- opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
- opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
- opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
- opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
import decord
|
|
7
|
+
from decord import cpu
|
|
8
|
+
USE_DECORD = True
|
|
9
|
+
except:
|
|
10
|
+
import av
|
|
11
|
+
USE_DECORD = False
|
|
12
|
+
|
|
13
|
+
def read_video(video_path):
|
|
14
|
+
"""Read video frames into list of HxWxC uint8 arrays for VideoMAE."""
|
|
15
|
+
if USE_DECORD:
|
|
16
|
+
vr = decord.VideoReader(video_path, ctx=cpu(0))
|
|
17
|
+
frames = vr.get_batch(range(len(vr))) # (T, H, W, C)
|
|
18
|
+
frames = frames.asnumpy().astype(np.uint8) # ensure uint8 for VideoMAE
|
|
19
|
+
frames_list = [frame for frame in frames] # list of T frames
|
|
20
|
+
else:
|
|
21
|
+
container = av.open(video_path)
|
|
22
|
+
frames = []
|
|
23
|
+
for frame in container.decode(video=0):
|
|
24
|
+
frames.append(frame.to_ndarray(format="rgb24").astype(np.uint8))
|
|
25
|
+
return frames
|
|
26
|
+
return frames_list
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def resample_video_idx(num_frames, original_fps, new_fps):
|
|
30
|
+
"""Return frame indices to match new fps"""
|
|
31
|
+
step = float(original_fps) / new_fps
|
|
32
|
+
if step.is_integer():
|
|
33
|
+
step = int(step)
|
|
34
|
+
return slice(None, None, step)
|
|
35
|
+
idxs = torch.arange(num_frames, dtype=torch.float32) * step
|
|
36
|
+
idxs = idxs.floor().to(torch.int64)
|
|
37
|
+
return idxs
|
|
38
|
+
|
|
39
|
+
def process_frames(frames, target_num_frames, input_fps, target_fps, start_frame=0, end_frame=-1, uniform_sample=False):
|
|
40
|
+
"""
|
|
41
|
+
frames: list of np arrays (H, W, C)
|
|
42
|
+
target_num_frames: int
|
|
43
|
+
Returns: list of np arrays (H, W, C) ready for processor
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
target_fps = target_fps or input_fps
|
|
48
|
+
num_frames = len(frames)
|
|
49
|
+
end_frame = end_frame if end_frame != -1 else num_frames
|
|
50
|
+
duration = num_frames / input_fps
|
|
51
|
+
|
|
52
|
+
# unfiorm sampling throughout the video
|
|
53
|
+
if uniform_sample:
|
|
54
|
+
# Too short → resample with new fps
|
|
55
|
+
if num_frames < target_num_frames:
|
|
56
|
+
new_fps = np.ceil(target_num_frames / duration)
|
|
57
|
+
idxs = resample_video_idx(target_num_frames, input_fps, new_fps)
|
|
58
|
+
idxs = np.clip(idxs, 0, num_frames - 1)
|
|
59
|
+
frames = [frames[i] for i in idxs]
|
|
60
|
+
|
|
61
|
+
# Too long → uniform sampling
|
|
62
|
+
elif num_frames > target_num_frames:
|
|
63
|
+
idxs = np.linspace(0, num_frames - 1, target_num_frames).astype(int)
|
|
64
|
+
frames = [frames[i] for i in idxs]
|
|
65
|
+
|
|
66
|
+
# Pad if still short
|
|
67
|
+
if len(frames) < target_num_frames:
|
|
68
|
+
pad = target_num_frames - len(frames)
|
|
69
|
+
frames.extend([frames[-1]] * pad)
|
|
70
|
+
else:
|
|
71
|
+
window = frames[start_frame:end_frame] # 24 frames
|
|
72
|
+
assert len(window) > 0, "Empty temporal window"
|
|
73
|
+
factor = input_fps / target_fps
|
|
74
|
+
idxs = [int(i * factor) for i in range(target_num_frames)]
|
|
75
|
+
idxs = [min(i, len(window) - 1) for i in idxs]
|
|
76
|
+
frames = [window[i] for i in idxs]
|
|
77
|
+
assert len(frames) == target_num_frames, f"Expected {target_num_frames} frames, got {len(frames)}"
|
|
78
|
+
|
|
79
|
+
return frames
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def get_stride(src_fps, sample_fps):
|
|
83
|
+
"""Get stride to apply based on the input and output fps.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
src_fps (int): The input fps of the video.
|
|
87
|
+
sample_fps (int): The output fps.
|
|
88
|
+
Returns:
|
|
89
|
+
stride (int): The stride to apply.
|
|
90
|
+
"""
|
|
91
|
+
if sample_fps <= 0:
|
|
92
|
+
stride = 1
|
|
93
|
+
else:
|
|
94
|
+
stride = int(src_fps / sample_fps)
|
|
95
|
+
return stride
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def read_fps(fps, sample_fps):
|
|
99
|
+
"""Computes the exact output fps based on input fps and wanted output fps.
|
|
100
|
+
Example: if input fps is 25 and wanted output fps is 2, the exact output fps is 2.0833333333333335.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
fps (int): The input fps.
|
|
104
|
+
sample_fps (int): The wanted output fps.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
est_out_fps (float): The exact output fps.
|
|
108
|
+
|
|
109
|
+
"""
|
|
110
|
+
stride = get_stride(fps, sample_fps)
|
|
111
|
+
est_out_fps = fps / stride
|
|
112
|
+
return est_out_fps
|
|
113
|
+
|
|
114
|
+
def get_num_frames(num_frames, fps, sample_fps):
|
|
115
|
+
"""Compute the number of frames of a video after fps changes.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
num_frames (int): Number of frames in the base video.
|
|
119
|
+
fps (int): The input fps.
|
|
120
|
+
sample_fps (int): The output fps.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
(int): The number of frames with the output fps.
|
|
124
|
+
"""
|
|
125
|
+
return math.ceil(num_frames / get_stride(fps, sample_fps))
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def distribute_elements(batch_size, len_devices):
|
|
129
|
+
"""Return a list containing the distribution of the batch along the devices.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
batch_size (int).
|
|
133
|
+
len_device (int).
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
distribution (list): For example if batch size is 8 and there is 4 gpus, the distribution is [2,2,2,2], meaning that each gpu will process 2 samples.
|
|
137
|
+
"""
|
|
138
|
+
quotient, remainder = divmod(batch_size, len_devices)
|
|
139
|
+
distribution = [quotient] * len_devices
|
|
140
|
+
if remainder > 0:
|
|
141
|
+
for i in range(len(distribution)):
|
|
142
|
+
distribution[i] += 1
|
|
143
|
+
|
|
144
|
+
return distribution
|
|
145
|
+
|
|
146
|
+
def _get_deferred_rgb_transform(IMAGENET_MEAN, IMAGENET_STD):
|
|
147
|
+
import torchvision.transforms as T
|
|
148
|
+
import torch.nn as nn
|
|
149
|
+
img_transforms = [
|
|
150
|
+
# Jittering separately is faster (low variance)
|
|
151
|
+
T.RandomApply(
|
|
152
|
+
nn.ModuleList([T.ColorJitter(hue=0.2)]), p=0.25
|
|
153
|
+
),
|
|
154
|
+
T.RandomApply(
|
|
155
|
+
nn.ModuleList([T.ColorJitter(saturation=(0.7, 1.2))]), p=0.25
|
|
156
|
+
),
|
|
157
|
+
T.RandomApply(
|
|
158
|
+
nn.ModuleList([T.ColorJitter(brightness=(0.7, 1.2))]), p=0.25
|
|
159
|
+
),
|
|
160
|
+
T.RandomApply(
|
|
161
|
+
nn.ModuleList([T.ColorJitter(contrast=(0.7, 1.2))]), p=0.25
|
|
162
|
+
),
|
|
163
|
+
# Jittering together is slower (high variance)
|
|
164
|
+
# transforms.RandomApply(
|
|
165
|
+
# nn.ModuleList([
|
|
166
|
+
# transforms.ColorJitter(
|
|
167
|
+
# brightness=(0.7, 1.2), contrast=(0.7, 1.2),
|
|
168
|
+
# saturation=(0.7, 1.2), hue=0.2)
|
|
169
|
+
# ]), 0.8),
|
|
170
|
+
T.RandomApply(nn.ModuleList([T.GaussianBlur(5)]), p=0.25),
|
|
171
|
+
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
|
|
172
|
+
]
|
|
173
|
+
return torch.jit.script(nn.Sequential(*img_transforms))
|
|
174
|
+
|
|
175
|
+
def get_remaining(data_len, batch_size):
|
|
176
|
+
"""Return the padding that ensures that all batches have an equal number of items, which is required with the pipeline to make sur that all clips are processed.
|
|
177
|
+
Args:
|
|
178
|
+
data_len (int): The length of dataset.
|
|
179
|
+
batch_size (int).
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
(int): The number of elements to add.
|
|
183
|
+
"""
|
|
184
|
+
return (math.ceil(data_len / batch_size) * batch_size) - data_len
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
import random
|
|
188
|
+
import numpy as np
|
|
189
|
+
import torch
|
|
190
|
+
import torchvision.transforms as T
|
|
191
|
+
import torchvision.transforms.functional as F
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class VideoTransform:
|
|
195
|
+
def __init__(self, config, mode="train"):
|
|
196
|
+
self.mode = mode
|
|
197
|
+
self.config = config
|
|
198
|
+
|
|
199
|
+
self.frame_height, self.frame_width = config.DATA.frame_size
|
|
200
|
+
self.augmentations = config.DATA.augmentations
|
|
201
|
+
|
|
202
|
+
def __call__(self, frames: np.ndarray):
|
|
203
|
+
"""
|
|
204
|
+
frames: np.ndarray (T, H, W, C)
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
if self.mode != "train":
|
|
208
|
+
return frames
|
|
209
|
+
|
|
210
|
+
is_float = frames.dtype in (np.float32, np.float64)
|
|
211
|
+
|
|
212
|
+
T_, H, W, C = frames.shape
|
|
213
|
+
frames_t = torch.from_numpy(frames).permute(0, 3, 1, 2)
|
|
214
|
+
|
|
215
|
+
aug = self.augmentations
|
|
216
|
+
|
|
217
|
+
# ---------------- Random crop ----------------
|
|
218
|
+
if getattr(aug, "random_crop", False):
|
|
219
|
+
scale = getattr(aug, "scale", (0.8, 1.0))
|
|
220
|
+
ratio = getattr(aug, "ratio", (3/4, 4/3))
|
|
221
|
+
|
|
222
|
+
i, j, h, w = T.RandomResizedCrop.get_params(
|
|
223
|
+
frames_t[0], scale=scale, ratio=ratio
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
frames_t = torch.stack([
|
|
227
|
+
F.resized_crop(
|
|
228
|
+
f, i, j, h, w,
|
|
229
|
+
size=(self.frame_height, self.frame_width),
|
|
230
|
+
interpolation=F.InterpolationMode.BILINEAR,
|
|
231
|
+
)
|
|
232
|
+
for f in frames_t
|
|
233
|
+
])
|
|
234
|
+
|
|
235
|
+
# ---------------- Affine ----------------
|
|
236
|
+
if getattr(aug, "random_affine", False):
|
|
237
|
+
max_translate = getattr(aug, "translate", (0.1, 0.1))
|
|
238
|
+
scale_range = getattr(aug, "affine_scale", (0.9, 1.0))
|
|
239
|
+
|
|
240
|
+
tx = int(random.uniform(-max_translate[0], max_translate[0]) * W)
|
|
241
|
+
ty = int(random.uniform(-max_translate[1], max_translate[1]) * H)
|
|
242
|
+
scale_factor = random.uniform(scale_range[0], scale_range[1])
|
|
243
|
+
|
|
244
|
+
frames_t = torch.stack([
|
|
245
|
+
F.affine(
|
|
246
|
+
f,
|
|
247
|
+
angle=0.0,
|
|
248
|
+
translate=[tx, ty],
|
|
249
|
+
scale=scale_factor,
|
|
250
|
+
shear=[0.0, 0.0],
|
|
251
|
+
interpolation=F.InterpolationMode.BILINEAR,
|
|
252
|
+
)
|
|
253
|
+
for f in frames_t
|
|
254
|
+
])
|
|
255
|
+
|
|
256
|
+
# ---------------- Perspective ----------------
|
|
257
|
+
if getattr(aug, "random_perspective", False):
|
|
258
|
+
distortion_scale = getattr(aug, "distortion_scale", 0.3)
|
|
259
|
+
p = getattr(aug, "perspective_prob", 0.5)
|
|
260
|
+
|
|
261
|
+
if random.random() < p:
|
|
262
|
+
startpoints, endpoints = T.RandomPerspective.get_params(
|
|
263
|
+
width=W, height=H, distortion_scale=distortion_scale
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
frames_t = torch.stack([
|
|
267
|
+
F.perspective(
|
|
268
|
+
f,
|
|
269
|
+
startpoints=startpoints,
|
|
270
|
+
endpoints=endpoints,
|
|
271
|
+
interpolation=F.InterpolationMode.BILINEAR,
|
|
272
|
+
)
|
|
273
|
+
for f in frames_t
|
|
274
|
+
])
|
|
275
|
+
|
|
276
|
+
# ---------------- Rotation ----------------
|
|
277
|
+
if getattr(aug, "random_rotation", False):
|
|
278
|
+
degrees = getattr(aug, "rotation_degrees", 5)
|
|
279
|
+
angle = random.uniform(-degrees, degrees)
|
|
280
|
+
|
|
281
|
+
frames_t = torch.stack([
|
|
282
|
+
F.rotate(f, angle=angle,
|
|
283
|
+
interpolation=F.InterpolationMode.BILINEAR)
|
|
284
|
+
for f in frames_t
|
|
285
|
+
])
|
|
286
|
+
|
|
287
|
+
# ---------------- Color jitter ----------------
|
|
288
|
+
if getattr(aug, "color_jitter", False):
|
|
289
|
+
jitter_prob = getattr(aug, "jitter_prob", 1.0) # defaults to 1.0
|
|
290
|
+
brightness, contrast, saturation, hue = getattr(
|
|
291
|
+
aug, "jitter_params", (0.2, 0.2, 0.2, 0.05)
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
jitter = T.ColorJitter(brightness, contrast, saturation, hue)
|
|
295
|
+
frames_t = torch.stack([jitter(f) for f in frames_t])
|
|
296
|
+
|
|
297
|
+
# ---------------- Flip ----------------
|
|
298
|
+
if getattr(aug, "random_horizontal_flip", False):
|
|
299
|
+
if random.random() < getattr(aug, "flip_prob", 0.5):
|
|
300
|
+
frames_t = torch.flip(frames_t, dims=[3])
|
|
301
|
+
|
|
302
|
+
result = frames_t.permute(0, 2, 3, 1)
|
|
303
|
+
return result.numpy().astype(np.float32 if is_float else np.uint8)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def build_transform(config, mode="train"):
|
|
307
|
+
return VideoTransform(config, mode)
|
|
308
|
+
|
|
309
|
+
def get_transforms_model(pre_model):
|
|
310
|
+
from torchvision.models.video import R3D_18_Weights, MC3_18_Weights
|
|
311
|
+
from torchvision.models.video import R2Plus1D_18_Weights, S3D_Weights
|
|
312
|
+
from torchvision.models.video import MViT_V2_S_Weights, MViT_V1_B_Weights
|
|
313
|
+
from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights, mvit_v1_b, MViT_V1_B_Weights
|
|
314
|
+
|
|
315
|
+
if pre_model == "r3d_18":
|
|
316
|
+
transforms_model = R3D_18_Weights.KINETICS400_V1.transforms()
|
|
317
|
+
elif pre_model == "s3d":
|
|
318
|
+
transforms_model = S3D_Weights.KINETICS400_V1.transforms()
|
|
319
|
+
elif pre_model == "mc3_18":
|
|
320
|
+
transforms_model = MC3_18_Weights.KINETICS400_V1.transforms()
|
|
321
|
+
elif pre_model == "r2plus1d_18":
|
|
322
|
+
transforms_model = R2Plus1D_18_Weights.KINETICS400_V1.transforms()
|
|
323
|
+
elif pre_model == "mvit_v2_s":
|
|
324
|
+
transforms_model = MViT_V2_S_Weights.KINETICS400_V1.transforms()
|
|
325
|
+
else:
|
|
326
|
+
transforms_model = R2Plus1D_18_Weights.KINETICS400_V1.transforms()
|
|
327
|
+
|
|
328
|
+
return transforms_model
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
# import torch
|
|
332
|
+
# import numpy as np
|
|
333
|
+
# import decord
|
|
334
|
+
# from decord import cpu
|
|
335
|
+
# import torchvision.transforms as T
|
|
336
|
+
|
|
337
|
+
# def read_video(video_path):
|
|
338
|
+
# """Read video frames into tensor (T, C, H, W)"""
|
|
339
|
+
# vr = decord.VideoReader(video_path, ctx=cpu(0))
|
|
340
|
+
# frames = vr.get_batch(range(len(vr))) # (T, H, W, C)
|
|
341
|
+
# frames = torch.from_numpy(frames.asnumpy()).permute(0, 3, 1, 2) # (T, C, H, W)
|
|
342
|
+
# return frames
|
|
343
|
+
|
|
344
|
+
# def resample_video_idx(num_frames, original_fps, new_fps):
|
|
345
|
+
# """Return frame indices to match new fps"""
|
|
346
|
+
# step = float(original_fps) / new_fps
|
|
347
|
+
# if step.is_integer():
|
|
348
|
+
# step = int(step)
|
|
349
|
+
# return slice(None, None, step)
|
|
350
|
+
# idxs = torch.arange(num_frames, dtype=torch.float32) * step
|
|
351
|
+
# idxs = idxs.floor().to(torch.int64)
|
|
352
|
+
# return idxs
|
|
353
|
+
|
|
354
|
+
# def process_frames(video_tensor, target_num_frames, input_fps, target_fps=None):
|
|
355
|
+
# """Uniform sampling, padding, FPS adjustment"""
|
|
356
|
+
# target_fps = target_fps or input_fps
|
|
357
|
+
# num_frames = video_tensor.size(0)
|
|
358
|
+
# duration = num_frames / input_fps
|
|
359
|
+
|
|
360
|
+
# # Too short → resample
|
|
361
|
+
# if num_frames < target_num_frames:
|
|
362
|
+
# new_fps = np.ceil(target_num_frames / duration)
|
|
363
|
+
# idxs = resample_video_idx(target_num_frames, input_fps, new_fps)
|
|
364
|
+
# idxs = np.clip(idxs, 0, num_frames - 1)
|
|
365
|
+
# video_tensor = video_tensor[idxs]
|
|
366
|
+
|
|
367
|
+
# # Too long → uniform sample
|
|
368
|
+
# elif num_frames > target_num_frames:
|
|
369
|
+
# idxs = np.linspace(0, num_frames - 1, target_num_frames)
|
|
370
|
+
# video_tensor = video_tensor[idxs.astype(int)]
|
|
371
|
+
|
|
372
|
+
# # Pad if still short
|
|
373
|
+
# if video_tensor.size(0) < target_num_frames:
|
|
374
|
+
# pad = target_num_frames - video_tensor.size(0)
|
|
375
|
+
# last_frame = video_tensor[-1:].repeat(pad, 1, 1, 1)
|
|
376
|
+
# video_tensor = torch.cat([video_tensor, last_frame], dim=0)
|
|
377
|
+
|
|
378
|
+
# return video_tensor
|
|
379
|
+
|
|
380
|
+
# def build_transform(frame_size=(224, 224), mean=None, std=None):
|
|
381
|
+
# """Return default frame transform"""
|
|
382
|
+
# mean = mean or [0.485, 0.456, 0.406]
|
|
383
|
+
# std = std or [0.229, 0.224, 0.225]
|
|
384
|
+
# return T.Compose([
|
|
385
|
+
# T.ConvertImageDtype(torch.float32),
|
|
386
|
+
# T.Resize(frame_size),
|
|
387
|
+
# T.Normalize(mean, std),
|
|
388
|
+
# ])
|
|
389
|
+
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import wandb
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import numpy as np
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
def init_wandb(cfg, run_id, use_wandb=False):
|
|
7
|
+
"""
|
|
8
|
+
Initialize Weights & Biases if enabled.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
cfg: config object with attributes:
|
|
12
|
+
- use_wandb (bool)
|
|
13
|
+
- project_name (str)
|
|
14
|
+
- run_name (str)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
if not use_wandb:
|
|
18
|
+
logging.info("W&B disabled.")
|
|
19
|
+
return None
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import wandb
|
|
23
|
+
except ImportError:
|
|
24
|
+
logging.warning("wandb not installed. Install with `pip install wandb`.")
|
|
25
|
+
return None
|
|
26
|
+
|
|
27
|
+
if getattr(cfg.DATA, "data_modality", None):
|
|
28
|
+
run_name = f"{cfg.MODEL.backbone.type}_{cfg.DATA.data_modality}"
|
|
29
|
+
else:
|
|
30
|
+
run_name = f"{cfg.MODEL.backbone.type}"
|
|
31
|
+
|
|
32
|
+
wandb.init(
|
|
33
|
+
project=cfg.TASK,
|
|
34
|
+
name=run_name,
|
|
35
|
+
id=run_id,
|
|
36
|
+
resume="allow",
|
|
37
|
+
config=vars(cfg) if hasattr(cfg, "__dict__") else cfg,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
logging.info(f"Wandb initialised")
|
|
41
|
+
return wandb
|
|
42
|
+
|
|
43
|
+
def log_table_wandb(name, rows, headers):
|
|
44
|
+
"""
|
|
45
|
+
Log a table to Weights & Biases.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
name (str): Name of the table in wandb.
|
|
49
|
+
rows (list[list]): Table rows.
|
|
50
|
+
headers (list[str]): Column headers.
|
|
51
|
+
"""
|
|
52
|
+
if wandb.run is None:
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
table = wandb.Table(columns=headers)
|
|
56
|
+
|
|
57
|
+
for row in rows:
|
|
58
|
+
table.add_data(*row)
|
|
59
|
+
|
|
60
|
+
wandb.log({name: table})
|
|
61
|
+
|
|
62
|
+
def log_attention_wandb(attention, split_name):
|
|
63
|
+
|
|
64
|
+
attn = attention.detach().cpu().numpy()
|
|
65
|
+
|
|
66
|
+
fig, ax = plt.subplots(figsize=(6, 3))
|
|
67
|
+
ax.imshow(attn, aspect="auto", cmap="viridis")
|
|
68
|
+
ax.set_title(f"{split_name} Attention Map")
|
|
69
|
+
ax.set_xlabel("Views / Time")
|
|
70
|
+
ax.set_ylabel("Batch")
|
|
71
|
+
|
|
72
|
+
wandb.log({
|
|
73
|
+
f"{split_name}/attention_map": wandb.Image(fig)
|
|
74
|
+
})
|
|
75
|
+
|
|
76
|
+
plt.close(fig)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def log_sample_videos_wandb(mvclips, preds, labels, split_name, max_samples=2, fps=5):
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# mvclips: (B, V, C, T, H, W)
|
|
83
|
+
mvclips = mvclips.detach().cpu().numpy()
|
|
84
|
+
|
|
85
|
+
for i in range(min(len(mvclips), max_samples)):
|
|
86
|
+
views = mvclips[i] # (V, C, T, H, W)
|
|
87
|
+
|
|
88
|
+
# Log each view separately
|
|
89
|
+
for v in range(views.shape[0]):
|
|
90
|
+
video = views[v].transpose(1, 2, 3, 0) # (T, H, W, C)
|
|
91
|
+
video = (video * 255).astype(np.uint8) if video.max() <= 1.0 else video
|
|
92
|
+
|
|
93
|
+
wandb.log({
|
|
94
|
+
f"{split_name}/sample_{i}_view_{v}": wandb.Video(
|
|
95
|
+
video,
|
|
96
|
+
fps=fps,
|
|
97
|
+
caption=f"Pred: {preds[i]}, GT: {labels[i]}"
|
|
98
|
+
)
|
|
99
|
+
})
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def log_confusion_matrix_wandb(y_true, y_pred, class_names, split_name):
|
|
103
|
+
wandb.log({
|
|
104
|
+
f"{split_name}/confusion_matrix": wandb.plot.confusion_matrix(
|
|
105
|
+
probs=None,
|
|
106
|
+
y_true=y_true,
|
|
107
|
+
preds=y_pred,
|
|
108
|
+
class_names=class_names
|
|
109
|
+
)
|
|
110
|
+
})
|
|
File without changes
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# opensportslib/datasets/builder.py
|
|
2
|
+
# from .spotting_dataset import SpottingDataset
|
|
3
|
+
|
|
4
|
+
def build_dataset(config, annotation_file=None, processor=None, split="train"):
|
|
5
|
+
"""Return a dataset instance based on model type"""
|
|
6
|
+
task = config.TASK.lower()
|
|
7
|
+
|
|
8
|
+
if "classification" in task:
|
|
9
|
+
from opensportslib.datasets import classification_dataset
|
|
10
|
+
return classification_dataset.build(config, annotation_file, processor, split)
|
|
11
|
+
|
|
12
|
+
elif "localization" in task:
|
|
13
|
+
from opensportslib.datasets.localization_dataset import LocalizationDataset
|
|
14
|
+
return LocalizationDataset(config, annotation_file, processor, split=split)
|
|
15
|
+
|
|
16
|
+
else:
|
|
17
|
+
raise ValueError(f"No dataset found for task: {task}")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
##### --------- ####
|
|
21
|
+
# def build_dataset(config, annotation_file=None, processor=None, split="train"):
|
|
22
|
+
# """Return a dataset instance based on task and modality"""
|
|
23
|
+
# task = config.TASK.lower()
|
|
24
|
+
|
|
25
|
+
# if "classification" in task:
|
|
26
|
+
# modality = config.DATA.data_modality.lower()
|
|
27
|
+
|
|
28
|
+
# if modality == "tracking_parquet":
|
|
29
|
+
# from opensportslib.datasets.classification_dataset import TrackingDataset
|
|
30
|
+
# return TrackingDataset(config, annotation_file, split)
|
|
31
|
+
# elif modality == "video":
|
|
32
|
+
# from opensportslib.datasets.classification_dataset import VideoDataset
|
|
33
|
+
# return VideoDataset(config, annotation_file, processor, split)
|
|
34
|
+
# else:
|
|
35
|
+
# raise ValueError(f"Unknown data_modality: {modality}")
|
|
36
|
+
|
|
37
|
+
# elif "localization" in task:
|
|
38
|
+
# from opensportslib.datasets.localization_dataset import LocalizationDataset
|
|
39
|
+
# return LocalizationDataset(config, annotation_file, processor, split=split)
|
|
40
|
+
|
|
41
|
+
# else:
|
|
42
|
+
# raise ValueError(f"No dataset found for task: {task}")
|