opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,389 @@
1
+ import torch
2
+ import numpy as np
3
+ import math
4
+
5
+ try:
6
+ import decord
7
+ from decord import cpu
8
+ USE_DECORD = True
9
+ except:
10
+ import av
11
+ USE_DECORD = False
12
+
13
+ def read_video(video_path):
14
+ """Read video frames into list of HxWxC uint8 arrays for VideoMAE."""
15
+ if USE_DECORD:
16
+ vr = decord.VideoReader(video_path, ctx=cpu(0))
17
+ frames = vr.get_batch(range(len(vr))) # (T, H, W, C)
18
+ frames = frames.asnumpy().astype(np.uint8) # ensure uint8 for VideoMAE
19
+ frames_list = [frame for frame in frames] # list of T frames
20
+ else:
21
+ container = av.open(video_path)
22
+ frames = []
23
+ for frame in container.decode(video=0):
24
+ frames.append(frame.to_ndarray(format="rgb24").astype(np.uint8))
25
+ return frames
26
+ return frames_list
27
+
28
+
29
+ def resample_video_idx(num_frames, original_fps, new_fps):
30
+ """Return frame indices to match new fps"""
31
+ step = float(original_fps) / new_fps
32
+ if step.is_integer():
33
+ step = int(step)
34
+ return slice(None, None, step)
35
+ idxs = torch.arange(num_frames, dtype=torch.float32) * step
36
+ idxs = idxs.floor().to(torch.int64)
37
+ return idxs
38
+
39
+ def process_frames(frames, target_num_frames, input_fps, target_fps, start_frame=0, end_frame=-1, uniform_sample=False):
40
+ """
41
+ frames: list of np arrays (H, W, C)
42
+ target_num_frames: int
43
+ Returns: list of np arrays (H, W, C) ready for processor
44
+ """
45
+
46
+
47
+ target_fps = target_fps or input_fps
48
+ num_frames = len(frames)
49
+ end_frame = end_frame if end_frame != -1 else num_frames
50
+ duration = num_frames / input_fps
51
+
52
+ # unfiorm sampling throughout the video
53
+ if uniform_sample:
54
+ # Too short → resample with new fps
55
+ if num_frames < target_num_frames:
56
+ new_fps = np.ceil(target_num_frames / duration)
57
+ idxs = resample_video_idx(target_num_frames, input_fps, new_fps)
58
+ idxs = np.clip(idxs, 0, num_frames - 1)
59
+ frames = [frames[i] for i in idxs]
60
+
61
+ # Too long → uniform sampling
62
+ elif num_frames > target_num_frames:
63
+ idxs = np.linspace(0, num_frames - 1, target_num_frames).astype(int)
64
+ frames = [frames[i] for i in idxs]
65
+
66
+ # Pad if still short
67
+ if len(frames) < target_num_frames:
68
+ pad = target_num_frames - len(frames)
69
+ frames.extend([frames[-1]] * pad)
70
+ else:
71
+ window = frames[start_frame:end_frame] # 24 frames
72
+ assert len(window) > 0, "Empty temporal window"
73
+ factor = input_fps / target_fps
74
+ idxs = [int(i * factor) for i in range(target_num_frames)]
75
+ idxs = [min(i, len(window) - 1) for i in idxs]
76
+ frames = [window[i] for i in idxs]
77
+ assert len(frames) == target_num_frames, f"Expected {target_num_frames} frames, got {len(frames)}"
78
+
79
+ return frames
80
+
81
+
82
+ def get_stride(src_fps, sample_fps):
83
+ """Get stride to apply based on the input and output fps.
84
+
85
+ Args:
86
+ src_fps (int): The input fps of the video.
87
+ sample_fps (int): The output fps.
88
+ Returns:
89
+ stride (int): The stride to apply.
90
+ """
91
+ if sample_fps <= 0:
92
+ stride = 1
93
+ else:
94
+ stride = int(src_fps / sample_fps)
95
+ return stride
96
+
97
+
98
+ def read_fps(fps, sample_fps):
99
+ """Computes the exact output fps based on input fps and wanted output fps.
100
+ Example: if input fps is 25 and wanted output fps is 2, the exact output fps is 2.0833333333333335.
101
+
102
+ Args:
103
+ fps (int): The input fps.
104
+ sample_fps (int): The wanted output fps.
105
+
106
+ Returns:
107
+ est_out_fps (float): The exact output fps.
108
+
109
+ """
110
+ stride = get_stride(fps, sample_fps)
111
+ est_out_fps = fps / stride
112
+ return est_out_fps
113
+
114
+ def get_num_frames(num_frames, fps, sample_fps):
115
+ """Compute the number of frames of a video after fps changes.
116
+
117
+ Args:
118
+ num_frames (int): Number of frames in the base video.
119
+ fps (int): The input fps.
120
+ sample_fps (int): The output fps.
121
+
122
+ Returns:
123
+ (int): The number of frames with the output fps.
124
+ """
125
+ return math.ceil(num_frames / get_stride(fps, sample_fps))
126
+
127
+
128
+ def distribute_elements(batch_size, len_devices):
129
+ """Return a list containing the distribution of the batch along the devices.
130
+
131
+ Args:
132
+ batch_size (int).
133
+ len_device (int).
134
+
135
+ Returns:
136
+ distribution (list): For example if batch size is 8 and there is 4 gpus, the distribution is [2,2,2,2], meaning that each gpu will process 2 samples.
137
+ """
138
+ quotient, remainder = divmod(batch_size, len_devices)
139
+ distribution = [quotient] * len_devices
140
+ if remainder > 0:
141
+ for i in range(len(distribution)):
142
+ distribution[i] += 1
143
+
144
+ return distribution
145
+
146
+ def _get_deferred_rgb_transform(IMAGENET_MEAN, IMAGENET_STD):
147
+ import torchvision.transforms as T
148
+ import torch.nn as nn
149
+ img_transforms = [
150
+ # Jittering separately is faster (low variance)
151
+ T.RandomApply(
152
+ nn.ModuleList([T.ColorJitter(hue=0.2)]), p=0.25
153
+ ),
154
+ T.RandomApply(
155
+ nn.ModuleList([T.ColorJitter(saturation=(0.7, 1.2))]), p=0.25
156
+ ),
157
+ T.RandomApply(
158
+ nn.ModuleList([T.ColorJitter(brightness=(0.7, 1.2))]), p=0.25
159
+ ),
160
+ T.RandomApply(
161
+ nn.ModuleList([T.ColorJitter(contrast=(0.7, 1.2))]), p=0.25
162
+ ),
163
+ # Jittering together is slower (high variance)
164
+ # transforms.RandomApply(
165
+ # nn.ModuleList([
166
+ # transforms.ColorJitter(
167
+ # brightness=(0.7, 1.2), contrast=(0.7, 1.2),
168
+ # saturation=(0.7, 1.2), hue=0.2)
169
+ # ]), 0.8),
170
+ T.RandomApply(nn.ModuleList([T.GaussianBlur(5)]), p=0.25),
171
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
172
+ ]
173
+ return torch.jit.script(nn.Sequential(*img_transforms))
174
+
175
+ def get_remaining(data_len, batch_size):
176
+ """Return the padding that ensures that all batches have an equal number of items, which is required with the pipeline to make sur that all clips are processed.
177
+ Args:
178
+ data_len (int): The length of dataset.
179
+ batch_size (int).
180
+
181
+ Returns:
182
+ (int): The number of elements to add.
183
+ """
184
+ return (math.ceil(data_len / batch_size) * batch_size) - data_len
185
+
186
+
187
+ import random
188
+ import numpy as np
189
+ import torch
190
+ import torchvision.transforms as T
191
+ import torchvision.transforms.functional as F
192
+
193
+
194
+ class VideoTransform:
195
+ def __init__(self, config, mode="train"):
196
+ self.mode = mode
197
+ self.config = config
198
+
199
+ self.frame_height, self.frame_width = config.DATA.frame_size
200
+ self.augmentations = config.DATA.augmentations
201
+
202
+ def __call__(self, frames: np.ndarray):
203
+ """
204
+ frames: np.ndarray (T, H, W, C)
205
+ """
206
+
207
+ if self.mode != "train":
208
+ return frames
209
+
210
+ is_float = frames.dtype in (np.float32, np.float64)
211
+
212
+ T_, H, W, C = frames.shape
213
+ frames_t = torch.from_numpy(frames).permute(0, 3, 1, 2)
214
+
215
+ aug = self.augmentations
216
+
217
+ # ---------------- Random crop ----------------
218
+ if getattr(aug, "random_crop", False):
219
+ scale = getattr(aug, "scale", (0.8, 1.0))
220
+ ratio = getattr(aug, "ratio", (3/4, 4/3))
221
+
222
+ i, j, h, w = T.RandomResizedCrop.get_params(
223
+ frames_t[0], scale=scale, ratio=ratio
224
+ )
225
+
226
+ frames_t = torch.stack([
227
+ F.resized_crop(
228
+ f, i, j, h, w,
229
+ size=(self.frame_height, self.frame_width),
230
+ interpolation=F.InterpolationMode.BILINEAR,
231
+ )
232
+ for f in frames_t
233
+ ])
234
+
235
+ # ---------------- Affine ----------------
236
+ if getattr(aug, "random_affine", False):
237
+ max_translate = getattr(aug, "translate", (0.1, 0.1))
238
+ scale_range = getattr(aug, "affine_scale", (0.9, 1.0))
239
+
240
+ tx = int(random.uniform(-max_translate[0], max_translate[0]) * W)
241
+ ty = int(random.uniform(-max_translate[1], max_translate[1]) * H)
242
+ scale_factor = random.uniform(scale_range[0], scale_range[1])
243
+
244
+ frames_t = torch.stack([
245
+ F.affine(
246
+ f,
247
+ angle=0.0,
248
+ translate=[tx, ty],
249
+ scale=scale_factor,
250
+ shear=[0.0, 0.0],
251
+ interpolation=F.InterpolationMode.BILINEAR,
252
+ )
253
+ for f in frames_t
254
+ ])
255
+
256
+ # ---------------- Perspective ----------------
257
+ if getattr(aug, "random_perspective", False):
258
+ distortion_scale = getattr(aug, "distortion_scale", 0.3)
259
+ p = getattr(aug, "perspective_prob", 0.5)
260
+
261
+ if random.random() < p:
262
+ startpoints, endpoints = T.RandomPerspective.get_params(
263
+ width=W, height=H, distortion_scale=distortion_scale
264
+ )
265
+
266
+ frames_t = torch.stack([
267
+ F.perspective(
268
+ f,
269
+ startpoints=startpoints,
270
+ endpoints=endpoints,
271
+ interpolation=F.InterpolationMode.BILINEAR,
272
+ )
273
+ for f in frames_t
274
+ ])
275
+
276
+ # ---------------- Rotation ----------------
277
+ if getattr(aug, "random_rotation", False):
278
+ degrees = getattr(aug, "rotation_degrees", 5)
279
+ angle = random.uniform(-degrees, degrees)
280
+
281
+ frames_t = torch.stack([
282
+ F.rotate(f, angle=angle,
283
+ interpolation=F.InterpolationMode.BILINEAR)
284
+ for f in frames_t
285
+ ])
286
+
287
+ # ---------------- Color jitter ----------------
288
+ if getattr(aug, "color_jitter", False):
289
+ jitter_prob = getattr(aug, "jitter_prob", 1.0) # defaults to 1.0
290
+ brightness, contrast, saturation, hue = getattr(
291
+ aug, "jitter_params", (0.2, 0.2, 0.2, 0.05)
292
+ )
293
+
294
+ jitter = T.ColorJitter(brightness, contrast, saturation, hue)
295
+ frames_t = torch.stack([jitter(f) for f in frames_t])
296
+
297
+ # ---------------- Flip ----------------
298
+ if getattr(aug, "random_horizontal_flip", False):
299
+ if random.random() < getattr(aug, "flip_prob", 0.5):
300
+ frames_t = torch.flip(frames_t, dims=[3])
301
+
302
+ result = frames_t.permute(0, 2, 3, 1)
303
+ return result.numpy().astype(np.float32 if is_float else np.uint8)
304
+
305
+
306
+ def build_transform(config, mode="train"):
307
+ return VideoTransform(config, mode)
308
+
309
+ def get_transforms_model(pre_model):
310
+ from torchvision.models.video import R3D_18_Weights, MC3_18_Weights
311
+ from torchvision.models.video import R2Plus1D_18_Weights, S3D_Weights
312
+ from torchvision.models.video import MViT_V2_S_Weights, MViT_V1_B_Weights
313
+ from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights, mvit_v1_b, MViT_V1_B_Weights
314
+
315
+ if pre_model == "r3d_18":
316
+ transforms_model = R3D_18_Weights.KINETICS400_V1.transforms()
317
+ elif pre_model == "s3d":
318
+ transforms_model = S3D_Weights.KINETICS400_V1.transforms()
319
+ elif pre_model == "mc3_18":
320
+ transforms_model = MC3_18_Weights.KINETICS400_V1.transforms()
321
+ elif pre_model == "r2plus1d_18":
322
+ transforms_model = R2Plus1D_18_Weights.KINETICS400_V1.transforms()
323
+ elif pre_model == "mvit_v2_s":
324
+ transforms_model = MViT_V2_S_Weights.KINETICS400_V1.transforms()
325
+ else:
326
+ transforms_model = R2Plus1D_18_Weights.KINETICS400_V1.transforms()
327
+
328
+ return transforms_model
329
+
330
+
331
+ # import torch
332
+ # import numpy as np
333
+ # import decord
334
+ # from decord import cpu
335
+ # import torchvision.transforms as T
336
+
337
+ # def read_video(video_path):
338
+ # """Read video frames into tensor (T, C, H, W)"""
339
+ # vr = decord.VideoReader(video_path, ctx=cpu(0))
340
+ # frames = vr.get_batch(range(len(vr))) # (T, H, W, C)
341
+ # frames = torch.from_numpy(frames.asnumpy()).permute(0, 3, 1, 2) # (T, C, H, W)
342
+ # return frames
343
+
344
+ # def resample_video_idx(num_frames, original_fps, new_fps):
345
+ # """Return frame indices to match new fps"""
346
+ # step = float(original_fps) / new_fps
347
+ # if step.is_integer():
348
+ # step = int(step)
349
+ # return slice(None, None, step)
350
+ # idxs = torch.arange(num_frames, dtype=torch.float32) * step
351
+ # idxs = idxs.floor().to(torch.int64)
352
+ # return idxs
353
+
354
+ # def process_frames(video_tensor, target_num_frames, input_fps, target_fps=None):
355
+ # """Uniform sampling, padding, FPS adjustment"""
356
+ # target_fps = target_fps or input_fps
357
+ # num_frames = video_tensor.size(0)
358
+ # duration = num_frames / input_fps
359
+
360
+ # # Too short → resample
361
+ # if num_frames < target_num_frames:
362
+ # new_fps = np.ceil(target_num_frames / duration)
363
+ # idxs = resample_video_idx(target_num_frames, input_fps, new_fps)
364
+ # idxs = np.clip(idxs, 0, num_frames - 1)
365
+ # video_tensor = video_tensor[idxs]
366
+
367
+ # # Too long → uniform sample
368
+ # elif num_frames > target_num_frames:
369
+ # idxs = np.linspace(0, num_frames - 1, target_num_frames)
370
+ # video_tensor = video_tensor[idxs.astype(int)]
371
+
372
+ # # Pad if still short
373
+ # if video_tensor.size(0) < target_num_frames:
374
+ # pad = target_num_frames - video_tensor.size(0)
375
+ # last_frame = video_tensor[-1:].repeat(pad, 1, 1, 1)
376
+ # video_tensor = torch.cat([video_tensor, last_frame], dim=0)
377
+
378
+ # return video_tensor
379
+
380
+ # def build_transform(frame_size=(224, 224), mean=None, std=None):
381
+ # """Return default frame transform"""
382
+ # mean = mean or [0.485, 0.456, 0.406]
383
+ # std = std or [0.229, 0.224, 0.225]
384
+ # return T.Compose([
385
+ # T.ConvertImageDtype(torch.float32),
386
+ # T.Resize(frame_size),
387
+ # T.Normalize(mean, std),
388
+ # ])
389
+
@@ -0,0 +1,110 @@
1
+ import wandb
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import logging
5
+
6
+ def init_wandb(cfg, run_id, use_wandb=False):
7
+ """
8
+ Initialize Weights & Biases if enabled.
9
+
10
+ Args:
11
+ cfg: config object with attributes:
12
+ - use_wandb (bool)
13
+ - project_name (str)
14
+ - run_name (str)
15
+ """
16
+
17
+ if not use_wandb:
18
+ logging.info("W&B disabled.")
19
+ return None
20
+
21
+ try:
22
+ import wandb
23
+ except ImportError:
24
+ logging.warning("wandb not installed. Install with `pip install wandb`.")
25
+ return None
26
+
27
+ if getattr(cfg.DATA, "data_modality", None):
28
+ run_name = f"{cfg.MODEL.backbone.type}_{cfg.DATA.data_modality}"
29
+ else:
30
+ run_name = f"{cfg.MODEL.backbone.type}"
31
+
32
+ wandb.init(
33
+ project=cfg.TASK,
34
+ name=run_name,
35
+ id=run_id,
36
+ resume="allow",
37
+ config=vars(cfg) if hasattr(cfg, "__dict__") else cfg,
38
+ )
39
+
40
+ logging.info(f"Wandb initialised")
41
+ return wandb
42
+
43
+ def log_table_wandb(name, rows, headers):
44
+ """
45
+ Log a table to Weights & Biases.
46
+
47
+ Args:
48
+ name (str): Name of the table in wandb.
49
+ rows (list[list]): Table rows.
50
+ headers (list[str]): Column headers.
51
+ """
52
+ if wandb.run is None:
53
+ return
54
+
55
+ table = wandb.Table(columns=headers)
56
+
57
+ for row in rows:
58
+ table.add_data(*row)
59
+
60
+ wandb.log({name: table})
61
+
62
+ def log_attention_wandb(attention, split_name):
63
+
64
+ attn = attention.detach().cpu().numpy()
65
+
66
+ fig, ax = plt.subplots(figsize=(6, 3))
67
+ ax.imshow(attn, aspect="auto", cmap="viridis")
68
+ ax.set_title(f"{split_name} Attention Map")
69
+ ax.set_xlabel("Views / Time")
70
+ ax.set_ylabel("Batch")
71
+
72
+ wandb.log({
73
+ f"{split_name}/attention_map": wandb.Image(fig)
74
+ })
75
+
76
+ plt.close(fig)
77
+
78
+
79
+ def log_sample_videos_wandb(mvclips, preds, labels, split_name, max_samples=2, fps=5):
80
+
81
+
82
+ # mvclips: (B, V, C, T, H, W)
83
+ mvclips = mvclips.detach().cpu().numpy()
84
+
85
+ for i in range(min(len(mvclips), max_samples)):
86
+ views = mvclips[i] # (V, C, T, H, W)
87
+
88
+ # Log each view separately
89
+ for v in range(views.shape[0]):
90
+ video = views[v].transpose(1, 2, 3, 0) # (T, H, W, C)
91
+ video = (video * 255).astype(np.uint8) if video.max() <= 1.0 else video
92
+
93
+ wandb.log({
94
+ f"{split_name}/sample_{i}_view_{v}": wandb.Video(
95
+ video,
96
+ fps=fps,
97
+ caption=f"Pred: {preds[i]}, GT: {labels[i]}"
98
+ )
99
+ })
100
+
101
+
102
+ def log_confusion_matrix_wandb(y_true, y_pred, class_names, split_name):
103
+ wandb.log({
104
+ f"{split_name}/confusion_matrix": wandb.plot.confusion_matrix(
105
+ probs=None,
106
+ y_true=y_true,
107
+ preds=y_pred,
108
+ class_names=class_names
109
+ )
110
+ })
File without changes
@@ -0,0 +1,42 @@
1
+ # opensportslib/datasets/builder.py
2
+ # from .spotting_dataset import SpottingDataset
3
+
4
+ def build_dataset(config, annotation_file=None, processor=None, split="train"):
5
+ """Return a dataset instance based on model type"""
6
+ task = config.TASK.lower()
7
+
8
+ if "classification" in task:
9
+ from opensportslib.datasets import classification_dataset
10
+ return classification_dataset.build(config, annotation_file, processor, split)
11
+
12
+ elif "localization" in task:
13
+ from opensportslib.datasets.localization_dataset import LocalizationDataset
14
+ return LocalizationDataset(config, annotation_file, processor, split=split)
15
+
16
+ else:
17
+ raise ValueError(f"No dataset found for task: {task}")
18
+
19
+
20
+ ##### --------- ####
21
+ # def build_dataset(config, annotation_file=None, processor=None, split="train"):
22
+ # """Return a dataset instance based on task and modality"""
23
+ # task = config.TASK.lower()
24
+
25
+ # if "classification" in task:
26
+ # modality = config.DATA.data_modality.lower()
27
+
28
+ # if modality == "tracking_parquet":
29
+ # from opensportslib.datasets.classification_dataset import TrackingDataset
30
+ # return TrackingDataset(config, annotation_file, split)
31
+ # elif modality == "video":
32
+ # from opensportslib.datasets.classification_dataset import VideoDataset
33
+ # return VideoDataset(config, annotation_file, processor, split)
34
+ # else:
35
+ # raise ValueError(f"Unknown data_modality: {modality}")
36
+
37
+ # elif "localization" in task:
38
+ # from opensportslib.datasets.localization_dataset import LocalizationDataset
39
+ # return LocalizationDataset(config, annotation_file, processor, split=split)
40
+
41
+ # else:
42
+ # raise ValueError(f"No dataset found for task: {task}")