opensportslib 0.0.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opensportslib/__init__.py +18 -0
- opensportslib/apis/__init__.py +21 -0
- opensportslib/apis/classification.py +361 -0
- opensportslib/apis/localization.py +228 -0
- opensportslib/config/classification.yaml +104 -0
- opensportslib/config/classification_tracking.yaml +103 -0
- opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
- opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
- opensportslib/config/localization.yaml +132 -0
- opensportslib/config/sngar_frames.yaml +98 -0
- opensportslib/core/__init__.py +0 -0
- opensportslib/core/loss/__init__.py +0 -0
- opensportslib/core/loss/builder.py +40 -0
- opensportslib/core/loss/calf.py +258 -0
- opensportslib/core/loss/ce.py +23 -0
- opensportslib/core/loss/combine.py +42 -0
- opensportslib/core/loss/nll.py +25 -0
- opensportslib/core/optimizer/__init__.py +0 -0
- opensportslib/core/optimizer/builder.py +38 -0
- opensportslib/core/sampler/weighted_sampler.py +104 -0
- opensportslib/core/scheduler/__init__.py +0 -0
- opensportslib/core/scheduler/builder.py +77 -0
- opensportslib/core/trainer/__init__.py +0 -0
- opensportslib/core/trainer/classification_trainer.py +1131 -0
- opensportslib/core/trainer/localization_trainer.py +1009 -0
- opensportslib/core/utils/checkpoint.py +238 -0
- opensportslib/core/utils/config.py +199 -0
- opensportslib/core/utils/data.py +85 -0
- opensportslib/core/utils/ddp.py +77 -0
- opensportslib/core/utils/default_args.py +110 -0
- opensportslib/core/utils/load_annotations.py +485 -0
- opensportslib/core/utils/seed.py +26 -0
- opensportslib/core/utils/video_processing.py +389 -0
- opensportslib/core/utils/wandb.py +110 -0
- opensportslib/datasets/__init__.py +0 -0
- opensportslib/datasets/builder.py +42 -0
- opensportslib/datasets/classification_dataset.py +582 -0
- opensportslib/datasets/localization_dataset.py +813 -0
- opensportslib/datasets/utils/__init__.py +15 -0
- opensportslib/datasets/utils/tracking.py +615 -0
- opensportslib/metrics/classification_metric.py +176 -0
- opensportslib/metrics/localization_metric.py +1482 -0
- opensportslib/models/__init__.py +0 -0
- opensportslib/models/backbones/builder.py +590 -0
- opensportslib/models/base/e2e.py +252 -0
- opensportslib/models/base/tracking.py +73 -0
- opensportslib/models/base/vars.py +29 -0
- opensportslib/models/base/video.py +130 -0
- opensportslib/models/base/video_mae.py +60 -0
- opensportslib/models/builder.py +43 -0
- opensportslib/models/heads/builder.py +266 -0
- opensportslib/models/neck/builder.py +210 -0
- opensportslib/models/utils/common.py +176 -0
- opensportslib/models/utils/impl/__init__.py +0 -0
- opensportslib/models/utils/impl/asformer.py +390 -0
- opensportslib/models/utils/impl/calf.py +74 -0
- opensportslib/models/utils/impl/gsm.py +112 -0
- opensportslib/models/utils/impl/gtad.py +347 -0
- opensportslib/models/utils/impl/tsm.py +123 -0
- opensportslib/models/utils/litebase.py +59 -0
- opensportslib/models/utils/modules.py +120 -0
- opensportslib/models/utils/shift.py +135 -0
- opensportslib/models/utils/utils.py +276 -0
- opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
- opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
- opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
- opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,485 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import tqdm
|
|
4
|
+
import logging
|
|
5
|
+
import cv2
|
|
6
|
+
import math
|
|
7
|
+
import torch
|
|
8
|
+
from opensportslib.core.utils.video_processing import get_stride, read_fps, get_num_frames
|
|
9
|
+
from opensportslib.core.utils.config import load_json
|
|
10
|
+
from collections import defaultdict
|
|
11
|
+
|
|
12
|
+
def load_annotations(annotations_path, task_key="action", exclude_labels=[""], multiview=False, input_type="video", allow_missing_labels=False):
|
|
13
|
+
|
|
14
|
+
with open(annotations_path, "r") as f:
|
|
15
|
+
data = json.load(f)
|
|
16
|
+
|
|
17
|
+
exclude_labels = set(exclude_labels or [""])
|
|
18
|
+
|
|
19
|
+
# Label list for the selected task
|
|
20
|
+
label_list = [
|
|
21
|
+
lbl for lbl in data["labels"][task_key]["labels"]
|
|
22
|
+
if lbl not in exclude_labels
|
|
23
|
+
]
|
|
24
|
+
label_map = {name: idx for idx, name in enumerate(label_list)}
|
|
25
|
+
|
|
26
|
+
# Group by action id (without view suffix)
|
|
27
|
+
grouped = defaultdict(lambda: {
|
|
28
|
+
"video_paths": [],
|
|
29
|
+
"label": None
|
|
30
|
+
})
|
|
31
|
+
|
|
32
|
+
for item in data["data"]:
|
|
33
|
+
label_idx = None
|
|
34
|
+
|
|
35
|
+
if "labels" in item and task_key in item["labels"]:
|
|
36
|
+
action_label = item["labels"][task_key].get("label", None)
|
|
37
|
+
if action_label in exclude_labels:
|
|
38
|
+
continue
|
|
39
|
+
|
|
40
|
+
if action_label in label_map:
|
|
41
|
+
label_idx = label_map[action_label]
|
|
42
|
+
|
|
43
|
+
elif not allow_missing_labels:
|
|
44
|
+
# training mode requires labels
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
# Extract group key
|
|
48
|
+
item_id = item["id"]
|
|
49
|
+
if multiview and "_view" in item_id:
|
|
50
|
+
group_id = item_id.rsplit("_view", 1)[0]
|
|
51
|
+
else:
|
|
52
|
+
group_id = item_id
|
|
53
|
+
|
|
54
|
+
# Collect clips
|
|
55
|
+
clips = [
|
|
56
|
+
inp["path"]
|
|
57
|
+
for inp in item.get("inputs", [])
|
|
58
|
+
if inp.get("type") == input_type and "path" in inp
|
|
59
|
+
]
|
|
60
|
+
if not clips:
|
|
61
|
+
continue
|
|
62
|
+
|
|
63
|
+
grouped[group_id]["video_paths"].extend(clips)
|
|
64
|
+
if label_idx is not None:
|
|
65
|
+
grouped[group_id]["label"] = label_idx
|
|
66
|
+
grouped[group_id]["id"] = group_id
|
|
67
|
+
|
|
68
|
+
return list(grouped.values()), label_map
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def load_annotations_(annotations_path, exclude_labels=None):
|
|
72
|
+
with open(annotations_path, "r") as f:
|
|
73
|
+
data = json.load(f)
|
|
74
|
+
|
|
75
|
+
exclude_labels = exclude_labels or ["", "Challenge"]
|
|
76
|
+
# Filter labels
|
|
77
|
+
label_list = [
|
|
78
|
+
name for name in data["labels"]["foul_type"]["labels"]
|
|
79
|
+
if name not in exclude_labels
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
label_map = {name: idx for idx, name in enumerate(label_list)}
|
|
83
|
+
samples = []
|
|
84
|
+
|
|
85
|
+
for item in data["data"]:
|
|
86
|
+
foul_label = item["labels"]["foul_type"]["label"]
|
|
87
|
+
|
|
88
|
+
# Skip unwanted labels
|
|
89
|
+
if foul_label in exclude_labels:
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
label_idx = label_map.get(foul_label, -1)
|
|
93
|
+
if label_idx == -1:
|
|
94
|
+
continue
|
|
95
|
+
|
|
96
|
+
# collect *all* video clips
|
|
97
|
+
all_clips = [
|
|
98
|
+
c.get("path", "") for c in item.get("inputs", [])
|
|
99
|
+
if c.get("type") == "video"
|
|
100
|
+
]
|
|
101
|
+
all_clips = [p.replace("Dataset/Train", "train")
|
|
102
|
+
.replace("Dataset/Test", "test")
|
|
103
|
+
.replace("Dataset/Valid", "valid") + ".mp4"
|
|
104
|
+
for p in all_clips if p]
|
|
105
|
+
|
|
106
|
+
if not all_clips:
|
|
107
|
+
continue
|
|
108
|
+
|
|
109
|
+
samples.append({
|
|
110
|
+
"video_paths": all_clips,
|
|
111
|
+
"label": label_idx
|
|
112
|
+
})
|
|
113
|
+
print(label_map)
|
|
114
|
+
return samples, label_map
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def has_localization_events(annotation_path):
|
|
118
|
+
import json
|
|
119
|
+
|
|
120
|
+
if annotation_path is None:
|
|
121
|
+
return False
|
|
122
|
+
|
|
123
|
+
with open(annotation_path) as f:
|
|
124
|
+
data = json.load(f)
|
|
125
|
+
|
|
126
|
+
for item in data.get("data", []):
|
|
127
|
+
if item.get("events"):
|
|
128
|
+
return True
|
|
129
|
+
|
|
130
|
+
return False
|
|
131
|
+
|
|
132
|
+
def annotationstoe2eformat(
|
|
133
|
+
label_files,
|
|
134
|
+
video_dirs,
|
|
135
|
+
input_fps,
|
|
136
|
+
extract_fps,
|
|
137
|
+
dali
|
|
138
|
+
):
|
|
139
|
+
"""
|
|
140
|
+
Adapt SN Ball Action Spotting annotations to E2E format.
|
|
141
|
+
|
|
142
|
+
Supports JSON with:
|
|
143
|
+
- top-level "data"
|
|
144
|
+
- video path in inputs[0]["path"]
|
|
145
|
+
- events with "label" and "position_ms"
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
label_files (str | list[str]): Annotation JSON files
|
|
149
|
+
video_dirs (str | list[str]): Root video directories
|
|
150
|
+
input_fps (int): FPS expected by the model
|
|
151
|
+
extract_fps (int): FPS for frame extraction
|
|
152
|
+
dali (bool): Whether using DALI or OpenCV
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
if not isinstance(label_files, list):
|
|
156
|
+
label_files = [label_files]
|
|
157
|
+
if not isinstance(video_dirs, list):
|
|
158
|
+
video_dirs = [video_dirs]
|
|
159
|
+
|
|
160
|
+
assert len(label_files) == len(video_dirs)
|
|
161
|
+
|
|
162
|
+
labels_e2e = []
|
|
163
|
+
classes_by_label_dir = []
|
|
164
|
+
task_name_list = []
|
|
165
|
+
|
|
166
|
+
for label_path, video_dir in zip(label_files, video_dirs):
|
|
167
|
+
logging.info(f"Processing {label_path} to e2e format")
|
|
168
|
+
|
|
169
|
+
annotations = load_json(label_path)
|
|
170
|
+
|
|
171
|
+
# ---- Extract class list (ball_action) ----
|
|
172
|
+
for task_name, task_data in annotations["labels"].items():
|
|
173
|
+
labels = task_data.get("labels", {})
|
|
174
|
+
task_name_list.append(task_name)
|
|
175
|
+
|
|
176
|
+
classes_by_label_dir.append(labels)
|
|
177
|
+
|
|
178
|
+
# ---- Iterate videos ----
|
|
179
|
+
videos = annotations["data"]
|
|
180
|
+
|
|
181
|
+
for video in tqdm.tqdm(videos):
|
|
182
|
+
# ---- Video path & metadata ----
|
|
183
|
+
video_path = video["inputs"][0]["path"].replace(" ", "_")
|
|
184
|
+
#game_dir = os.path.dirname(video_path)
|
|
185
|
+
#game_name = os.path.basename(video_path)
|
|
186
|
+
full_video_path = os.path.join(video_dir, video_path)
|
|
187
|
+
assert os.path.isfile(full_video_path), full_video_path
|
|
188
|
+
vc = cv2.VideoCapture(full_video_path)
|
|
189
|
+
width = int(vc.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
190
|
+
height = int(vc.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
191
|
+
fps = vc.get(cv2.CAP_PROP_FPS)
|
|
192
|
+
num_frames = int(vc.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
193
|
+
|
|
194
|
+
# ---- FPS handling ----
|
|
195
|
+
target_fps = extract_fps if extract_fps < fps else fps
|
|
196
|
+
sample_fps = read_fps(fps, target_fps)
|
|
197
|
+
|
|
198
|
+
num_frames_after = get_num_frames(
|
|
199
|
+
num_frames, fps, target_fps
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
if dali:
|
|
203
|
+
if get_stride(fps, target_fps) != get_stride(input_fps, extract_fps):
|
|
204
|
+
sample_fps = fps / get_stride(input_fps, extract_fps)
|
|
205
|
+
num_frames_dali = math.ceil(
|
|
206
|
+
num_frames / get_stride(input_fps, extract_fps)
|
|
207
|
+
)
|
|
208
|
+
else:
|
|
209
|
+
num_frames_dali = num_frames_after
|
|
210
|
+
|
|
211
|
+
# ---- Events ----
|
|
212
|
+
events = []
|
|
213
|
+
for ann in video.get("events", []):
|
|
214
|
+
position_ms = float(ann["position_ms"])
|
|
215
|
+
|
|
216
|
+
if dali:
|
|
217
|
+
if get_stride(fps, target_fps) != get_stride(input_fps, extract_fps):
|
|
218
|
+
adj_frame = (
|
|
219
|
+
position_ms / 1000
|
|
220
|
+
* (fps / get_stride(input_fps, extract_fps))
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
adj_frame = position_ms / 1000 * sample_fps
|
|
224
|
+
else:
|
|
225
|
+
adj_frame = position_ms / 1000 * sample_fps
|
|
226
|
+
|
|
227
|
+
if int(adj_frame) == 0:
|
|
228
|
+
adj_frame = 1
|
|
229
|
+
|
|
230
|
+
events.append({
|
|
231
|
+
"frame": int(adj_frame),
|
|
232
|
+
"label": ann["label"],
|
|
233
|
+
|
|
234
|
+
})
|
|
235
|
+
|
|
236
|
+
events.sort(key=lambda x: x["frame"])
|
|
237
|
+
|
|
238
|
+
labels_e2e.append({
|
|
239
|
+
"events": events,
|
|
240
|
+
"fps": sample_fps,
|
|
241
|
+
"num_frames": num_frames_dali if dali else num_frames_after,
|
|
242
|
+
"num_frames_base": num_frames,
|
|
243
|
+
"num_events": len(events),
|
|
244
|
+
"width": width,
|
|
245
|
+
"height": height,
|
|
246
|
+
"video": full_video_path,
|
|
247
|
+
"path": video_path,
|
|
248
|
+
})
|
|
249
|
+
|
|
250
|
+
# ---- Sanity checks ----
|
|
251
|
+
base_classes = classes_by_label_dir[0]
|
|
252
|
+
for c in classes_by_label_dir:
|
|
253
|
+
assert c == base_classes
|
|
254
|
+
|
|
255
|
+
labels_e2e.sort(key=lambda x: x["video"])
|
|
256
|
+
|
|
257
|
+
return labels_e2e, task_name_list[0]
|
|
258
|
+
|
|
259
|
+
# def annotationstoe2eformat(label_files, video_dirs, input_fps, extract_fps, dali):
|
|
260
|
+
# """Adapt annotations jsons to e2e format.
|
|
261
|
+
|
|
262
|
+
# Args:
|
|
263
|
+
# label_files (string,list[string]): Json files of annotations.
|
|
264
|
+
# label_dirs (string,list[string]): Data root folder of videos. Must match number of label files.
|
|
265
|
+
# input_fps (int): Fps of input videos.
|
|
266
|
+
# extract_fps (int): Fps at which we extract frames.
|
|
267
|
+
# dali (bool): WHether processing with dali or opencv.
|
|
268
|
+
# """
|
|
269
|
+
|
|
270
|
+
# if not isinstance(label_files, list):
|
|
271
|
+
# label_files = [label_files]
|
|
272
|
+
# if not isinstance(video_dirs, list):
|
|
273
|
+
# video_dirs = [video_dirs]
|
|
274
|
+
# assert len(label_files) == len(video_dirs)
|
|
275
|
+
|
|
276
|
+
# labels_e2e = list()
|
|
277
|
+
# classes_by_label_dir = []
|
|
278
|
+
# for label_dir, video_dir in zip(label_files, video_dirs):
|
|
279
|
+
# logging.info("Processing " + label_dir + " to e2e format.")
|
|
280
|
+
# videos = []
|
|
281
|
+
# annotations = load_json(label_dir)
|
|
282
|
+
# labels = annotations["labels"]
|
|
283
|
+
# classes_by_label_dir.append(labels)
|
|
284
|
+
# videos = annotations["videos"]
|
|
285
|
+
# for video in tqdm.tqdm(videos):
|
|
286
|
+
# if "annotations" in video.keys():
|
|
287
|
+
# video_annotations = video["annotations"]
|
|
288
|
+
# else:
|
|
289
|
+
# video_annotations = []
|
|
290
|
+
|
|
291
|
+
# num_events = 0
|
|
292
|
+
|
|
293
|
+
# vc = cv2.VideoCapture(os.path.join(video_dir, video["path"]))
|
|
294
|
+
# width = int(vc.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
295
|
+
# height = int(vc.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
296
|
+
# fps = vc.get(cv2.CAP_PROP_FPS)
|
|
297
|
+
# num_frames = int(vc.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
298
|
+
|
|
299
|
+
# sample_fps = read_fps(fps, extract_fps if extract_fps < fps else fps)
|
|
300
|
+
# num_frames_after = get_num_frames(
|
|
301
|
+
# num_frames, fps, extract_fps if extract_fps < fps else fps
|
|
302
|
+
# )
|
|
303
|
+
|
|
304
|
+
# if dali:
|
|
305
|
+
# if get_stride(
|
|
306
|
+
# fps, extract_fps if extract_fps < fps else fps
|
|
307
|
+
# ) != get_stride(input_fps, extract_fps):
|
|
308
|
+
# sample_fps = fps / get_stride(input_fps, extract_fps)
|
|
309
|
+
# num_frames_dali = math.ceil(
|
|
310
|
+
# num_frames / get_stride(input_fps, extract_fps)
|
|
311
|
+
# )
|
|
312
|
+
# else:
|
|
313
|
+
# num_frames_dali = num_frames_after
|
|
314
|
+
|
|
315
|
+
# # video_id = os.path.splitext(video["path"])[0]
|
|
316
|
+
# video_id = os.path.join(video_dir, video["path"])
|
|
317
|
+
|
|
318
|
+
# events = []
|
|
319
|
+
# for annotation in video_annotations:
|
|
320
|
+
# if dali:
|
|
321
|
+
# if get_stride(
|
|
322
|
+
# fps, extract_fps if extract_fps < fps else fps
|
|
323
|
+
# ) != get_stride(input_fps, extract_fps):
|
|
324
|
+
# adj_frame = (
|
|
325
|
+
# float(annotation["position"])
|
|
326
|
+
# / 1000
|
|
327
|
+
# * (fps / get_stride(input_fps, extract_fps))
|
|
328
|
+
# )
|
|
329
|
+
# else:
|
|
330
|
+
# adj_frame = float(annotation["position"]) / 1000 * sample_fps
|
|
331
|
+
# if int(adj_frame) == 0:
|
|
332
|
+
# adj_frame = 1
|
|
333
|
+
# else:
|
|
334
|
+
# adj_frame = float(annotation["position"]) / 1000 * sample_fps
|
|
335
|
+
# events.append(
|
|
336
|
+
# {
|
|
337
|
+
# "frame": int(adj_frame),
|
|
338
|
+
# "label": annotation["label"],
|
|
339
|
+
# "team": annotation["team"],
|
|
340
|
+
# "visibility": annotation["visibility"],
|
|
341
|
+
# }
|
|
342
|
+
# )
|
|
343
|
+
|
|
344
|
+
# num_events += len(events)
|
|
345
|
+
# events.sort(key=lambda x: x["frame"])
|
|
346
|
+
|
|
347
|
+
# labels_e2e.append(
|
|
348
|
+
# {
|
|
349
|
+
# "events": events,
|
|
350
|
+
# "fps": sample_fps,
|
|
351
|
+
# "num_frames": num_frames_dali if dali else num_frames_after,
|
|
352
|
+
# "num_frames_base": num_frames,
|
|
353
|
+
# "num_events": len(events),
|
|
354
|
+
# "width": width,
|
|
355
|
+
# "height": height,
|
|
356
|
+
# "video": video_id,
|
|
357
|
+
# "path": video["path"],
|
|
358
|
+
# }
|
|
359
|
+
# )
|
|
360
|
+
# assert len(video_annotations) == num_events
|
|
361
|
+
# classes = classes_by_label_dir[0]
|
|
362
|
+
# for classes_tmp in classes_by_label_dir:
|
|
363
|
+
# assert classes == classes_tmp
|
|
364
|
+
# labels_e2e.sort(key=lambda x: x["video"])
|
|
365
|
+
# return labels_e2e
|
|
366
|
+
|
|
367
|
+
def construct_labels(path, extract_fps):
|
|
368
|
+
"""This method is used when the input of the dataset is a video file instead of a json file.
|
|
369
|
+
It creates a pseudo json by processing the video to get metadatas.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
path (string): The path of the video file.
|
|
373
|
+
extract_fps (int): The fps at which we want to extract frames.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
List(dict): The pseudo json object.
|
|
377
|
+
(int): stride at which we will process the video.
|
|
378
|
+
"""
|
|
379
|
+
wanted_sample_fps = extract_fps
|
|
380
|
+
vc = cv2.VideoCapture(path)
|
|
381
|
+
fps = vc.get(cv2.CAP_PROP_FPS)
|
|
382
|
+
num_frames = int(vc.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
383
|
+
|
|
384
|
+
sample_fps = read_fps(fps, wanted_sample_fps if wanted_sample_fps < fps else fps)
|
|
385
|
+
num_frames_after = get_num_frames(
|
|
386
|
+
num_frames, fps, wanted_sample_fps if wanted_sample_fps < fps else fps
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
# def get_repartition_gpu():
|
|
391
|
+
# """Returns the distribution of gpus that will be used by pipelines for dali."""
|
|
392
|
+
# x = torch.cuda.device_count() - 1
|
|
393
|
+
# print("Number of gpus:", x)
|
|
394
|
+
# if x == 1:
|
|
395
|
+
# return [0], [0]
|
|
396
|
+
# if x == 2:
|
|
397
|
+
# return [0, 1], [0, 1]
|
|
398
|
+
# elif x == 3:
|
|
399
|
+
# return [0, 1], [1, 2]
|
|
400
|
+
# elif x > 3:
|
|
401
|
+
# return [0, 1, 2, 3], [0, 2, 1, 3]
|
|
402
|
+
|
|
403
|
+
def get_repartition_gpu(max_train_gpus=2):
|
|
404
|
+
n = torch.cuda.device_count()
|
|
405
|
+
|
|
406
|
+
if n == 0:
|
|
407
|
+
return [], []
|
|
408
|
+
|
|
409
|
+
train_gpus = list(range(min(n, max_train_gpus)))
|
|
410
|
+
dali_gpus = train_gpus.copy()
|
|
411
|
+
|
|
412
|
+
return train_gpus, dali_gpus
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def check_config(cfg, split="train"):
|
|
416
|
+
"""Check for incoherences, missing elements in dict config.
|
|
417
|
+
The checks are different regarding the methods.
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
cfg (dict): Config dictionnary.
|
|
421
|
+
|
|
422
|
+
"""
|
|
423
|
+
from opensportslib.core.utils.config import load_json, load_classes
|
|
424
|
+
from omegaconf import ListConfig
|
|
425
|
+
if cfg.MODEL.runner.type == "runner_e2e":
|
|
426
|
+
if cfg.dali == True:
|
|
427
|
+
cfg.TRAIN.repartitions = get_repartition_gpu(cfg.SYSTEM.GPU)
|
|
428
|
+
assert cfg.DATA.modality in ["rgb"]
|
|
429
|
+
assert cfg.MODEL.backbone.type in [
|
|
430
|
+
# From torchvision
|
|
431
|
+
"rn18",
|
|
432
|
+
"rn18_tsm",
|
|
433
|
+
"rn18_gsm",
|
|
434
|
+
"rn50",
|
|
435
|
+
"rn50_tsm",
|
|
436
|
+
"rn50_gsm",
|
|
437
|
+
# From timm (following its naming conventions)
|
|
438
|
+
"rny002",
|
|
439
|
+
"rny002_tsm",
|
|
440
|
+
"rny002_gsm",
|
|
441
|
+
"rny008",
|
|
442
|
+
"rny008_tsm",
|
|
443
|
+
"rny008_gsm",
|
|
444
|
+
# From timm
|
|
445
|
+
"convnextt",
|
|
446
|
+
"convnextt_tsm",
|
|
447
|
+
"convnextt_gsm",
|
|
448
|
+
]
|
|
449
|
+
assert cfg.MODEL.head.type in ["", "gru", "deeper_gru", "mstcn", "asformer"]
|
|
450
|
+
# assert cfg.dataset.batch_size % cfg.training.acc_grad_iter == 0
|
|
451
|
+
assert cfg.DATA.train.dataloader.batch_size % cfg.TRAIN.acc_grad_iter == 0
|
|
452
|
+
assert cfg.TRAIN.criterion_valid in ["map", "loss"]
|
|
453
|
+
assert cfg.TRAIN.num_epochs == cfg.TRAIN.scheduler.num_epochs
|
|
454
|
+
assert cfg.TRAIN.acc_grad_iter == cfg.TRAIN.scheduler.acc_grad_iter
|
|
455
|
+
|
|
456
|
+
if split=="train":
|
|
457
|
+
data_path = cfg.DATA.train.path
|
|
458
|
+
elif split=="valid":
|
|
459
|
+
data_path = cfg.DATA.valid.path
|
|
460
|
+
elif split=="test":
|
|
461
|
+
data_path = cfg.DATA.test.path
|
|
462
|
+
else:
|
|
463
|
+
raise ValueError(f"Unknown split {split}")
|
|
464
|
+
|
|
465
|
+
if cfg.TRAIN.start_valid_epoch is None:
|
|
466
|
+
cfg.TRAIN.start_valid_epoch = (
|
|
467
|
+
cfg.TRAIN.num_epochs - cfg.TRAIN.base_num_valid_epochs
|
|
468
|
+
)
|
|
469
|
+
if cfg.DATA.crop_dim is None or cfg.DATA.crop_dim <= 0:
|
|
470
|
+
cfg.DATA.crop_dim = None
|
|
471
|
+
if (
|
|
472
|
+
data_path != None
|
|
473
|
+
and os.path.isfile(data_path)
|
|
474
|
+
and data_path.endswith(".json")
|
|
475
|
+
and "labels" in load_json(data_path).keys()
|
|
476
|
+
):
|
|
477
|
+
for task_name, task_data in load_json(data_path)["labels"].items():
|
|
478
|
+
classes = task_data.get("labels", {})
|
|
479
|
+
#classes = load_json(cfg.DATA.test.path)["labels"]["action"]["labels"]
|
|
480
|
+
else:
|
|
481
|
+
assert isinstance(cfg.DATA.classes, (list, ListConfig))
|
|
482
|
+
classes = cfg.DATA.classes
|
|
483
|
+
|
|
484
|
+
#print(classes)
|
|
485
|
+
cfg.DATA.classes = load_classes(classes)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# opensportslib/core/utils/seed.py
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
def set_reproducibility(seed=42):
|
|
9
|
+
"""Set random seeds and deterministic flags for reproducibility."""
|
|
10
|
+
random.seed(seed)
|
|
11
|
+
np.random.seed(seed)
|
|
12
|
+
torch.manual_seed(seed)
|
|
13
|
+
torch.cuda.manual_seed(seed)
|
|
14
|
+
torch.cuda.manual_seed_all(seed)
|
|
15
|
+
torch.backends.cudnn.deterministic = True
|
|
16
|
+
torch.backends.cudnn.benchmark = False
|
|
17
|
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
|
18
|
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
19
|
+
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def seed_worker(worker_id):
|
|
23
|
+
"""Initialize random seeds for DataLoader workers."""
|
|
24
|
+
worker_seed = torch.initial_seed() % 2**32
|
|
25
|
+
np.random.seed(worker_seed)
|
|
26
|
+
random.seed(worker_seed)
|