opensportslib 0.0.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opensportslib/__init__.py +18 -0
- opensportslib/apis/__init__.py +21 -0
- opensportslib/apis/classification.py +361 -0
- opensportslib/apis/localization.py +228 -0
- opensportslib/config/classification.yaml +104 -0
- opensportslib/config/classification_tracking.yaml +103 -0
- opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
- opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
- opensportslib/config/localization.yaml +132 -0
- opensportslib/config/sngar_frames.yaml +98 -0
- opensportslib/core/__init__.py +0 -0
- opensportslib/core/loss/__init__.py +0 -0
- opensportslib/core/loss/builder.py +40 -0
- opensportslib/core/loss/calf.py +258 -0
- opensportslib/core/loss/ce.py +23 -0
- opensportslib/core/loss/combine.py +42 -0
- opensportslib/core/loss/nll.py +25 -0
- opensportslib/core/optimizer/__init__.py +0 -0
- opensportslib/core/optimizer/builder.py +38 -0
- opensportslib/core/sampler/weighted_sampler.py +104 -0
- opensportslib/core/scheduler/__init__.py +0 -0
- opensportslib/core/scheduler/builder.py +77 -0
- opensportslib/core/trainer/__init__.py +0 -0
- opensportslib/core/trainer/classification_trainer.py +1131 -0
- opensportslib/core/trainer/localization_trainer.py +1009 -0
- opensportslib/core/utils/checkpoint.py +238 -0
- opensportslib/core/utils/config.py +199 -0
- opensportslib/core/utils/data.py +85 -0
- opensportslib/core/utils/ddp.py +77 -0
- opensportslib/core/utils/default_args.py +110 -0
- opensportslib/core/utils/load_annotations.py +485 -0
- opensportslib/core/utils/seed.py +26 -0
- opensportslib/core/utils/video_processing.py +389 -0
- opensportslib/core/utils/wandb.py +110 -0
- opensportslib/datasets/__init__.py +0 -0
- opensportslib/datasets/builder.py +42 -0
- opensportslib/datasets/classification_dataset.py +582 -0
- opensportslib/datasets/localization_dataset.py +813 -0
- opensportslib/datasets/utils/__init__.py +15 -0
- opensportslib/datasets/utils/tracking.py +615 -0
- opensportslib/metrics/classification_metric.py +176 -0
- opensportslib/metrics/localization_metric.py +1482 -0
- opensportslib/models/__init__.py +0 -0
- opensportslib/models/backbones/builder.py +590 -0
- opensportslib/models/base/e2e.py +252 -0
- opensportslib/models/base/tracking.py +73 -0
- opensportslib/models/base/vars.py +29 -0
- opensportslib/models/base/video.py +130 -0
- opensportslib/models/base/video_mae.py +60 -0
- opensportslib/models/builder.py +43 -0
- opensportslib/models/heads/builder.py +266 -0
- opensportslib/models/neck/builder.py +210 -0
- opensportslib/models/utils/common.py +176 -0
- opensportslib/models/utils/impl/__init__.py +0 -0
- opensportslib/models/utils/impl/asformer.py +390 -0
- opensportslib/models/utils/impl/calf.py +74 -0
- opensportslib/models/utils/impl/gsm.py +112 -0
- opensportslib/models/utils/impl/gtad.py +347 -0
- opensportslib/models/utils/impl/tsm.py +123 -0
- opensportslib/models/utils/litebase.py +59 -0
- opensportslib/models/utils/modules.py +120 -0
- opensportslib/models/utils/shift.py +135 -0
- opensportslib/models/utils/utils.py +276 -0
- opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
- opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
- opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
- opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
# opensportslib/core/utils/checkpoint.py
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
def localization_remap(key):
|
|
7
|
+
if key.startswith("_features"):
|
|
8
|
+
return "backbone." + key
|
|
9
|
+
elif key.startswith("_pred_fine"):
|
|
10
|
+
return "head." + key
|
|
11
|
+
return key
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def save_checkpoint(model, path, processor=None, tokenizer=None, optimizer=None, epoch=None):
|
|
15
|
+
"""
|
|
16
|
+
Save model checkpoint to `path`. Uses HF save_pretrained if available,
|
|
17
|
+
otherwise falls back to saving a PyTorch checkpoint.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model (torch.nn.Module or HF PreTrainedModel): model to save
|
|
21
|
+
path (str): path to save torch checkpoint (file path, e.g., /.../checkpoint.pt)
|
|
22
|
+
optimizer (torch.optim.Optimizer, optional): optimizer to save
|
|
23
|
+
epoch (int, optional): current epoch number
|
|
24
|
+
processor (optional): HF processor / feature extractor to save with model
|
|
25
|
+
tokenizer (optional): HF tokenizer to save with model
|
|
26
|
+
"""
|
|
27
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
28
|
+
|
|
29
|
+
hf_saved = False
|
|
30
|
+
|
|
31
|
+
# 1) Try to save HuggingFace model if available
|
|
32
|
+
if hasattr(model, "save_pretrained"):
|
|
33
|
+
try:
|
|
34
|
+
model.save_pretrained(path)
|
|
35
|
+
hf_saved = True
|
|
36
|
+
print(f"[Checkpoint] HuggingFace model saved to {path}")
|
|
37
|
+
except Exception as e:
|
|
38
|
+
print(f"[Checkpoint] Warning: could not save HF model: {e}")
|
|
39
|
+
hf_saved = False
|
|
40
|
+
|
|
41
|
+
# 2) Save processor / tokenizer if provided (only if HF save succeeded)
|
|
42
|
+
if hf_saved:
|
|
43
|
+
if processor is not None:
|
|
44
|
+
try:
|
|
45
|
+
processor.save_pretrained(path)
|
|
46
|
+
print(f"[Checkpoint] Processor saved to {path}")
|
|
47
|
+
except Exception as e:
|
|
48
|
+
print(f"[Checkpoint] Warning: could not save processor: {e}")
|
|
49
|
+
|
|
50
|
+
if tokenizer is not None:
|
|
51
|
+
try:
|
|
52
|
+
tokenizer.save_pretrained(path)
|
|
53
|
+
print(f"[Checkpoint] Tokenizer saved to {path}")
|
|
54
|
+
except Exception as e:
|
|
55
|
+
print(f"[Checkpoint] Warning: could not save tokenizer: {e}")
|
|
56
|
+
|
|
57
|
+
# 3) Fallback: Save a PyTorch checkpoint if HF save is unavailable or failed
|
|
58
|
+
if not hf_saved:
|
|
59
|
+
checkpoint = {}
|
|
60
|
+
if hasattr(model, "state_dict"):
|
|
61
|
+
checkpoint["model_state_dict"] = model.state_dict()
|
|
62
|
+
else:
|
|
63
|
+
print("[Checkpoint] Warning: model has no state_dict(), skipping model_state_dict.")
|
|
64
|
+
|
|
65
|
+
if optimizer is not None:
|
|
66
|
+
checkpoint["optimizer_state_dict"] = optimizer.state_dict()
|
|
67
|
+
if epoch is not None:
|
|
68
|
+
checkpoint["epoch"] = epoch
|
|
69
|
+
|
|
70
|
+
torch.save(checkpoint, path)
|
|
71
|
+
print(f"[Checkpoint] Torch checkpoint saved at: {path}")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def load_checkpoint(
|
|
75
|
+
model,
|
|
76
|
+
path,
|
|
77
|
+
optimizer=None,
|
|
78
|
+
scheduler=None,
|
|
79
|
+
device=None,
|
|
80
|
+
key_remap_fn=None,
|
|
81
|
+
hf_filename="model.pth.tar", # required if loading from HF repo
|
|
82
|
+
hf_token=None, # optional (for private repos / non-interactive envs)
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Load checkpoint from:
|
|
86
|
+
- local .pt/.pth/.tar
|
|
87
|
+
- HuggingFace repo (repo_id)
|
|
88
|
+
|
|
89
|
+
Auth behavior:
|
|
90
|
+
- If logged in -> no token needed
|
|
91
|
+
- If not logged in -> interactive prompt
|
|
92
|
+
- If non-interactive -> hf_token required
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
import os
|
|
96
|
+
import sys
|
|
97
|
+
import torch
|
|
98
|
+
from opensportslib.core.utils.config import expand
|
|
99
|
+
|
|
100
|
+
if device is None:
|
|
101
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
102
|
+
|
|
103
|
+
ckpt_path = None
|
|
104
|
+
hf_error = None
|
|
105
|
+
is_local = os.path.exists(path)
|
|
106
|
+
# --------------------------------------------------
|
|
107
|
+
# Try Hugging Face FIRST
|
|
108
|
+
# --------------------------------------------------
|
|
109
|
+
try:
|
|
110
|
+
from huggingface_hub import hf_hub_download, whoami, login, list_repo_files
|
|
111
|
+
|
|
112
|
+
# Ensure auth if needed
|
|
113
|
+
if hf_token is None:
|
|
114
|
+
try:
|
|
115
|
+
whoami()
|
|
116
|
+
except Exception:
|
|
117
|
+
if sys.stdin.isatty():
|
|
118
|
+
login()
|
|
119
|
+
|
|
120
|
+
if not is_local:
|
|
121
|
+
print(f"[HF] Inspecting repo: {path}")
|
|
122
|
+
|
|
123
|
+
files = list_repo_files(path, token=hf_token)
|
|
124
|
+
|
|
125
|
+
# find checkpoint file automatically
|
|
126
|
+
candidates = [
|
|
127
|
+
f for f in files
|
|
128
|
+
if f.endswith((".pt", ".pth", ".pth.tar", ".bin"))
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
if not candidates:
|
|
132
|
+
raise FileNotFoundError(
|
|
133
|
+
f"No checkpoint file found in HF repo {path}. "
|
|
134
|
+
f"Files: {files}"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
# pick first candidate
|
|
138
|
+
hf_filename = candidates[0]
|
|
139
|
+
print(f"[HF] Using checkpoint file: {hf_filename}")
|
|
140
|
+
|
|
141
|
+
ckpt_path = hf_hub_download(
|
|
142
|
+
repo_id=path,
|
|
143
|
+
filename=hf_filename,
|
|
144
|
+
token=hf_token,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
print(f"[HF] Loaded from cache: {ckpt_path}")
|
|
148
|
+
|
|
149
|
+
except Exception as e:
|
|
150
|
+
hf_error = e
|
|
151
|
+
|
|
152
|
+
# --------------------------------------------------
|
|
153
|
+
# 2️⃣ Fallback to local
|
|
154
|
+
# --------------------------------------------------
|
|
155
|
+
path = expand(path)
|
|
156
|
+
if ckpt_path is None:
|
|
157
|
+
if not is_local:
|
|
158
|
+
raise FileNotFoundError(
|
|
159
|
+
f"Checkpoint not found on HuggingFace OR locally: {path}"
|
|
160
|
+
) from hf_error
|
|
161
|
+
|
|
162
|
+
ckpt_path = path
|
|
163
|
+
print(f"[Local] Using local checkpoint: {ckpt_path}")
|
|
164
|
+
# --------------------------------------------------
|
|
165
|
+
# Load checkpoint
|
|
166
|
+
# --------------------------------------------------
|
|
167
|
+
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
|
|
168
|
+
|
|
169
|
+
# ---------------- MODEL STATE ----------------
|
|
170
|
+
if isinstance(checkpoint, dict):
|
|
171
|
+
if "model_state_dict" in checkpoint:
|
|
172
|
+
state_dict = checkpoint["model_state_dict"]
|
|
173
|
+
elif "state_dict" in checkpoint:
|
|
174
|
+
state_dict = checkpoint["state_dict"]
|
|
175
|
+
else:
|
|
176
|
+
state_dict = {
|
|
177
|
+
k: v for k, v in checkpoint.items()
|
|
178
|
+
if isinstance(v, torch.Tensor)
|
|
179
|
+
}
|
|
180
|
+
else:
|
|
181
|
+
raise ValueError("Checkpoint format not recognized")
|
|
182
|
+
|
|
183
|
+
# Clean + remap keys
|
|
184
|
+
model_keys = list(model.state_dict().keys())
|
|
185
|
+
ckpt_keys = list(state_dict.keys())
|
|
186
|
+
|
|
187
|
+
ckpt_has_module = ckpt_keys[0].startswith("module.")
|
|
188
|
+
model_has_module = model_keys[0].startswith("module.")
|
|
189
|
+
|
|
190
|
+
# Case 1: checkpoint has module., model doesn't
|
|
191
|
+
if ckpt_has_module and not model_has_module:
|
|
192
|
+
state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
|
|
193
|
+
|
|
194
|
+
# Case 2: checkpoint doesn't have module., model does
|
|
195
|
+
elif not ckpt_has_module and model_has_module:
|
|
196
|
+
state_dict = {f"module.{k}": v for k, v in state_dict.items()}
|
|
197
|
+
|
|
198
|
+
# Optional custom remap
|
|
199
|
+
if key_remap_fn:
|
|
200
|
+
state_dict = {key_remap_fn(k): v for k, v in state_dict.items()}
|
|
201
|
+
|
|
202
|
+
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
|
203
|
+
print("\n--- MISSING KEYS ---")
|
|
204
|
+
for k in missing[:20]:
|
|
205
|
+
print(k)
|
|
206
|
+
|
|
207
|
+
print("\n--- UNEXPECTED KEYS ---")
|
|
208
|
+
for k in unexpected[:20]:
|
|
209
|
+
print(k)
|
|
210
|
+
model.to(device)
|
|
211
|
+
|
|
212
|
+
# ---------------- EPOCH ----------------
|
|
213
|
+
epoch = checkpoint.get("epoch") if isinstance(checkpoint, dict) else None
|
|
214
|
+
|
|
215
|
+
# ---------------- OPTIMIZER ----------------
|
|
216
|
+
if optimizer and isinstance(checkpoint, dict):
|
|
217
|
+
opt_state = checkpoint.get("optimizer") or checkpoint.get("optimizer_state_dict")
|
|
218
|
+
if opt_state:
|
|
219
|
+
optimizer.load_state_dict(opt_state)
|
|
220
|
+
|
|
221
|
+
# ---------------- SCHEDULER ----------------
|
|
222
|
+
if scheduler and isinstance(checkpoint, dict):
|
|
223
|
+
sch_state = checkpoint.get("scheduler") or checkpoint.get("scheduler_state_dict")
|
|
224
|
+
if sch_state:
|
|
225
|
+
scheduler.load_state_dict(sch_state)
|
|
226
|
+
|
|
227
|
+
print(f"[Checkpoint] Loaded from {ckpt_path} | epoch: {epoch}")
|
|
228
|
+
print(f"Missing keys: {len(missing)}")
|
|
229
|
+
print(f"Unexpected keys: {len(unexpected)}")
|
|
230
|
+
|
|
231
|
+
return model, optimizer, scheduler, epoch
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def load_huggingface_checkpoint(config, path, device):
|
|
236
|
+
from opensportslib.models.base.video_mae import load_video_mae_checkpoint
|
|
237
|
+
return load_video_mae_checkpoint(config, device=device, ckpt_path=path)
|
|
238
|
+
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import json
|
|
5
|
+
import gzip
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
def dict_to_namespace(d, skip_keys=("classes",)):
|
|
9
|
+
"""
|
|
10
|
+
Recursively convert dict to namespace for easy access,
|
|
11
|
+
but keep certain keys (like 'classes') as raw dict/list.
|
|
12
|
+
"""
|
|
13
|
+
from types import SimpleNamespace
|
|
14
|
+
|
|
15
|
+
if isinstance(d, dict):
|
|
16
|
+
out = {}
|
|
17
|
+
for k, v in d.items():
|
|
18
|
+
if k in skip_keys:
|
|
19
|
+
out[k] = v # leave as-is
|
|
20
|
+
else:
|
|
21
|
+
out[k] = dict_to_namespace(v, skip_keys)
|
|
22
|
+
return SimpleNamespace(**out)
|
|
23
|
+
elif isinstance(d, list):
|
|
24
|
+
return [dict_to_namespace(v, skip_keys) for v in d]
|
|
25
|
+
else:
|
|
26
|
+
return d
|
|
27
|
+
|
|
28
|
+
def namespace_to_dict(ns):
|
|
29
|
+
return {k: vars(v) if hasattr(v, "__dict__") else v for k, v in vars(ns).items()}
|
|
30
|
+
|
|
31
|
+
def namespace_to_omegaconf(ns):
|
|
32
|
+
"""
|
|
33
|
+
Recursively convert SimpleNamespace (or dict/list) back to OmegaConf
|
|
34
|
+
"""
|
|
35
|
+
from omegaconf import OmegaConf
|
|
36
|
+
from types import SimpleNamespace
|
|
37
|
+
|
|
38
|
+
def to_dict(obj):
|
|
39
|
+
if isinstance(obj, SimpleNamespace):
|
|
40
|
+
return {k: to_dict(v) for k, v in vars(obj).items()}
|
|
41
|
+
elif isinstance(obj, dict):
|
|
42
|
+
return {k: to_dict(v) for k, v in obj.items()}
|
|
43
|
+
elif isinstance(obj, list):
|
|
44
|
+
return [to_dict(v) for v in obj]
|
|
45
|
+
else:
|
|
46
|
+
return obj
|
|
47
|
+
|
|
48
|
+
return OmegaConf.create(to_dict(ns))
|
|
49
|
+
|
|
50
|
+
def load_config(config_path):
|
|
51
|
+
"""
|
|
52
|
+
Loading configurations
|
|
53
|
+
"""
|
|
54
|
+
print(config_path)
|
|
55
|
+
if config_path.endswith(".yaml") or config_path.endswith(".yml"):
|
|
56
|
+
with open(config_path, "r") as f:
|
|
57
|
+
cfg_dict = yaml.safe_load(f)
|
|
58
|
+
elif config_path.endswith(".json"):
|
|
59
|
+
with open(config_path, "r") as f:
|
|
60
|
+
cfg_dict = json.load(f)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError("Unsupported config format. Use YAML or JSON.")
|
|
63
|
+
return dict_to_namespace(cfg_dict)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def load_config_omega(path):
|
|
68
|
+
|
|
69
|
+
from omegaconf import OmegaConf
|
|
70
|
+
cfg = OmegaConf.load(path)
|
|
71
|
+
# OmegaConf.resolve(cfg)
|
|
72
|
+
# cfg = OmegaConf.to_container(cfg, resolve=True)
|
|
73
|
+
return dict_to_namespace(cfg)
|
|
74
|
+
|
|
75
|
+
def resolve_config_omega(cfg):
|
|
76
|
+
from omegaconf import OmegaConf, DictConfig
|
|
77
|
+
#cfg = namespace_to_omegaconf(cfg)
|
|
78
|
+
#cfg = namespace_to_dict(cfg)
|
|
79
|
+
#print(type(cfg))
|
|
80
|
+
#cfg = OmegaConf.create(cfg)
|
|
81
|
+
if not isinstance(cfg, DictConfig):
|
|
82
|
+
return cfg
|
|
83
|
+
OmegaConf.resolve(cfg)
|
|
84
|
+
cfg = dict_to_namespace(OmegaConf.to_container(cfg, resolve=True))
|
|
85
|
+
return cfg
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def expand(path):
|
|
89
|
+
return os.path.abspath(os.path.expanduser(path))
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def load_json(fpath):
|
|
93
|
+
with open(fpath) as fp:
|
|
94
|
+
return json.load(fp)
|
|
95
|
+
|
|
96
|
+
def load_gz_json(fpath):
|
|
97
|
+
with gzip.open(fpath, "rt", encoding="ascii") as fp:
|
|
98
|
+
return json.load(fp)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def store_json(fpath, obj, pretty=False):
|
|
102
|
+
kwargs = {}
|
|
103
|
+
if pretty:
|
|
104
|
+
kwargs["indent"] = 4
|
|
105
|
+
kwargs["sort_keys"] = False
|
|
106
|
+
with open(fpath, "w") as fp:
|
|
107
|
+
json.dump(obj, fp, **kwargs)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def store_gz_json(fpath, obj):
|
|
111
|
+
with gzip.open(fpath, "wt", encoding="ascii") as fp:
|
|
112
|
+
json.dump(obj, fp)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def load_text(fpath):
|
|
116
|
+
"""Load text from a given file.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
fpath (string): The path of the file.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
lines (List): List in which element is a line of the file.
|
|
123
|
+
|
|
124
|
+
"""
|
|
125
|
+
lines = []
|
|
126
|
+
with open(fpath, "r") as fp:
|
|
127
|
+
for l in fp:
|
|
128
|
+
l = l.strip()
|
|
129
|
+
if l:
|
|
130
|
+
lines.append(l)
|
|
131
|
+
return lines
|
|
132
|
+
|
|
133
|
+
def load_classes(input):
|
|
134
|
+
"""Load classes from either list or txt file.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
input (string): Path of the file that contains one class per line or list of classes.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Dictionnary with classes associated to indexes.
|
|
141
|
+
"""
|
|
142
|
+
from omegaconf import ListConfig
|
|
143
|
+
if isinstance(input, (list, ListConfig)):
|
|
144
|
+
return {x: i + 1 for i, x in enumerate(input)}
|
|
145
|
+
return {x: i + 1 for i, x in enumerate(load_text(input))}
|
|
146
|
+
|
|
147
|
+
def clear_files(dir_name, re_str, exclude=[]):
|
|
148
|
+
for file_name in os.listdir(dir_name):
|
|
149
|
+
if re.match(re_str, file_name):
|
|
150
|
+
if file_name not in exclude:
|
|
151
|
+
file_path = os.path.join(dir_name, file_name)
|
|
152
|
+
os.remove(file_path)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _print_info_helper(src_file, labels):
|
|
156
|
+
"""Print informations about videos contained in a json file.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
src_file (string): The source file.
|
|
160
|
+
labels (list(dict)): List containing a dict fro each video.
|
|
161
|
+
"""
|
|
162
|
+
num_frames = sum([x["num_frames"] for x in labels])
|
|
163
|
+
num_events = sum([len(x["events"]) for x in labels])
|
|
164
|
+
print(
|
|
165
|
+
"{} : {} videos, {} frames, {:0.5f}% non-bg".format(
|
|
166
|
+
src_file, len(labels), num_frames, num_events / num_frames * 100
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
def select_device(config):
|
|
171
|
+
import torch
|
|
172
|
+
mode = config.device.lower()
|
|
173
|
+
|
|
174
|
+
if mode == "auto":
|
|
175
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
176
|
+
|
|
177
|
+
elif mode == "cuda":
|
|
178
|
+
assert torch.cuda.is_available(), "CUDA requested but not available"
|
|
179
|
+
gpu_id = getattr(config, "gpu_id", 0)
|
|
180
|
+
torch.cuda.set_device(gpu_id)
|
|
181
|
+
device = torch.device(f"cuda:{gpu_id}")
|
|
182
|
+
|
|
183
|
+
elif mode == "cpu":
|
|
184
|
+
device = torch.device("cpu")
|
|
185
|
+
|
|
186
|
+
else:
|
|
187
|
+
raise ValueError(f"Unknown device mode: {mode}")
|
|
188
|
+
|
|
189
|
+
print(f"Using device: {device}")
|
|
190
|
+
if device.type == "cuda" or device.type == "auto":
|
|
191
|
+
print(f"GPU: {torch.cuda.get_device_name(device)}")
|
|
192
|
+
|
|
193
|
+
return device
|
|
194
|
+
|
|
195
|
+
def is_local_path(p):
|
|
196
|
+
return p and (
|
|
197
|
+
os.path.exists(p) or
|
|
198
|
+
p.endswith((".pt", ".pth", ".tar"))
|
|
199
|
+
)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from torch.utils.data import Subset
|
|
3
|
+
import torch
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
def balanced_subset(dataset, samples_per_class=5):
|
|
7
|
+
class_indices = defaultdict(list)
|
|
8
|
+
|
|
9
|
+
for idx in range(len(dataset)):
|
|
10
|
+
label = dataset.samples[idx]["label"]
|
|
11
|
+
class_indices[label].append(idx)
|
|
12
|
+
|
|
13
|
+
print(class_indices.keys())
|
|
14
|
+
selected_indices = []
|
|
15
|
+
for label, indices in class_indices.items():
|
|
16
|
+
selected_indices.extend(indices[:samples_per_class])
|
|
17
|
+
|
|
18
|
+
print(selected_indices)
|
|
19
|
+
return Subset(dataset, selected_indices)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def batch_tensor(tensor, dim=1, squeeze=False):
|
|
23
|
+
"""
|
|
24
|
+
A function to reshape PyTorch tensor `tensor` along some dimension `dim` to the batch dimension 0 such that the tensor can be processed in parallel.
|
|
25
|
+
If `sqeeze`=True, the dimension `dim` will be removed completely, otherwise it will be of size=1. Check `unbatch_tensor()` for the reverese function.
|
|
26
|
+
"""
|
|
27
|
+
batch_size, dim_size = tensor.shape[0], tensor.shape[dim]
|
|
28
|
+
returned_size = list(tensor.shape)
|
|
29
|
+
returned_size[0] = batch_size * dim_size
|
|
30
|
+
returned_size[dim] = 1
|
|
31
|
+
if squeeze:
|
|
32
|
+
return tensor.transpose(0, dim).reshape(returned_size).squeeze_(dim)
|
|
33
|
+
else:
|
|
34
|
+
return tensor.transpose(0, dim).reshape(returned_size)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def unbatch_tensor(tensor, batch_size, dim=1, unsqueeze=False):
|
|
38
|
+
"""
|
|
39
|
+
A function to chunk pytorch tensor `tensor` along the batch dimension 0 and concatenate the chuncks on dimension `dim` to recover from `batch_tensor()` function.
|
|
40
|
+
If `unsqueee`=True, it will add a dimension `dim` before the unbatching.
|
|
41
|
+
"""
|
|
42
|
+
fake_batch_size = tensor.shape[0]
|
|
43
|
+
nb_chunks = int(fake_batch_size / batch_size)
|
|
44
|
+
if unsqueeze:
|
|
45
|
+
return torch.cat(torch.chunk(tensor.unsqueeze_(dim), nb_chunks, dim=0), dim=dim).contiguous()
|
|
46
|
+
else:
|
|
47
|
+
return torch.cat(torch.chunk(tensor, nb_chunks, dim=0), dim=dim).contiguous()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def tracking_collate_fn(batch):
|
|
51
|
+
"""
|
|
52
|
+
Custom collate function for tracking data.
|
|
53
|
+
Uses PyG Batch.from_data_list for efficient C++ batching.
|
|
54
|
+
"""
|
|
55
|
+
from torch_geometric.data import Batch
|
|
56
|
+
|
|
57
|
+
batch_size = len(batch)
|
|
58
|
+
seq_len = batch[0]['seq_len']
|
|
59
|
+
|
|
60
|
+
# flatten all graphs from all samples
|
|
61
|
+
all_graphs = []
|
|
62
|
+
for sample_idx, item in enumerate(batch):
|
|
63
|
+
for time_idx, graph in enumerate(item['graphs']):
|
|
64
|
+
all_graphs.append(graph)
|
|
65
|
+
|
|
66
|
+
# PyG handles node offsets for edge_index automatically
|
|
67
|
+
batched_graphs = Batch.from_data_list(all_graphs)
|
|
68
|
+
|
|
69
|
+
return {
|
|
70
|
+
'x': batched_graphs.x,
|
|
71
|
+
'edge_index': batched_graphs.edge_index,
|
|
72
|
+
'batch': batched_graphs.batch,
|
|
73
|
+
'batch_size': batch_size,
|
|
74
|
+
'seq_len': seq_len,
|
|
75
|
+
'labels': torch.tensor([item['label'] for item in batch], dtype=torch.long),
|
|
76
|
+
'id': [item['id'] for item in batch],
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
def mixup_data(x, y, alpha=0.2):
|
|
80
|
+
"""blend pairs of samples and their labels for mixup augmentation."""
|
|
81
|
+
lam = np.random.beta(alpha, alpha) if alpha > 0 else 1.0
|
|
82
|
+
index = torch.randperm(x.size(0)).to(x.device)
|
|
83
|
+
mixed_x = lam * x + (1 - lam) * x[index]
|
|
84
|
+
return mixed_x, y, y[index], lam
|
|
85
|
+
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import torch.distributed as dist
|
|
2
|
+
def ddp_setup(rank, world_size):
|
|
3
|
+
import os
|
|
4
|
+
os.environ["MASTER_ADDR"] = "localhost"
|
|
5
|
+
os.environ["MASTER_PORT"] = "12355" # any free port
|
|
6
|
+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
|
7
|
+
|
|
8
|
+
def ddp_cleanup():
|
|
9
|
+
dist.destroy_process_group()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from torch.utils.data import Sampler
|
|
14
|
+
import math
|
|
15
|
+
|
|
16
|
+
class DistributedWeightedSampler(Sampler):
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
weights,
|
|
20
|
+
num_replicas=None,
|
|
21
|
+
rank=None,
|
|
22
|
+
replacement=True,
|
|
23
|
+
num_samples=None,
|
|
24
|
+
seed=0
|
|
25
|
+
):
|
|
26
|
+
if num_replicas is None:
|
|
27
|
+
num_replicas = dist.get_world_size()
|
|
28
|
+
if rank is None:
|
|
29
|
+
rank = dist.get_rank()
|
|
30
|
+
|
|
31
|
+
self.num_replicas = num_replicas
|
|
32
|
+
self.rank = rank
|
|
33
|
+
self.replacement = replacement
|
|
34
|
+
self.seed = seed
|
|
35
|
+
|
|
36
|
+
self.weights = torch.as_tensor(weights, dtype=torch.double)
|
|
37
|
+
self.dataset_size = len(self.weights)
|
|
38
|
+
|
|
39
|
+
# Split dataset across ranks
|
|
40
|
+
self.num_samples_per_rank = math.ceil(self.dataset_size / self.num_replicas)
|
|
41
|
+
self.total_size = self.num_samples_per_rank * self.num_replicas
|
|
42
|
+
|
|
43
|
+
# Pad if needed
|
|
44
|
+
if self.total_size > self.dataset_size:
|
|
45
|
+
padding = self.total_size - self.dataset_size
|
|
46
|
+
self.weights = torch.cat([self.weights, self.weights[:padding]])
|
|
47
|
+
|
|
48
|
+
# Indices for this rank
|
|
49
|
+
self.rank_indices = list(range(self.rank, self.total_size, self.num_replicas))
|
|
50
|
+
self.rank_weights = self.weights[self.rank_indices]
|
|
51
|
+
|
|
52
|
+
# Number of samples to draw
|
|
53
|
+
if num_samples is None:
|
|
54
|
+
self.num_samples = len(self.rank_indices)
|
|
55
|
+
else:
|
|
56
|
+
self.num_samples = num_samples // self.num_replicas
|
|
57
|
+
|
|
58
|
+
self.epoch = 0
|
|
59
|
+
|
|
60
|
+
def set_epoch(self, epoch):
|
|
61
|
+
self.epoch = epoch
|
|
62
|
+
|
|
63
|
+
def __iter__(self):
|
|
64
|
+
g = torch.Generator()
|
|
65
|
+
g.manual_seed(self.seed + self.epoch)
|
|
66
|
+
|
|
67
|
+
sampled = torch.multinomial(
|
|
68
|
+
self.rank_weights,
|
|
69
|
+
self.num_samples,
|
|
70
|
+
self.replacement,
|
|
71
|
+
generator=g
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return (self.rank_indices[i] % self.dataset_size for i in sampled)
|
|
75
|
+
|
|
76
|
+
def __len__(self):
|
|
77
|
+
return self.num_samples
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
def get_default_args_data_train_e2e_dali(cfg):
|
|
2
|
+
return {
|
|
3
|
+
"classes": cfg.DATA.classes,
|
|
4
|
+
"train": True,
|
|
5
|
+
"acc_grad_iter": cfg.TRAIN.acc_grad_iter,
|
|
6
|
+
"num_epochs": cfg.TRAIN.num_epochs,
|
|
7
|
+
"repartitions": cfg.TRAIN.repartitions,
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_default_args_data_valid_e2e_dali(cfg):
|
|
12
|
+
return {
|
|
13
|
+
"classes": cfg.DATA.classes,
|
|
14
|
+
"train": False,
|
|
15
|
+
"acc_grad_iter": cfg.TRAIN.acc_grad_iter,
|
|
16
|
+
"num_epochs": cfg.TRAIN.num_epochs,
|
|
17
|
+
"repartitions": cfg.TRAIN.repartitions,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_default_args_data_train_e2e_opencv(cfg):
|
|
22
|
+
return {"classes": cfg.DATA.classes, "train": True}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_default_args_data_valid_e2e_opencv(cfg):
|
|
26
|
+
return {"classes": cfg.DATA.classes, "train": False}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_default_args_data_train():
|
|
30
|
+
return {"train": True}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_default_args_data_valid():
|
|
34
|
+
return {"train": False}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_default_args_data_valid_data_frames_e2e_dali(cfg):
|
|
38
|
+
return {"classes": cfg.DATA.classes, "repartitions": cfg.TRAIN.repartitions}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_default_args_data_valid_data_frames_e2e_opencv(cfg):
|
|
42
|
+
return {"classes": cfg.DATA.classes}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_default_args_dataset(split, cfg):
|
|
46
|
+
if split == "train":
|
|
47
|
+
if cfg.MODEL.runner.type == "runner_e2e":
|
|
48
|
+
if getattr(cfg, "dali", False):
|
|
49
|
+
return get_default_args_data_train_e2e_dali(cfg)
|
|
50
|
+
else:
|
|
51
|
+
return get_default_args_data_train_e2e_opencv(cfg)
|
|
52
|
+
else:
|
|
53
|
+
return get_default_args_data_train()
|
|
54
|
+
|
|
55
|
+
elif split == "valid":
|
|
56
|
+
if cfg.MODEL.runner.type == "runner_e2e":
|
|
57
|
+
if getattr(cfg, "dali", False):
|
|
58
|
+
return get_default_args_data_valid_e2e_dali(cfg)
|
|
59
|
+
else:
|
|
60
|
+
return get_default_args_data_valid_e2e_opencv(cfg)
|
|
61
|
+
else:
|
|
62
|
+
return get_default_args_data_valid()
|
|
63
|
+
|
|
64
|
+
elif split == "valid_data_frames" or split == "test" or split == "challenge":
|
|
65
|
+
if cfg.MODEL.runner.type == "runner_e2e":
|
|
66
|
+
if getattr(cfg, "dali", False):
|
|
67
|
+
return get_default_args_data_valid_data_frames_e2e_dali(cfg)
|
|
68
|
+
else:
|
|
69
|
+
return get_default_args_data_valid_data_frames_e2e_opencv(cfg)
|
|
70
|
+
else:
|
|
71
|
+
return
|
|
72
|
+
else:
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_default_args_model(cfg):
|
|
77
|
+
if cfg.MODEL.type == "E2E":
|
|
78
|
+
return {"classes": cfg.DATA.classes}
|
|
79
|
+
else:
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def get_default_args_trainer(cfg, len_train_loader):
|
|
84
|
+
if cfg.TRAIN.type == "trainer_e2e":
|
|
85
|
+
return {
|
|
86
|
+
"len_train_loader": len_train_loader,
|
|
87
|
+
"work_dir": cfg.SYSTEM.work_dir,
|
|
88
|
+
"dali": cfg.dali,
|
|
89
|
+
"repartitions": cfg.TRAIN.repartitions if cfg.dali else None,
|
|
90
|
+
"cfg_test": cfg.DATA.test,
|
|
91
|
+
#"cfg_challenge": cfg.DATA.challenge,
|
|
92
|
+
"cfg_valid_data_frames": cfg.DATA.valid_data_frames,
|
|
93
|
+
}
|
|
94
|
+
else:
|
|
95
|
+
return {"work_dir": cfg.SYSTEM.work_dir}
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_default_args_train(model, train_loader, valid_loader, classes, trainer_type):
|
|
99
|
+
if trainer_type == "trainer_CALF" or trainer_type == "trainer_pooling":
|
|
100
|
+
return {
|
|
101
|
+
"model": model,
|
|
102
|
+
"train_dataloaders": train_loader,
|
|
103
|
+
"val_dataloaders": valid_loader,
|
|
104
|
+
}
|
|
105
|
+
elif trainer_type == "trainer_e2e":
|
|
106
|
+
return {
|
|
107
|
+
"train_loader": train_loader,
|
|
108
|
+
"valid_loader": valid_loader,
|
|
109
|
+
"classes": classes,
|
|
110
|
+
}
|