opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,485 @@
1
+ import json
2
+ import os
3
+ import tqdm
4
+ import logging
5
+ import cv2
6
+ import math
7
+ import torch
8
+ from opensportslib.core.utils.video_processing import get_stride, read_fps, get_num_frames
9
+ from opensportslib.core.utils.config import load_json
10
+ from collections import defaultdict
11
+
12
+ def load_annotations(annotations_path, task_key="action", exclude_labels=[""], multiview=False, input_type="video", allow_missing_labels=False):
13
+
14
+ with open(annotations_path, "r") as f:
15
+ data = json.load(f)
16
+
17
+ exclude_labels = set(exclude_labels or [""])
18
+
19
+ # Label list for the selected task
20
+ label_list = [
21
+ lbl for lbl in data["labels"][task_key]["labels"]
22
+ if lbl not in exclude_labels
23
+ ]
24
+ label_map = {name: idx for idx, name in enumerate(label_list)}
25
+
26
+ # Group by action id (without view suffix)
27
+ grouped = defaultdict(lambda: {
28
+ "video_paths": [],
29
+ "label": None
30
+ })
31
+
32
+ for item in data["data"]:
33
+ label_idx = None
34
+
35
+ if "labels" in item and task_key in item["labels"]:
36
+ action_label = item["labels"][task_key].get("label", None)
37
+ if action_label in exclude_labels:
38
+ continue
39
+
40
+ if action_label in label_map:
41
+ label_idx = label_map[action_label]
42
+
43
+ elif not allow_missing_labels:
44
+ # training mode requires labels
45
+ continue
46
+
47
+ # Extract group key
48
+ item_id = item["id"]
49
+ if multiview and "_view" in item_id:
50
+ group_id = item_id.rsplit("_view", 1)[0]
51
+ else:
52
+ group_id = item_id
53
+
54
+ # Collect clips
55
+ clips = [
56
+ inp["path"]
57
+ for inp in item.get("inputs", [])
58
+ if inp.get("type") == input_type and "path" in inp
59
+ ]
60
+ if not clips:
61
+ continue
62
+
63
+ grouped[group_id]["video_paths"].extend(clips)
64
+ if label_idx is not None:
65
+ grouped[group_id]["label"] = label_idx
66
+ grouped[group_id]["id"] = group_id
67
+
68
+ return list(grouped.values()), label_map
69
+
70
+
71
+ def load_annotations_(annotations_path, exclude_labels=None):
72
+ with open(annotations_path, "r") as f:
73
+ data = json.load(f)
74
+
75
+ exclude_labels = exclude_labels or ["", "Challenge"]
76
+ # Filter labels
77
+ label_list = [
78
+ name for name in data["labels"]["foul_type"]["labels"]
79
+ if name not in exclude_labels
80
+ ]
81
+
82
+ label_map = {name: idx for idx, name in enumerate(label_list)}
83
+ samples = []
84
+
85
+ for item in data["data"]:
86
+ foul_label = item["labels"]["foul_type"]["label"]
87
+
88
+ # Skip unwanted labels
89
+ if foul_label in exclude_labels:
90
+ continue
91
+
92
+ label_idx = label_map.get(foul_label, -1)
93
+ if label_idx == -1:
94
+ continue
95
+
96
+ # collect *all* video clips
97
+ all_clips = [
98
+ c.get("path", "") for c in item.get("inputs", [])
99
+ if c.get("type") == "video"
100
+ ]
101
+ all_clips = [p.replace("Dataset/Train", "train")
102
+ .replace("Dataset/Test", "test")
103
+ .replace("Dataset/Valid", "valid") + ".mp4"
104
+ for p in all_clips if p]
105
+
106
+ if not all_clips:
107
+ continue
108
+
109
+ samples.append({
110
+ "video_paths": all_clips,
111
+ "label": label_idx
112
+ })
113
+ print(label_map)
114
+ return samples, label_map
115
+
116
+
117
+ def has_localization_events(annotation_path):
118
+ import json
119
+
120
+ if annotation_path is None:
121
+ return False
122
+
123
+ with open(annotation_path) as f:
124
+ data = json.load(f)
125
+
126
+ for item in data.get("data", []):
127
+ if item.get("events"):
128
+ return True
129
+
130
+ return False
131
+
132
+ def annotationstoe2eformat(
133
+ label_files,
134
+ video_dirs,
135
+ input_fps,
136
+ extract_fps,
137
+ dali
138
+ ):
139
+ """
140
+ Adapt SN Ball Action Spotting annotations to E2E format.
141
+
142
+ Supports JSON with:
143
+ - top-level "data"
144
+ - video path in inputs[0]["path"]
145
+ - events with "label" and "position_ms"
146
+
147
+ Args:
148
+ label_files (str | list[str]): Annotation JSON files
149
+ video_dirs (str | list[str]): Root video directories
150
+ input_fps (int): FPS expected by the model
151
+ extract_fps (int): FPS for frame extraction
152
+ dali (bool): Whether using DALI or OpenCV
153
+ """
154
+
155
+ if not isinstance(label_files, list):
156
+ label_files = [label_files]
157
+ if not isinstance(video_dirs, list):
158
+ video_dirs = [video_dirs]
159
+
160
+ assert len(label_files) == len(video_dirs)
161
+
162
+ labels_e2e = []
163
+ classes_by_label_dir = []
164
+ task_name_list = []
165
+
166
+ for label_path, video_dir in zip(label_files, video_dirs):
167
+ logging.info(f"Processing {label_path} to e2e format")
168
+
169
+ annotations = load_json(label_path)
170
+
171
+ # ---- Extract class list (ball_action) ----
172
+ for task_name, task_data in annotations["labels"].items():
173
+ labels = task_data.get("labels", {})
174
+ task_name_list.append(task_name)
175
+
176
+ classes_by_label_dir.append(labels)
177
+
178
+ # ---- Iterate videos ----
179
+ videos = annotations["data"]
180
+
181
+ for video in tqdm.tqdm(videos):
182
+ # ---- Video path & metadata ----
183
+ video_path = video["inputs"][0]["path"].replace(" ", "_")
184
+ #game_dir = os.path.dirname(video_path)
185
+ #game_name = os.path.basename(video_path)
186
+ full_video_path = os.path.join(video_dir, video_path)
187
+ assert os.path.isfile(full_video_path), full_video_path
188
+ vc = cv2.VideoCapture(full_video_path)
189
+ width = int(vc.get(cv2.CAP_PROP_FRAME_WIDTH))
190
+ height = int(vc.get(cv2.CAP_PROP_FRAME_HEIGHT))
191
+ fps = vc.get(cv2.CAP_PROP_FPS)
192
+ num_frames = int(vc.get(cv2.CAP_PROP_FRAME_COUNT))
193
+
194
+ # ---- FPS handling ----
195
+ target_fps = extract_fps if extract_fps < fps else fps
196
+ sample_fps = read_fps(fps, target_fps)
197
+
198
+ num_frames_after = get_num_frames(
199
+ num_frames, fps, target_fps
200
+ )
201
+
202
+ if dali:
203
+ if get_stride(fps, target_fps) != get_stride(input_fps, extract_fps):
204
+ sample_fps = fps / get_stride(input_fps, extract_fps)
205
+ num_frames_dali = math.ceil(
206
+ num_frames / get_stride(input_fps, extract_fps)
207
+ )
208
+ else:
209
+ num_frames_dali = num_frames_after
210
+
211
+ # ---- Events ----
212
+ events = []
213
+ for ann in video.get("events", []):
214
+ position_ms = float(ann["position_ms"])
215
+
216
+ if dali:
217
+ if get_stride(fps, target_fps) != get_stride(input_fps, extract_fps):
218
+ adj_frame = (
219
+ position_ms / 1000
220
+ * (fps / get_stride(input_fps, extract_fps))
221
+ )
222
+ else:
223
+ adj_frame = position_ms / 1000 * sample_fps
224
+ else:
225
+ adj_frame = position_ms / 1000 * sample_fps
226
+
227
+ if int(adj_frame) == 0:
228
+ adj_frame = 1
229
+
230
+ events.append({
231
+ "frame": int(adj_frame),
232
+ "label": ann["label"],
233
+
234
+ })
235
+
236
+ events.sort(key=lambda x: x["frame"])
237
+
238
+ labels_e2e.append({
239
+ "events": events,
240
+ "fps": sample_fps,
241
+ "num_frames": num_frames_dali if dali else num_frames_after,
242
+ "num_frames_base": num_frames,
243
+ "num_events": len(events),
244
+ "width": width,
245
+ "height": height,
246
+ "video": full_video_path,
247
+ "path": video_path,
248
+ })
249
+
250
+ # ---- Sanity checks ----
251
+ base_classes = classes_by_label_dir[0]
252
+ for c in classes_by_label_dir:
253
+ assert c == base_classes
254
+
255
+ labels_e2e.sort(key=lambda x: x["video"])
256
+
257
+ return labels_e2e, task_name_list[0]
258
+
259
+ # def annotationstoe2eformat(label_files, video_dirs, input_fps, extract_fps, dali):
260
+ # """Adapt annotations jsons to e2e format.
261
+
262
+ # Args:
263
+ # label_files (string,list[string]): Json files of annotations.
264
+ # label_dirs (string,list[string]): Data root folder of videos. Must match number of label files.
265
+ # input_fps (int): Fps of input videos.
266
+ # extract_fps (int): Fps at which we extract frames.
267
+ # dali (bool): WHether processing with dali or opencv.
268
+ # """
269
+
270
+ # if not isinstance(label_files, list):
271
+ # label_files = [label_files]
272
+ # if not isinstance(video_dirs, list):
273
+ # video_dirs = [video_dirs]
274
+ # assert len(label_files) == len(video_dirs)
275
+
276
+ # labels_e2e = list()
277
+ # classes_by_label_dir = []
278
+ # for label_dir, video_dir in zip(label_files, video_dirs):
279
+ # logging.info("Processing " + label_dir + " to e2e format.")
280
+ # videos = []
281
+ # annotations = load_json(label_dir)
282
+ # labels = annotations["labels"]
283
+ # classes_by_label_dir.append(labels)
284
+ # videos = annotations["videos"]
285
+ # for video in tqdm.tqdm(videos):
286
+ # if "annotations" in video.keys():
287
+ # video_annotations = video["annotations"]
288
+ # else:
289
+ # video_annotations = []
290
+
291
+ # num_events = 0
292
+
293
+ # vc = cv2.VideoCapture(os.path.join(video_dir, video["path"]))
294
+ # width = int(vc.get(cv2.CAP_PROP_FRAME_WIDTH))
295
+ # height = int(vc.get(cv2.CAP_PROP_FRAME_HEIGHT))
296
+ # fps = vc.get(cv2.CAP_PROP_FPS)
297
+ # num_frames = int(vc.get(cv2.CAP_PROP_FRAME_COUNT))
298
+
299
+ # sample_fps = read_fps(fps, extract_fps if extract_fps < fps else fps)
300
+ # num_frames_after = get_num_frames(
301
+ # num_frames, fps, extract_fps if extract_fps < fps else fps
302
+ # )
303
+
304
+ # if dali:
305
+ # if get_stride(
306
+ # fps, extract_fps if extract_fps < fps else fps
307
+ # ) != get_stride(input_fps, extract_fps):
308
+ # sample_fps = fps / get_stride(input_fps, extract_fps)
309
+ # num_frames_dali = math.ceil(
310
+ # num_frames / get_stride(input_fps, extract_fps)
311
+ # )
312
+ # else:
313
+ # num_frames_dali = num_frames_after
314
+
315
+ # # video_id = os.path.splitext(video["path"])[0]
316
+ # video_id = os.path.join(video_dir, video["path"])
317
+
318
+ # events = []
319
+ # for annotation in video_annotations:
320
+ # if dali:
321
+ # if get_stride(
322
+ # fps, extract_fps if extract_fps < fps else fps
323
+ # ) != get_stride(input_fps, extract_fps):
324
+ # adj_frame = (
325
+ # float(annotation["position"])
326
+ # / 1000
327
+ # * (fps / get_stride(input_fps, extract_fps))
328
+ # )
329
+ # else:
330
+ # adj_frame = float(annotation["position"]) / 1000 * sample_fps
331
+ # if int(adj_frame) == 0:
332
+ # adj_frame = 1
333
+ # else:
334
+ # adj_frame = float(annotation["position"]) / 1000 * sample_fps
335
+ # events.append(
336
+ # {
337
+ # "frame": int(adj_frame),
338
+ # "label": annotation["label"],
339
+ # "team": annotation["team"],
340
+ # "visibility": annotation["visibility"],
341
+ # }
342
+ # )
343
+
344
+ # num_events += len(events)
345
+ # events.sort(key=lambda x: x["frame"])
346
+
347
+ # labels_e2e.append(
348
+ # {
349
+ # "events": events,
350
+ # "fps": sample_fps,
351
+ # "num_frames": num_frames_dali if dali else num_frames_after,
352
+ # "num_frames_base": num_frames,
353
+ # "num_events": len(events),
354
+ # "width": width,
355
+ # "height": height,
356
+ # "video": video_id,
357
+ # "path": video["path"],
358
+ # }
359
+ # )
360
+ # assert len(video_annotations) == num_events
361
+ # classes = classes_by_label_dir[0]
362
+ # for classes_tmp in classes_by_label_dir:
363
+ # assert classes == classes_tmp
364
+ # labels_e2e.sort(key=lambda x: x["video"])
365
+ # return labels_e2e
366
+
367
+ def construct_labels(path, extract_fps):
368
+ """This method is used when the input of the dataset is a video file instead of a json file.
369
+ It creates a pseudo json by processing the video to get metadatas.
370
+
371
+ Args:
372
+ path (string): The path of the video file.
373
+ extract_fps (int): The fps at which we want to extract frames.
374
+
375
+ Returns:
376
+ List(dict): The pseudo json object.
377
+ (int): stride at which we will process the video.
378
+ """
379
+ wanted_sample_fps = extract_fps
380
+ vc = cv2.VideoCapture(path)
381
+ fps = vc.get(cv2.CAP_PROP_FPS)
382
+ num_frames = int(vc.get(cv2.CAP_PROP_FRAME_COUNT))
383
+
384
+ sample_fps = read_fps(fps, wanted_sample_fps if wanted_sample_fps < fps else fps)
385
+ num_frames_after = get_num_frames(
386
+ num_frames, fps, wanted_sample_fps if wanted_sample_fps < fps else fps
387
+ )
388
+
389
+
390
+ # def get_repartition_gpu():
391
+ # """Returns the distribution of gpus that will be used by pipelines for dali."""
392
+ # x = torch.cuda.device_count() - 1
393
+ # print("Number of gpus:", x)
394
+ # if x == 1:
395
+ # return [0], [0]
396
+ # if x == 2:
397
+ # return [0, 1], [0, 1]
398
+ # elif x == 3:
399
+ # return [0, 1], [1, 2]
400
+ # elif x > 3:
401
+ # return [0, 1, 2, 3], [0, 2, 1, 3]
402
+
403
+ def get_repartition_gpu(max_train_gpus=2):
404
+ n = torch.cuda.device_count()
405
+
406
+ if n == 0:
407
+ return [], []
408
+
409
+ train_gpus = list(range(min(n, max_train_gpus)))
410
+ dali_gpus = train_gpus.copy()
411
+
412
+ return train_gpus, dali_gpus
413
+
414
+
415
+ def check_config(cfg, split="train"):
416
+ """Check for incoherences, missing elements in dict config.
417
+ The checks are different regarding the methods.
418
+
419
+ Args:
420
+ cfg (dict): Config dictionnary.
421
+
422
+ """
423
+ from opensportslib.core.utils.config import load_json, load_classes
424
+ from omegaconf import ListConfig
425
+ if cfg.MODEL.runner.type == "runner_e2e":
426
+ if cfg.dali == True:
427
+ cfg.TRAIN.repartitions = get_repartition_gpu(cfg.SYSTEM.GPU)
428
+ assert cfg.DATA.modality in ["rgb"]
429
+ assert cfg.MODEL.backbone.type in [
430
+ # From torchvision
431
+ "rn18",
432
+ "rn18_tsm",
433
+ "rn18_gsm",
434
+ "rn50",
435
+ "rn50_tsm",
436
+ "rn50_gsm",
437
+ # From timm (following its naming conventions)
438
+ "rny002",
439
+ "rny002_tsm",
440
+ "rny002_gsm",
441
+ "rny008",
442
+ "rny008_tsm",
443
+ "rny008_gsm",
444
+ # From timm
445
+ "convnextt",
446
+ "convnextt_tsm",
447
+ "convnextt_gsm",
448
+ ]
449
+ assert cfg.MODEL.head.type in ["", "gru", "deeper_gru", "mstcn", "asformer"]
450
+ # assert cfg.dataset.batch_size % cfg.training.acc_grad_iter == 0
451
+ assert cfg.DATA.train.dataloader.batch_size % cfg.TRAIN.acc_grad_iter == 0
452
+ assert cfg.TRAIN.criterion_valid in ["map", "loss"]
453
+ assert cfg.TRAIN.num_epochs == cfg.TRAIN.scheduler.num_epochs
454
+ assert cfg.TRAIN.acc_grad_iter == cfg.TRAIN.scheduler.acc_grad_iter
455
+
456
+ if split=="train":
457
+ data_path = cfg.DATA.train.path
458
+ elif split=="valid":
459
+ data_path = cfg.DATA.valid.path
460
+ elif split=="test":
461
+ data_path = cfg.DATA.test.path
462
+ else:
463
+ raise ValueError(f"Unknown split {split}")
464
+
465
+ if cfg.TRAIN.start_valid_epoch is None:
466
+ cfg.TRAIN.start_valid_epoch = (
467
+ cfg.TRAIN.num_epochs - cfg.TRAIN.base_num_valid_epochs
468
+ )
469
+ if cfg.DATA.crop_dim is None or cfg.DATA.crop_dim <= 0:
470
+ cfg.DATA.crop_dim = None
471
+ if (
472
+ data_path != None
473
+ and os.path.isfile(data_path)
474
+ and data_path.endswith(".json")
475
+ and "labels" in load_json(data_path).keys()
476
+ ):
477
+ for task_name, task_data in load_json(data_path)["labels"].items():
478
+ classes = task_data.get("labels", {})
479
+ #classes = load_json(cfg.DATA.test.path)["labels"]["action"]["labels"]
480
+ else:
481
+ assert isinstance(cfg.DATA.classes, (list, ListConfig))
482
+ classes = cfg.DATA.classes
483
+
484
+ #print(classes)
485
+ cfg.DATA.classes = load_classes(classes)
@@ -0,0 +1,26 @@
1
+ # opensportslib/core/utils/seed.py
2
+
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import os
7
+
8
+ def set_reproducibility(seed=42):
9
+ """Set random seeds and deterministic flags for reproducibility."""
10
+ random.seed(seed)
11
+ np.random.seed(seed)
12
+ torch.manual_seed(seed)
13
+ torch.cuda.manual_seed(seed)
14
+ torch.cuda.manual_seed_all(seed)
15
+ torch.backends.cudnn.deterministic = True
16
+ torch.backends.cudnn.benchmark = False
17
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
18
+ os.environ['PYTHONHASHSEED'] = str(seed)
19
+ torch.use_deterministic_algorithms(True, warn_only=True)
20
+
21
+
22
+ def seed_worker(worker_id):
23
+ """Initialize random seeds for DataLoader workers."""
24
+ worker_seed = torch.initial_seed() % 2**32
25
+ np.random.seed(worker_seed)
26
+ random.seed(worker_seed)