opensportslib 0.0.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opensportslib/__init__.py +18 -0
- opensportslib/apis/__init__.py +21 -0
- opensportslib/apis/classification.py +361 -0
- opensportslib/apis/localization.py +228 -0
- opensportslib/config/classification.yaml +104 -0
- opensportslib/config/classification_tracking.yaml +103 -0
- opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
- opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
- opensportslib/config/localization.yaml +132 -0
- opensportslib/config/sngar_frames.yaml +98 -0
- opensportslib/core/__init__.py +0 -0
- opensportslib/core/loss/__init__.py +0 -0
- opensportslib/core/loss/builder.py +40 -0
- opensportslib/core/loss/calf.py +258 -0
- opensportslib/core/loss/ce.py +23 -0
- opensportslib/core/loss/combine.py +42 -0
- opensportslib/core/loss/nll.py +25 -0
- opensportslib/core/optimizer/__init__.py +0 -0
- opensportslib/core/optimizer/builder.py +38 -0
- opensportslib/core/sampler/weighted_sampler.py +104 -0
- opensportslib/core/scheduler/__init__.py +0 -0
- opensportslib/core/scheduler/builder.py +77 -0
- opensportslib/core/trainer/__init__.py +0 -0
- opensportslib/core/trainer/classification_trainer.py +1131 -0
- opensportslib/core/trainer/localization_trainer.py +1009 -0
- opensportslib/core/utils/checkpoint.py +238 -0
- opensportslib/core/utils/config.py +199 -0
- opensportslib/core/utils/data.py +85 -0
- opensportslib/core/utils/ddp.py +77 -0
- opensportslib/core/utils/default_args.py +110 -0
- opensportslib/core/utils/load_annotations.py +485 -0
- opensportslib/core/utils/seed.py +26 -0
- opensportslib/core/utils/video_processing.py +389 -0
- opensportslib/core/utils/wandb.py +110 -0
- opensportslib/datasets/__init__.py +0 -0
- opensportslib/datasets/builder.py +42 -0
- opensportslib/datasets/classification_dataset.py +582 -0
- opensportslib/datasets/localization_dataset.py +813 -0
- opensportslib/datasets/utils/__init__.py +15 -0
- opensportslib/datasets/utils/tracking.py +615 -0
- opensportslib/metrics/classification_metric.py +176 -0
- opensportslib/metrics/localization_metric.py +1482 -0
- opensportslib/models/__init__.py +0 -0
- opensportslib/models/backbones/builder.py +590 -0
- opensportslib/models/base/e2e.py +252 -0
- opensportslib/models/base/tracking.py +73 -0
- opensportslib/models/base/vars.py +29 -0
- opensportslib/models/base/video.py +130 -0
- opensportslib/models/base/video_mae.py +60 -0
- opensportslib/models/builder.py +43 -0
- opensportslib/models/heads/builder.py +266 -0
- opensportslib/models/neck/builder.py +210 -0
- opensportslib/models/utils/common.py +176 -0
- opensportslib/models/utils/impl/__init__.py +0 -0
- opensportslib/models/utils/impl/asformer.py +390 -0
- opensportslib/models/utils/impl/calf.py +74 -0
- opensportslib/models/utils/impl/gsm.py +112 -0
- opensportslib/models/utils/impl/gtad.py +347 -0
- opensportslib/models/utils/impl/tsm.py +123 -0
- opensportslib/models/utils/litebase.py +59 -0
- opensportslib/models/utils/modules.py +120 -0
- opensportslib/models/utils/shift.py +135 -0
- opensportslib/models/utils/utils.py +276 -0
- opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
- opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
- opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
- opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
# from . import apis as model
|
|
3
|
+
# from . import metrics
|
|
4
|
+
# from . import datasets
|
|
5
|
+
# from . import core
|
|
6
|
+
|
|
7
|
+
def __getattr__(name):
|
|
8
|
+
if name == "model":
|
|
9
|
+
return importlib.import_module("opensportslib.apis")
|
|
10
|
+
if name == "metrics":
|
|
11
|
+
return importlib.import_module("opensportslib.metrics")
|
|
12
|
+
if name == "datasets":
|
|
13
|
+
return importlib.import_module("opensportslib.datasets")
|
|
14
|
+
if name == "core":
|
|
15
|
+
return importlib.import_module("opensportslib.core")
|
|
16
|
+
raise AttributeError(f"module 'opensportslib' has no attribute '{name}'")
|
|
17
|
+
|
|
18
|
+
__all__ = ["model", "metrics", "datasets", "core"]
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# opensportslib/apis/__init__.py
|
|
2
|
+
|
|
3
|
+
# Import task APIs
|
|
4
|
+
from opensportslib.apis.classification import ClassificationAPI
|
|
5
|
+
from opensportslib.apis.localization import LocalizationAPI
|
|
6
|
+
import warnings
|
|
7
|
+
warnings.filterwarnings("ignore")
|
|
8
|
+
|
|
9
|
+
# Factory functions for user-facing calls
|
|
10
|
+
def classification(config=None, data_dir=None, save_dir=None):#, pretrained_model=None):
|
|
11
|
+
return ClassificationAPI(config=config, data_dir=data_dir, save_dir=save_dir)#,pretrained_model=pretrained_model)
|
|
12
|
+
|
|
13
|
+
def localization(config=None, data_dir=None, save_dir=None):#, pretrained_model=None):
|
|
14
|
+
return LocalizationAPI(config=config, data_dir=data_dir, save_dir=save_dir)#,pretrained_model=pretrained_model)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Expose only these
|
|
18
|
+
__all__ = [
|
|
19
|
+
"classification",
|
|
20
|
+
"localization",
|
|
21
|
+
]
|
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
# opensportslib/apis/classification.py
|
|
2
|
+
|
|
3
|
+
"""public API for classification tasks.
|
|
4
|
+
|
|
5
|
+
supports three dataset/task combinations:
|
|
6
|
+
- MVFoul (video-based foul classification)
|
|
7
|
+
- SN-GAR with video modality
|
|
8
|
+
- SN-GAR with tracking modality
|
|
9
|
+
|
|
10
|
+
handles single-GPU and multi-GPU (DDP) training and inference,
|
|
11
|
+
delegating heavy lifting to Trainer_Classification.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import os
|
|
15
|
+
import logging
|
|
16
|
+
from opensportslib.core.utils.config import expand
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ClassificationAPI:
|
|
20
|
+
"""top-level entry point for classification training and inference.
|
|
21
|
+
|
|
22
|
+
loads a YAML config, resolves paths, and exposes train() /
|
|
23
|
+
infer() methods that transparently handle single-GPU and
|
|
24
|
+
DDP execution.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
config: path to the YAML configuration file.
|
|
28
|
+
data_dir: override for DATA.data_dir in the config.
|
|
29
|
+
if None, the value from the config is used.
|
|
30
|
+
save_dir: override for the checkpoint output directory.
|
|
31
|
+
falls back to SYSTEM.save_dir, then "./checkpoints".
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, config=None, data_dir=None, save_dir=None):
|
|
35
|
+
from opensportslib.core.utils.config import (
|
|
36
|
+
load_config_omega
|
|
37
|
+
)
|
|
38
|
+
import uuid
|
|
39
|
+
|
|
40
|
+
if config is None:
|
|
41
|
+
raise ValueError("config path is required")
|
|
42
|
+
|
|
43
|
+
config_path = expand(config)
|
|
44
|
+
self.config = load_config_omega(config_path)
|
|
45
|
+
|
|
46
|
+
# let the caller override the dataset root directory.
|
|
47
|
+
self.config.DATA.data_dir = expand(
|
|
48
|
+
data_dir or self.config.DATA.data_dir
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# checkpoint output directory; never derived from BASE_DIR so the
|
|
52
|
+
# user always has explicit control over where weights are written.
|
|
53
|
+
self.run_id = os.environ.get("RUN_ID") or str(uuid.uuid4())[:8]
|
|
54
|
+
os.environ["RUN_ID"] = self.run_id
|
|
55
|
+
|
|
56
|
+
self.save_dir = expand(
|
|
57
|
+
save_dir or self.config.SYSTEM.save_dir or "./checkpoints"
|
|
58
|
+
)
|
|
59
|
+
save_filename = os.path.join(self.config.MODEL.backbone.type, self.run_id)
|
|
60
|
+
self.config.SYSTEM.save_dir = os.path.join(self.save_dir, save_filename)
|
|
61
|
+
os.makedirs(self.config.SYSTEM.save_dir, exist_ok=True)
|
|
62
|
+
|
|
63
|
+
# DDP rank; used for logging and checkpointing.
|
|
64
|
+
rank = int(os.environ.get("RANK", 0))
|
|
65
|
+
self.trainer=None
|
|
66
|
+
self.best_checkpoint=None
|
|
67
|
+
|
|
68
|
+
log_dir = expand(self.config.SYSTEM.log_dir or "./log_dir")
|
|
69
|
+
os.makedirs(os.path.join(self.config.SYSTEM.save_dir, log_dir), exist_ok=True)
|
|
70
|
+
logging.basicConfig(
|
|
71
|
+
level=logging.INFO,
|
|
72
|
+
format="%(asctime)s | %(levelname)s | %(message)s",
|
|
73
|
+
handlers=[
|
|
74
|
+
logging.FileHandler(os.path.join(log_dir, "train.log")),
|
|
75
|
+
logging.StreamHandler(),
|
|
76
|
+
],
|
|
77
|
+
force=True,
|
|
78
|
+
)
|
|
79
|
+
if rank == 0:
|
|
80
|
+
logging.info(f"DATA DIR : {self.config.DATA.data_dir}")
|
|
81
|
+
logging.info(f"MODEL SAVEDIR: {self.config.SYSTEM.save_dir}")
|
|
82
|
+
|
|
83
|
+
# -----------------------------------------------------------------
|
|
84
|
+
# internal DDP worker
|
|
85
|
+
# -----------------------------------------------------------------
|
|
86
|
+
@staticmethod
|
|
87
|
+
def _worker_ddp(
|
|
88
|
+
rank,
|
|
89
|
+
world_size,
|
|
90
|
+
mode,
|
|
91
|
+
config,
|
|
92
|
+
return_queue=None,
|
|
93
|
+
train_set=None,
|
|
94
|
+
valid_set=None,
|
|
95
|
+
test_set=None,
|
|
96
|
+
pretrained=None,
|
|
97
|
+
use_wandb=False
|
|
98
|
+
):
|
|
99
|
+
"""execute a single training or inference job on one GPU.
|
|
100
|
+
|
|
101
|
+
spawned once per GPU by train() / infer(). Each process gets
|
|
102
|
+
its own Trainer_Classification instance so that no mutable
|
|
103
|
+
state is shared across ranks.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
rank: GPU rank (0-indexed).
|
|
107
|
+
world_size: total number of participating GPUs.
|
|
108
|
+
mode: "train" or "infer".
|
|
109
|
+
return_queue: multiprocessing.Queue used by rank-0 to
|
|
110
|
+
return metrics to the calling process (inference only).
|
|
111
|
+
train_set: path to the training set annotations file.
|
|
112
|
+
valid_set: path to the validation set annotations file.
|
|
113
|
+
test_set: path to the test set annotations file.
|
|
114
|
+
pretrained: path to a saved checkpoint for warm-starting
|
|
115
|
+
or inference.
|
|
116
|
+
"""
|
|
117
|
+
import torch
|
|
118
|
+
from opensportslib.core.trainer.classification_trainer import Trainer_Classification
|
|
119
|
+
from opensportslib.core.utils.ddp import ddp_setup, ddp_cleanup
|
|
120
|
+
from opensportslib.core.utils.wandb import init_wandb
|
|
121
|
+
from opensportslib.core.utils.seed import set_reproducibility
|
|
122
|
+
from opensportslib.datasets.builder import build_dataset
|
|
123
|
+
from opensportslib.models.builder import build_model
|
|
124
|
+
import logging
|
|
125
|
+
|
|
126
|
+
# configure logging for each spawned process
|
|
127
|
+
logging.basicConfig(
|
|
128
|
+
level=logging.INFO,
|
|
129
|
+
format=f"[RANK {rank}] %(asctime)s | %(levelname)s | %(message)s",
|
|
130
|
+
force=True,
|
|
131
|
+
)
|
|
132
|
+
# silence non-rank0 processes
|
|
133
|
+
if rank != 0:
|
|
134
|
+
logging.getLogger().setLevel(logging.ERROR)
|
|
135
|
+
|
|
136
|
+
if rank == 0:
|
|
137
|
+
init_wandb(config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
|
|
138
|
+
|
|
139
|
+
# reproducibility:
|
|
140
|
+
# we default to reproducible training, but allow the user to
|
|
141
|
+
# disable this via SYSTEM.use_seed=False in the config.
|
|
142
|
+
if getattr(config.SYSTEM, "use_seed", False):
|
|
143
|
+
set_reproducibility(config.SYSTEM.seed)
|
|
144
|
+
|
|
145
|
+
is_ddp = world_size > 1
|
|
146
|
+
if is_ddp:
|
|
147
|
+
torch.cuda.set_device(rank)
|
|
148
|
+
ddp_setup(rank, world_size)
|
|
149
|
+
device = torch.device(f"cuda:{rank}")
|
|
150
|
+
else:
|
|
151
|
+
from opensportslib.core.utils.config import select_device
|
|
152
|
+
device = select_device(config.SYSTEM)
|
|
153
|
+
|
|
154
|
+
# each process creates a fresh trainer to avoid shared mutable state.
|
|
155
|
+
trainer = Trainer_Classification(config)
|
|
156
|
+
trainer.device = device
|
|
157
|
+
|
|
158
|
+
# build or restore the model.
|
|
159
|
+
if pretrained:
|
|
160
|
+
model, processor, scheduler, epoch = trainer.load(pretrained)
|
|
161
|
+
else:
|
|
162
|
+
model, processor = build_model(config, device)
|
|
163
|
+
|
|
164
|
+
trainer.model = model
|
|
165
|
+
|
|
166
|
+
if mode == "train":
|
|
167
|
+
train_data = build_dataset(
|
|
168
|
+
config, train_set, processor, split="train"
|
|
169
|
+
)
|
|
170
|
+
valid_data = build_dataset(
|
|
171
|
+
config, valid_set, processor, split="valid"
|
|
172
|
+
)
|
|
173
|
+
best_ckpt = trainer.train(
|
|
174
|
+
model, train_data, valid_data,
|
|
175
|
+
rank=rank, world_size=world_size
|
|
176
|
+
)
|
|
177
|
+
# SEND BACK CHECKPOINT FROM RANK 0
|
|
178
|
+
if rank == 0 and return_queue is not None:
|
|
179
|
+
best_ckpt = getattr(trainer.trainer, "best_checkpoint_path", None)
|
|
180
|
+
return_queue.put(best_ckpt)
|
|
181
|
+
|
|
182
|
+
elif mode == "infer":
|
|
183
|
+
test_data = build_dataset(
|
|
184
|
+
config, test_set, processor, split="test"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
metrics = trainer.infer(
|
|
188
|
+
test_data, rank=rank, world_size=world_size
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
if rank == 0 and return_queue is not None:
|
|
192
|
+
return_queue.put(metrics)
|
|
193
|
+
|
|
194
|
+
if is_ddp:
|
|
195
|
+
ddp_cleanup()
|
|
196
|
+
|
|
197
|
+
# -----------------------------------------------------------------
|
|
198
|
+
# public training interface
|
|
199
|
+
# -----------------------------------------------------------------
|
|
200
|
+
|
|
201
|
+
def train(
|
|
202
|
+
self,
|
|
203
|
+
train_set=None,
|
|
204
|
+
valid_set=None,
|
|
205
|
+
test_set=None,
|
|
206
|
+
pretrained=None,
|
|
207
|
+
use_ddp=False,
|
|
208
|
+
use_wandb=True
|
|
209
|
+
):
|
|
210
|
+
"""run a full training loop.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
train_set: path to training annotationns. defaults to the
|
|
214
|
+
value in the loaded config.
|
|
215
|
+
valid_set: path to validation annotations.
|
|
216
|
+
test_set: currently unused.
|
|
217
|
+
pretrained: optional checkpoint path for warm-starting.
|
|
218
|
+
use_ddp: if True and more than one GPU is visible,
|
|
219
|
+
spawn one process per GPU via torch.multiprocessing.spawn.
|
|
220
|
+
"""
|
|
221
|
+
import torch
|
|
222
|
+
import torch.multiprocessing as mp
|
|
223
|
+
from opensportslib.core.utils.config import (
|
|
224
|
+
resolve_config_omega
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
train_set = expand(train_set or self.config.DATA.annotations.train)
|
|
228
|
+
valid_set = expand(valid_set or self.config.DATA.annotations.valid)
|
|
229
|
+
|
|
230
|
+
self.config = resolve_config_omega(self.config)
|
|
231
|
+
logging.info("Configuration:")
|
|
232
|
+
logging.info(self.config)
|
|
233
|
+
|
|
234
|
+
world_size = torch.cuda.device_count() or self.config.SYSTEM.GPU
|
|
235
|
+
use_ddp = use_ddp and world_size > 1
|
|
236
|
+
|
|
237
|
+
ctx = mp.get_context("spawn")
|
|
238
|
+
queue = ctx.Queue()
|
|
239
|
+
|
|
240
|
+
if use_ddp:
|
|
241
|
+
logging.info(f"Launching DDP on {world_size} GPUs")
|
|
242
|
+
mp.spawn(
|
|
243
|
+
ClassificationAPI._worker_ddp,
|
|
244
|
+
args=(
|
|
245
|
+
world_size, "train", self.config, queue,
|
|
246
|
+
train_set, valid_set, None, pretrained, use_wandb
|
|
247
|
+
),
|
|
248
|
+
nprocs=world_size,
|
|
249
|
+
)
|
|
250
|
+
else:
|
|
251
|
+
logging.info("Single GPU training")
|
|
252
|
+
ClassificationAPI._worker_ddp(
|
|
253
|
+
rank=0,
|
|
254
|
+
world_size=1,
|
|
255
|
+
mode="train",
|
|
256
|
+
config=self.config,
|
|
257
|
+
return_queue=queue,
|
|
258
|
+
train_set=train_set,
|
|
259
|
+
valid_set=valid_set,
|
|
260
|
+
pretrained=pretrained,
|
|
261
|
+
use_wandb=use_wandb
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
self.best_checkpoint = queue.get()
|
|
265
|
+
return self.best_checkpoint
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def infer(
|
|
269
|
+
self,
|
|
270
|
+
test_set=None,
|
|
271
|
+
pretrained=None,
|
|
272
|
+
predictions=None,
|
|
273
|
+
use_ddp=False,
|
|
274
|
+
use_wandb=True
|
|
275
|
+
):
|
|
276
|
+
"""run inference or evaluate saved predictions.
|
|
277
|
+
|
|
278
|
+
when "predictions" is None, the model runs a forward pass
|
|
279
|
+
over the test set and returns live metrics. when "predictions"
|
|
280
|
+
points to a saved prediction file, only the evaluation step runs
|
|
281
|
+
(no GPU needed).
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
test_set: path to test annotations.
|
|
285
|
+
pretrained: checkpoint path (required when running live inference).
|
|
286
|
+
predictions: path to a previously saved prediction file.
|
|
287
|
+
if provided, evaluation is run offline without a model.
|
|
288
|
+
use_ddp: if True, distribute inference across all visible GPUs.
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
a metrics dictionary produced by the trainer.
|
|
292
|
+
"""
|
|
293
|
+
import torch
|
|
294
|
+
import torch.multiprocessing as mp
|
|
295
|
+
from opensportslib.core.utils.config import (
|
|
296
|
+
resolve_config_omega
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
test_set = expand(test_set or self.config.DATA.annotations.test)
|
|
300
|
+
|
|
301
|
+
self.config = resolve_config_omega(self.config)
|
|
302
|
+
logging.info("Configuration:")
|
|
303
|
+
logging.info(self.config)
|
|
304
|
+
|
|
305
|
+
if pretrained is None and predictions is None:
|
|
306
|
+
if hasattr(self, "best_checkpoint"):
|
|
307
|
+
pretrained = self.best_checkpoint
|
|
308
|
+
logging.info(f"Using last trained checkpoint: {pretrained}")
|
|
309
|
+
else:
|
|
310
|
+
raise ValueError("No pretrained checkpoint provided and no training run found.")
|
|
311
|
+
|
|
312
|
+
if not predictions:
|
|
313
|
+
# live inference: run the model on test data.
|
|
314
|
+
world_size = torch.cuda.device_count()
|
|
315
|
+
use_ddp = use_ddp and world_size > 1
|
|
316
|
+
|
|
317
|
+
ctx = mp.get_context("spawn")
|
|
318
|
+
queue = ctx.Queue()
|
|
319
|
+
|
|
320
|
+
if use_ddp:
|
|
321
|
+
mp.spawn(
|
|
322
|
+
ClassificationAPI._worker_ddp,
|
|
323
|
+
args=(
|
|
324
|
+
world_size, "infer", self.config, queue,
|
|
325
|
+
None, None, test_set, pretrained, use_wandb
|
|
326
|
+
),
|
|
327
|
+
nprocs=world_size,
|
|
328
|
+
)
|
|
329
|
+
else:
|
|
330
|
+
ClassificationAPI._worker_ddp(
|
|
331
|
+
rank=0,
|
|
332
|
+
world_size=1,
|
|
333
|
+
mode="infer",
|
|
334
|
+
config=self.config,
|
|
335
|
+
return_queue=queue,
|
|
336
|
+
test_set=test_set,
|
|
337
|
+
pretrained=pretrained,
|
|
338
|
+
use_wandb=use_wandb
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# rank-0 pushes metrics into the queue; retrieve them here.
|
|
342
|
+
metrics = queue.get()
|
|
343
|
+
else:
|
|
344
|
+
# offline evaluation from a saved prediction file.
|
|
345
|
+
from opensportslib.datasets.builder import build_dataset
|
|
346
|
+
from opensportslib.core.trainer.classification_trainer import Trainer_Classification
|
|
347
|
+
|
|
348
|
+
self.trainer = Trainer_Classification(self.config)
|
|
349
|
+
test_data = build_dataset(
|
|
350
|
+
self.config, test_set, None, split="test"
|
|
351
|
+
)
|
|
352
|
+
metrics = self.trainer.evaluate(
|
|
353
|
+
pred_path=predictions,
|
|
354
|
+
gt_path=test_set,
|
|
355
|
+
class_names=test_data.label_map,
|
|
356
|
+
exclude_labels=test_data.exclude_labels
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
logging.info(f"TEST METRICS : {metrics}")
|
|
360
|
+
|
|
361
|
+
return metrics
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
from opensportslib.core.utils.config import expand
|
|
2
|
+
import os
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
class LocalizationAPI:
|
|
7
|
+
def __init__(self, config=None, data_dir=None, save_dir=None):
|
|
8
|
+
from opensportslib.core.utils.config import load_config_omega
|
|
9
|
+
#from ..core.trainer import Trainer
|
|
10
|
+
import uuid
|
|
11
|
+
|
|
12
|
+
if config is None:
|
|
13
|
+
raise ValueError("config path is required")
|
|
14
|
+
|
|
15
|
+
# Load config
|
|
16
|
+
### load data_dor first then do load config with omega to resolve $paths
|
|
17
|
+
config_path = expand(config)
|
|
18
|
+
self.config = load_config_omega(config_path)
|
|
19
|
+
# User must control dataset folder
|
|
20
|
+
self.config.DATA.data_dir = expand(data_dir or self.config.DATA.data_dir)
|
|
21
|
+
print(self.config.DATA.classes)
|
|
22
|
+
# User controls model saving location (never use BASE_DIR)
|
|
23
|
+
|
|
24
|
+
self.run_id = os.environ.get("RUN_ID") or str(uuid.uuid4())[:8]
|
|
25
|
+
os.environ["RUN_ID"] = self.run_id
|
|
26
|
+
|
|
27
|
+
self.save_dir = expand(
|
|
28
|
+
save_dir or self.config.SYSTEM.save_dir or "./checkpoints"
|
|
29
|
+
)
|
|
30
|
+
save_filename = os.path.join(self.config.MODEL.backbone.type, self.run_id)
|
|
31
|
+
self.config.SYSTEM.save_dir = os.path.join(self.save_dir, save_filename)
|
|
32
|
+
os.makedirs(self.config.SYSTEM.save_dir, exist_ok=True)
|
|
33
|
+
|
|
34
|
+
log_dir = expand(self.config.SYSTEM.log_dir or "./log_dir")
|
|
35
|
+
os.makedirs(os.path.join(self.config.SYSTEM.save_dir, log_dir), exist_ok=True)
|
|
36
|
+
logging.basicConfig(
|
|
37
|
+
level=logging.INFO,
|
|
38
|
+
format="%(asctime)s | %(levelname)s | %(message)s",
|
|
39
|
+
handlers=[
|
|
40
|
+
logging.FileHandler(os.path.join(log_dir, "train.log")),
|
|
41
|
+
logging.StreamHandler(), # still prints to console
|
|
42
|
+
],
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
self.best_checkpoint=None
|
|
46
|
+
|
|
47
|
+
logger = logging.getLogger(__name__)
|
|
48
|
+
|
|
49
|
+
print("CONFIG PATH :", config_path)
|
|
50
|
+
print("DATA DIR :", self.config.DATA.data_dir)
|
|
51
|
+
print("SAVEDIR:", self.config.SYSTEM.save_dir)
|
|
52
|
+
print("Classes :", self.config.DATA.classes)
|
|
53
|
+
|
|
54
|
+
#self.trainer = Trainer(self.config)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def train(self, train_set=None, valid_set=None, pretrained=None, use_ddp=False, use_wandb=True):
|
|
58
|
+
from opensportslib.datasets.builder import build_dataset
|
|
59
|
+
from opensportslib.models.builder import build_model
|
|
60
|
+
from opensportslib.core.trainer.localization_trainer import build_trainer
|
|
61
|
+
from opensportslib.core.utils.default_args import get_default_args_trainer, get_default_args_train
|
|
62
|
+
from opensportslib.core.utils.config import select_device, resolve_config_omega
|
|
63
|
+
from opensportslib.core.utils.load_annotations import check_config
|
|
64
|
+
from opensportslib.core.utils.wandb import init_wandb
|
|
65
|
+
import random
|
|
66
|
+
import numpy as np
|
|
67
|
+
import torch
|
|
68
|
+
# # Load model
|
|
69
|
+
# if pretrained:
|
|
70
|
+
# self.model, self.processor, _ = self.trainer.load(expand(pretrained))
|
|
71
|
+
# else:
|
|
72
|
+
# self.model, self.processor = build_model(self.config, self.trainer.device)
|
|
73
|
+
# Expand annotation paths (user or config)
|
|
74
|
+
self.config.DATA.train.path = expand(train_set or self.config.DATA.train.path)
|
|
75
|
+
self.config.DATA.valid.path = expand(valid_set or self.config.DATA.valid.path)
|
|
76
|
+
|
|
77
|
+
self.config = resolve_config_omega(self.config)
|
|
78
|
+
check_config(self.config, split="train")
|
|
79
|
+
init_wandb(self.config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
|
|
80
|
+
logging.info("Configuration:")
|
|
81
|
+
logging.info(self.config)
|
|
82
|
+
#print(self.config)
|
|
83
|
+
|
|
84
|
+
def set_seed(seed):
|
|
85
|
+
random.seed(seed) # Python random module
|
|
86
|
+
np.random.seed(seed) # NumPy
|
|
87
|
+
torch.manual_seed(seed) # PyTorch
|
|
88
|
+
torch.cuda.manual_seed(seed) # PyTorch CUDA
|
|
89
|
+
torch.cuda.manual_seed_all(seed) # Multi-GPU training
|
|
90
|
+
|
|
91
|
+
# Ensures deterministic behavior
|
|
92
|
+
torch.backends.cudnn.deterministic = True
|
|
93
|
+
torch.backends.cudnn.benchmark = False
|
|
94
|
+
|
|
95
|
+
# Ensures deterministic behavior for CUDA operations
|
|
96
|
+
torch.use_deterministic_algorithms(True, warn_only=True)
|
|
97
|
+
|
|
98
|
+
set_seed(self.config.SYSTEM.seed)
|
|
99
|
+
# Start Timing
|
|
100
|
+
start = time.time()
|
|
101
|
+
|
|
102
|
+
device = select_device(self.config.SYSTEM)
|
|
103
|
+
self.model = build_model(self.config, device=device)
|
|
104
|
+
print(self.model)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# Datasets
|
|
108
|
+
# Train
|
|
109
|
+
data_obj_train = build_dataset(self.config, split="train")
|
|
110
|
+
dataset_Train = data_obj_train.building_dataset(
|
|
111
|
+
cfg=data_obj_train.cfg,
|
|
112
|
+
gpu=self.config.SYSTEM.GPU,
|
|
113
|
+
default_args=data_obj_train.default_args,
|
|
114
|
+
)
|
|
115
|
+
train_loader = data_obj_train.building_dataloader(dataset_Train, cfg=data_obj_train.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=True)
|
|
116
|
+
print(len(train_loader))
|
|
117
|
+
# Valid
|
|
118
|
+
data_obj_valid = build_dataset(self.config,split="valid")
|
|
119
|
+
dataset_Valid = data_obj_valid.building_dataset(
|
|
120
|
+
cfg=data_obj_valid.cfg,
|
|
121
|
+
gpu= self.config.SYSTEM.GPU,
|
|
122
|
+
default_args=data_obj_valid.default_args,
|
|
123
|
+
)
|
|
124
|
+
valid_loader = data_obj_valid.building_dataloader(dataset_Valid, cfg=data_obj_valid.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=True)
|
|
125
|
+
print(len(valid_loader))
|
|
126
|
+
|
|
127
|
+
# Trainer
|
|
128
|
+
trainer = build_trainer(
|
|
129
|
+
cfg=self.config,
|
|
130
|
+
model=self.model,
|
|
131
|
+
default_args=get_default_args_trainer(self.config, len(train_loader)),
|
|
132
|
+
resume_from = pretrained
|
|
133
|
+
)
|
|
134
|
+
# Start training`
|
|
135
|
+
logging.info("Start training")
|
|
136
|
+
|
|
137
|
+
trainer.train(
|
|
138
|
+
**get_default_args_train(
|
|
139
|
+
self.model,
|
|
140
|
+
train_loader,
|
|
141
|
+
valid_loader,
|
|
142
|
+
self.config.DATA.classes,
|
|
143
|
+
self.config.TRAIN.type,
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
self.best_checkpoint = trainer.best_checkpoint_path
|
|
147
|
+
|
|
148
|
+
logging.info(f"Total Execution Time is {time.time()-start} seconds")
|
|
149
|
+
return self.best_checkpoint
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def infer(self, test_set=None, pretrained=None, predictions=None, use_ddp=False, use_wandb=True):
|
|
153
|
+
from opensportslib.datasets.builder import build_dataset
|
|
154
|
+
from opensportslib.models.builder import build_model
|
|
155
|
+
from opensportslib.core.trainer.localization_trainer import build_inferer, build_evaluator
|
|
156
|
+
from opensportslib.core.utils.config import select_device, resolve_config_omega, is_local_path
|
|
157
|
+
from opensportslib.core.utils.checkpoint import load_checkpoint, localization_remap
|
|
158
|
+
from opensportslib.core.utils.load_annotations import check_config, has_localization_events
|
|
159
|
+
from opensportslib.core.utils.wandb import init_wandb
|
|
160
|
+
import time
|
|
161
|
+
|
|
162
|
+
self.config.DATA.test.path = expand(test_set or self.config.DATA.test.path)
|
|
163
|
+
self.config.MODEL.multi_gpu = False
|
|
164
|
+
self.config = resolve_config_omega(self.config)
|
|
165
|
+
check_config(self.config, split="test")
|
|
166
|
+
init_wandb(self.config, run_id=os.environ["RUN_ID"], use_wandb=use_wandb)
|
|
167
|
+
logging.info("Configuration:")
|
|
168
|
+
logging.info(self.config)
|
|
169
|
+
# Start Timing
|
|
170
|
+
start = time.time()
|
|
171
|
+
if pretrained is None and predictions is None:
|
|
172
|
+
if hasattr(self, "best_checkpoint"):
|
|
173
|
+
pretrained = self.best_checkpoint
|
|
174
|
+
print(f"Using last trained checkpoint: {pretrained}")
|
|
175
|
+
else:
|
|
176
|
+
raise ValueError("No pretrained checkpoint provided and no training run found.")
|
|
177
|
+
|
|
178
|
+
if not predictions:
|
|
179
|
+
logging.info("No predictions provided, running inference.")
|
|
180
|
+
device = select_device(self.config.SYSTEM)
|
|
181
|
+
self.model = build_model(self.config, device=device)
|
|
182
|
+
print("Model type:", type(self.model))
|
|
183
|
+
print("Torch model type:", type(self.model._model))
|
|
184
|
+
# Load model
|
|
185
|
+
if pretrained:
|
|
186
|
+
#pretrained = expand(pretrained)
|
|
187
|
+
if is_local_path(pretrained):
|
|
188
|
+
self.config.SYSTEM.work_dir = os.path.dirname(os.path.abspath(pretrained))
|
|
189
|
+
|
|
190
|
+
self.model._model, _, _, epoch = load_checkpoint(model=self.model._model,
|
|
191
|
+
path=pretrained,
|
|
192
|
+
device=device,
|
|
193
|
+
key_remap_fn=localization_remap)
|
|
194
|
+
|
|
195
|
+
# Datasets
|
|
196
|
+
# Test
|
|
197
|
+
data_obj_test = build_dataset(self.config, split="test")
|
|
198
|
+
dataset_Test = data_obj_test.building_dataset(
|
|
199
|
+
cfg=data_obj_test.cfg,
|
|
200
|
+
gpu=self.config.SYSTEM.GPU,
|
|
201
|
+
default_args=data_obj_test.default_args,
|
|
202
|
+
)
|
|
203
|
+
test_loader = data_obj_test.building_dataloader(dataset_Test, cfg=data_obj_test.cfg.dataloader, gpu=self.config.SYSTEM.GPU, dali=True)
|
|
204
|
+
print(len(test_loader))
|
|
205
|
+
|
|
206
|
+
# # Inference
|
|
207
|
+
inferer = build_inferer(cfg=self.config.MODEL,
|
|
208
|
+
model=self.model)
|
|
209
|
+
json_gz_file = inferer.infer(cfg=self.config, data=test_loader)
|
|
210
|
+
|
|
211
|
+
#json_gz_file = self.config.DATA.test.results + ".recall.json.gz"
|
|
212
|
+
json_gz_file = predictions if predictions else json_gz_file
|
|
213
|
+
|
|
214
|
+
metrics = None
|
|
215
|
+
|
|
216
|
+
if has_localization_events(self.config.DATA.test.path):
|
|
217
|
+
logging.info("Ground truth labels detected → running evaluation")
|
|
218
|
+
|
|
219
|
+
evaluator = build_evaluator(cfg=self.config)
|
|
220
|
+
metrics = evaluator.evaluate(
|
|
221
|
+
cfg_testset=self.config.DATA.test,
|
|
222
|
+
json_gz_file=json_gz_file
|
|
223
|
+
)
|
|
224
|
+
else:
|
|
225
|
+
logging.info("No labels found in annotation file → skipping evaluation")
|
|
226
|
+
|
|
227
|
+
logging.info(f"Total Execution Time is {time.time()-start} seconds")
|
|
228
|
+
return metrics
|