opensportslib 0.0.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opensportslib/__init__.py +18 -0
- opensportslib/apis/__init__.py +21 -0
- opensportslib/apis/classification.py +361 -0
- opensportslib/apis/localization.py +228 -0
- opensportslib/config/classification.yaml +104 -0
- opensportslib/config/classification_tracking.yaml +103 -0
- opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
- opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
- opensportslib/config/localization.yaml +132 -0
- opensportslib/config/sngar_frames.yaml +98 -0
- opensportslib/core/__init__.py +0 -0
- opensportslib/core/loss/__init__.py +0 -0
- opensportslib/core/loss/builder.py +40 -0
- opensportslib/core/loss/calf.py +258 -0
- opensportslib/core/loss/ce.py +23 -0
- opensportslib/core/loss/combine.py +42 -0
- opensportslib/core/loss/nll.py +25 -0
- opensportslib/core/optimizer/__init__.py +0 -0
- opensportslib/core/optimizer/builder.py +38 -0
- opensportslib/core/sampler/weighted_sampler.py +104 -0
- opensportslib/core/scheduler/__init__.py +0 -0
- opensportslib/core/scheduler/builder.py +77 -0
- opensportslib/core/trainer/__init__.py +0 -0
- opensportslib/core/trainer/classification_trainer.py +1131 -0
- opensportslib/core/trainer/localization_trainer.py +1009 -0
- opensportslib/core/utils/checkpoint.py +238 -0
- opensportslib/core/utils/config.py +199 -0
- opensportslib/core/utils/data.py +85 -0
- opensportslib/core/utils/ddp.py +77 -0
- opensportslib/core/utils/default_args.py +110 -0
- opensportslib/core/utils/load_annotations.py +485 -0
- opensportslib/core/utils/seed.py +26 -0
- opensportslib/core/utils/video_processing.py +389 -0
- opensportslib/core/utils/wandb.py +110 -0
- opensportslib/datasets/__init__.py +0 -0
- opensportslib/datasets/builder.py +42 -0
- opensportslib/datasets/classification_dataset.py +582 -0
- opensportslib/datasets/localization_dataset.py +813 -0
- opensportslib/datasets/utils/__init__.py +15 -0
- opensportslib/datasets/utils/tracking.py +615 -0
- opensportslib/metrics/classification_metric.py +176 -0
- opensportslib/metrics/localization_metric.py +1482 -0
- opensportslib/models/__init__.py +0 -0
- opensportslib/models/backbones/builder.py +590 -0
- opensportslib/models/base/e2e.py +252 -0
- opensportslib/models/base/tracking.py +73 -0
- opensportslib/models/base/vars.py +29 -0
- opensportslib/models/base/video.py +130 -0
- opensportslib/models/base/video_mae.py +60 -0
- opensportslib/models/builder.py +43 -0
- opensportslib/models/heads/builder.py +266 -0
- opensportslib/models/neck/builder.py +210 -0
- opensportslib/models/utils/common.py +176 -0
- opensportslib/models/utils/impl/__init__.py +0 -0
- opensportslib/models/utils/impl/asformer.py +390 -0
- opensportslib/models/utils/impl/calf.py +74 -0
- opensportslib/models/utils/impl/gsm.py +112 -0
- opensportslib/models/utils/impl/gtad.py +347 -0
- opensportslib/models/utils/impl/tsm.py +123 -0
- opensportslib/models/utils/litebase.py +59 -0
- opensportslib/models/utils/modules.py +120 -0
- opensportslib/models/utils/shift.py +135 -0
- opensportslib/models/utils/utils.py +276 -0
- opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
- opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
- opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
- opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1009 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2022 James Hong, Haotian Zhang, Matthew Fisher, Michael Gharbi,
|
|
3
|
+
Kayvon Fatahalian
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without modification,
|
|
6
|
+
are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation and/or
|
|
13
|
+
other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its contributors
|
|
16
|
+
may be used to endorse or promote products derived from this software without
|
|
17
|
+
specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
20
|
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
21
|
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
|
23
|
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
|
24
|
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
|
25
|
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
|
26
|
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
27
|
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
28
|
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
29
|
+
"""
|
|
30
|
+
from opensportslib.metrics.localization_metric import *
|
|
31
|
+
from opensportslib.core.optimizer.builder import build_optimizer
|
|
32
|
+
from opensportslib.core.optimizer.builder import build_optimizer
|
|
33
|
+
from opensportslib.core.scheduler.builder import build_scheduler
|
|
34
|
+
from opensportslib.core.utils.config import store_json
|
|
35
|
+
from opensportslib.datasets.builder import build_dataset
|
|
36
|
+
import os
|
|
37
|
+
import torch
|
|
38
|
+
import wandb
|
|
39
|
+
import time
|
|
40
|
+
import json
|
|
41
|
+
import tqdm
|
|
42
|
+
import numpy as np
|
|
43
|
+
from opensportslib.core.utils.config import load_gz_json, load_json
|
|
44
|
+
from abc import ABC, abstractmethod
|
|
45
|
+
import logging
|
|
46
|
+
logger = logging.getLogger(__name__)
|
|
47
|
+
|
|
48
|
+
def build_trainer(cfg, model=None, default_args=None, resume_from=None):
|
|
49
|
+
"""Build a trainer from config dict.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
cfg (dict): Config dict. It should at least contain the key "type".
|
|
53
|
+
model : The model that is used to train. Needed only if E2E method because training do not rely on pytorch lightning.
|
|
54
|
+
Default: None.
|
|
55
|
+
default_args (dict | None, optional): Default initialization arguments.
|
|
56
|
+
Default: None.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
evaluator: The constructed trainer.
|
|
60
|
+
"""
|
|
61
|
+
if cfg.TRAIN.type == "trainer_e2e":
|
|
62
|
+
print(cfg.SYSTEM.work_dir)
|
|
63
|
+
checkpoint_dir = default_args["work_dir"]
|
|
64
|
+
start_epoch = 0
|
|
65
|
+
logging.info(f"Checkpoint directory: {checkpoint_dir}")
|
|
66
|
+
|
|
67
|
+
# Handle checkpoint loading
|
|
68
|
+
if resume_from is not None:
|
|
69
|
+
if not os.path.isfile(resume_from):
|
|
70
|
+
raise ValueError(f"Checkpoint file not found: {resume_from}")
|
|
71
|
+
|
|
72
|
+
logging.info(f"Loading checkpoint from: {resume_from}")
|
|
73
|
+
checkpoint = torch.load(resume_from)
|
|
74
|
+
|
|
75
|
+
# Load model state
|
|
76
|
+
model.load(checkpoint['model_state_dict'])
|
|
77
|
+
logging.info("Model state loaded successfully")
|
|
78
|
+
|
|
79
|
+
# Get current training progress
|
|
80
|
+
start_epoch = checkpoint['epoch'] + 1
|
|
81
|
+
logging.info(f"Resuming from epoch {start_epoch}")
|
|
82
|
+
|
|
83
|
+
# Check if we've already reached target epochs
|
|
84
|
+
if start_epoch >= cfg.TRAIN.num_epochs:
|
|
85
|
+
logging.error(f"Model already trained for {start_epoch} epochs")
|
|
86
|
+
logging.error(f"Target epochs in config: {cfg.TRAIN.num_epochs}")
|
|
87
|
+
logging.error("Please increase num_epochs in config to continue training")
|
|
88
|
+
raise ValueError("Need to increase num_epochs to continue training")
|
|
89
|
+
|
|
90
|
+
logging.info(f"Will continue training from epoch {start_epoch} to {cfg.TRAIN.num_epochs}")
|
|
91
|
+
|
|
92
|
+
logging.info("Building optimizer...")
|
|
93
|
+
optimizer, scaler = build_optimizer(model._get_params(), cfg.TRAIN.optimizer)
|
|
94
|
+
|
|
95
|
+
# Load optimizer state if available in checkpoint
|
|
96
|
+
if resume_from is not None and 'optimizer_state_dict' in checkpoint:
|
|
97
|
+
try:
|
|
98
|
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
99
|
+
scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
100
|
+
logging.info("Optimizer and scaler states loaded")
|
|
101
|
+
except Exception as e:
|
|
102
|
+
logging.warning(f"Could not load optimizer state: {e}")
|
|
103
|
+
logging.warning("Will start with fresh optimizer state")
|
|
104
|
+
|
|
105
|
+
logging.info("Building scheduler...")
|
|
106
|
+
lr_scheduler = build_scheduler(optimizer, cfg.TRAIN.scheduler, default_args)
|
|
107
|
+
|
|
108
|
+
# Load scheduler state if available
|
|
109
|
+
if resume_from is not None and 'lr_state_dict' in checkpoint:
|
|
110
|
+
try:
|
|
111
|
+
lr_scheduler.load_state_dict(checkpoint['lr_state_dict'])
|
|
112
|
+
logging.info("Scheduler state loaded")
|
|
113
|
+
except Exception as e:
|
|
114
|
+
logging.warning(f"Could not load scheduler state: {e}")
|
|
115
|
+
logging.warning("Will start with fresh scheduler state")
|
|
116
|
+
|
|
117
|
+
trainer = Trainer_e2e(
|
|
118
|
+
cfg,
|
|
119
|
+
model,
|
|
120
|
+
optimizer,
|
|
121
|
+
scaler,
|
|
122
|
+
lr_scheduler,
|
|
123
|
+
default_args["work_dir"],
|
|
124
|
+
default_args["dali"],
|
|
125
|
+
default_args["repartitions"],
|
|
126
|
+
default_args["cfg_test"],
|
|
127
|
+
#default_args["cfg_challenge"],
|
|
128
|
+
default_args["cfg_valid_data_frames"],
|
|
129
|
+
start_epoch=start_epoch
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Load training history if resuming
|
|
133
|
+
if resume_from is not None:
|
|
134
|
+
trainer.best_epoch = checkpoint.get('best_epoch', 0)
|
|
135
|
+
trainer.best_criterion_valid = checkpoint.get('best_criterion_valid',
|
|
136
|
+
0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
|
|
137
|
+
logging.info(f"Restored best epoch: {trainer.best_epoch}")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
return trainer
|
|
141
|
+
|
|
142
|
+
class Trainer(ABC):
|
|
143
|
+
def __init__(self):
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
@abstractmethod
|
|
147
|
+
def train(self):
|
|
148
|
+
pass
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class Trainer_e2e(Trainer):
|
|
153
|
+
"""Trainer class used for the e2e model.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
args (dict): Dict of config.
|
|
157
|
+
model.
|
|
158
|
+
optimizer (torch.optim.Optimizer): The optimizer to update model parameters. Set to None if validation epoch.
|
|
159
|
+
scaler (torch.cuda.amp.GradScaler): The gradient scaler for mixed precision training.
|
|
160
|
+
lr_scheduler : The learning rate scheduler.
|
|
161
|
+
work_dir (string): The folder in which the different files will be saved.
|
|
162
|
+
dali (bool): Whether videos are processed with dali or opencv.
|
|
163
|
+
repartitions (List[int]): List of gpus used data processing.
|
|
164
|
+
Default: None.
|
|
165
|
+
cfg_test (dict): Dict of config for the inference (testing purpose) and evaluation of the test split. Occurs once training is done.
|
|
166
|
+
Default: None.
|
|
167
|
+
cfg_challenge (dict): Dict of config for the inference (testing purpose) of the challenge split. Occurs once training is done.
|
|
168
|
+
Default: None.
|
|
169
|
+
cfg_valid_data_frames (dict): Dict of config for the inference (testing purpose) and evaluation of the valid split. Occurs through the epochs after a certain number of epochs only if the criterion for the valid split is 'map'.
|
|
170
|
+
Default: None.
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
def __init__(
|
|
174
|
+
self,
|
|
175
|
+
args,
|
|
176
|
+
model,
|
|
177
|
+
optimizer,
|
|
178
|
+
scaler,
|
|
179
|
+
lr_scheduler,
|
|
180
|
+
work_dir,
|
|
181
|
+
dali,
|
|
182
|
+
repartitions=None,
|
|
183
|
+
cfg_test=None,
|
|
184
|
+
#cfg_challenge=None,
|
|
185
|
+
cfg_valid_data_frames=None,
|
|
186
|
+
start_epoch=0
|
|
187
|
+
):
|
|
188
|
+
self.config = args
|
|
189
|
+
self.losses = []
|
|
190
|
+
self.best_epoch = 0
|
|
191
|
+
self.best_criterion_valid = 0 if args.TRAIN.criterion_valid == "map" else float("inf")
|
|
192
|
+
|
|
193
|
+
self.num_epochs = args.TRAIN.num_epochs
|
|
194
|
+
self.epoch = start_epoch
|
|
195
|
+
self.model = model
|
|
196
|
+
|
|
197
|
+
self.optimizer = optimizer
|
|
198
|
+
self.scaler = scaler
|
|
199
|
+
self.lr_scheduler = lr_scheduler
|
|
200
|
+
|
|
201
|
+
self.acc_grad_iter = args.TRAIN.acc_grad_iter
|
|
202
|
+
|
|
203
|
+
self.start_valid_epoch = args.TRAIN.start_valid_epoch
|
|
204
|
+
self.criterion_valid = args.TRAIN.criterion_valid
|
|
205
|
+
self.valid_map_every = args.TRAIN.valid_map_every
|
|
206
|
+
#self.save_dir = work_dir
|
|
207
|
+
self.dali = dali
|
|
208
|
+
|
|
209
|
+
self.repartitions = repartitions
|
|
210
|
+
self.cfg_test = cfg_test
|
|
211
|
+
#self.cfg_challenge = cfg_challenge
|
|
212
|
+
self.cfg_valid_data_frames = cfg_valid_data_frames
|
|
213
|
+
|
|
214
|
+
self.best_checkpoint_path = None
|
|
215
|
+
|
|
216
|
+
self.save_dir = work_dir #os.path.join(work_dir, run_name, run_id)
|
|
217
|
+
os.makedirs(self.save_dir, exist_ok=True)
|
|
218
|
+
try:
|
|
219
|
+
wandb.watch(self.model, log="gradients", log_freq=100)
|
|
220
|
+
except Exception:
|
|
221
|
+
pass
|
|
222
|
+
|
|
223
|
+
def save_checkpoint(self, epoch, is_best=False):
|
|
224
|
+
"""Save checkpoint with training state."""
|
|
225
|
+
checkpoint = {
|
|
226
|
+
'epoch': epoch,
|
|
227
|
+
'model_state_dict': self.model.state_dict(),
|
|
228
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
229
|
+
'scaler_state_dict': self.scaler.state_dict(),
|
|
230
|
+
'lr_state_dict': self.lr_scheduler.state_dict(),
|
|
231
|
+
'best_epoch': self.best_epoch,
|
|
232
|
+
'best_criterion_valid': self.best_criterion_valid
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
os.makedirs(self.save_dir, exist_ok=True)
|
|
236
|
+
# Save latest checkpoint
|
|
237
|
+
# latest_path = os.path.join(self.save_dir, f"latest_checkpoint_{epoch:03d}.pt")
|
|
238
|
+
# torch.save(checkpoint, latest_path)
|
|
239
|
+
# logging.info(f"Latest checkpoint saved: {latest_path}")
|
|
240
|
+
|
|
241
|
+
# # Remove previous latest checkpoint
|
|
242
|
+
# for f in os.listdir(self.save_dir):
|
|
243
|
+
# if f.startswith("latest_checkpoint_") and f != os.path.basename(latest_path):
|
|
244
|
+
# os.remove(os.path.join(self.save_dir, f))
|
|
245
|
+
|
|
246
|
+
latest_path = os.path.join(self.save_dir, "latest_checkpoint.pt")
|
|
247
|
+
torch.save(checkpoint, latest_path)
|
|
248
|
+
logging.info(f"Latest checkpoint saved: {latest_path}")
|
|
249
|
+
|
|
250
|
+
# Save best checkpoint if needed
|
|
251
|
+
# if is_best:
|
|
252
|
+
# # best_path = os.path.join(self.save_dir, f"best_checkpoint_{epoch:03d}.pt")
|
|
253
|
+
# best_path = os.path.join(self.save_dir, f"best_checkpoint_{epoch:03d}.pt")
|
|
254
|
+
# torch.save(checkpoint, best_path)
|
|
255
|
+
# logging.info(f"Best checkpoint saved: {best_path}")
|
|
256
|
+
|
|
257
|
+
# # Remove previous best checkpoint
|
|
258
|
+
# for f in os.listdir(self.save_dir):
|
|
259
|
+
# if f.startswith("best_checkpoint_") and f != os.path.basename(best_path):
|
|
260
|
+
# os.remove(os.path.join(self.save_dir, f))
|
|
261
|
+
|
|
262
|
+
if is_best:
|
|
263
|
+
best_path = os.path.join(self.save_dir, "best_checkpoint.pt")
|
|
264
|
+
self.best_checkpoint_path = best_path
|
|
265
|
+
torch.save(checkpoint, best_path)
|
|
266
|
+
logging.info(f"Best checkpoint saved: {best_path}")
|
|
267
|
+
|
|
268
|
+
def train(self, train_loader, valid_loader, classes):
|
|
269
|
+
"""Training loop with checkpoint management."""
|
|
270
|
+
if self.criterion_valid == "map":
|
|
271
|
+
data_obj_valid = build_dataset(self.config, split="valid_data_frames")
|
|
272
|
+
dataset_Valid_Frames = data_obj_valid.building_dataset(
|
|
273
|
+
data_obj_valid.cfg,
|
|
274
|
+
None,
|
|
275
|
+
{"repartitions": self.repartitions, "classes": classes},
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
for epoch in range(self.epoch, self.num_epochs):
|
|
279
|
+
train_loss = self.model.epoch(
|
|
280
|
+
train_loader,
|
|
281
|
+
self.dali,
|
|
282
|
+
self.optimizer,
|
|
283
|
+
self.scaler,
|
|
284
|
+
lr_scheduler=self.lr_scheduler,
|
|
285
|
+
acc_grad_iter=self.acc_grad_iter,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
valid_loss = self.model.epoch(
|
|
289
|
+
valid_loader, self.dali, acc_grad_iter=self.acc_grad_iter
|
|
290
|
+
)
|
|
291
|
+
print(
|
|
292
|
+
f"[Epoch {epoch+1}/{self.num_epochs}] Train loss: {train_loss:.5f} Valid loss: {valid_loss:.5f}"
|
|
293
|
+
)
|
|
294
|
+
logging.info(
|
|
295
|
+
f"[Epoch {epoch+1}/{self.num_epochs}] Train loss: {train_loss:.5f} Valid loss: {valid_loss:.5f}"
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
valid_mAP = 0
|
|
299
|
+
is_best = False
|
|
300
|
+
|
|
301
|
+
if self.criterion_valid == "loss":
|
|
302
|
+
if valid_loss < self.best_criterion_valid:
|
|
303
|
+
self.best_criterion_valid = valid_loss
|
|
304
|
+
self.best_epoch = epoch
|
|
305
|
+
is_best = True
|
|
306
|
+
print("New best epoch!")
|
|
307
|
+
elif self.criterion_valid == "map":
|
|
308
|
+
if epoch >= self.start_valid_epoch and epoch % self.valid_map_every == 0:
|
|
309
|
+
pred_file = None
|
|
310
|
+
if self.save_dir is not None:
|
|
311
|
+
pred_file = os.path.join(
|
|
312
|
+
self.save_dir, f"pred-valid_{epoch:03d}"
|
|
313
|
+
)
|
|
314
|
+
os.makedirs(self.save_dir, exist_ok=True)
|
|
315
|
+
valid_mAP = infer_and_process_predictions_e2e(
|
|
316
|
+
self.model,
|
|
317
|
+
self.dali,
|
|
318
|
+
dataset_Valid_Frames,
|
|
319
|
+
"VALID",
|
|
320
|
+
classes,
|
|
321
|
+
pred_file,
|
|
322
|
+
dataloader_params=self.cfg_valid_data_frames.dataloader,
|
|
323
|
+
)
|
|
324
|
+
if valid_mAP > self.best_criterion_valid:
|
|
325
|
+
self.best_criterion_valid = valid_mAP
|
|
326
|
+
self.best_epoch = epoch
|
|
327
|
+
is_best = True
|
|
328
|
+
print("New best epoch!")
|
|
329
|
+
else:
|
|
330
|
+
print("Unknown criterion:", self.criterion_valid)
|
|
331
|
+
|
|
332
|
+
self.losses.append(
|
|
333
|
+
{
|
|
334
|
+
"epoch": epoch,
|
|
335
|
+
"train": train_loss,
|
|
336
|
+
"valid": valid_loss,
|
|
337
|
+
"valid_mAP": valid_mAP,
|
|
338
|
+
}
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# ---------------- W&B LOG ----------------
|
|
342
|
+
wandb.log({
|
|
343
|
+
"epoch": epoch + 1,
|
|
344
|
+
"train/loss": train_loss,
|
|
345
|
+
"valid/loss": valid_loss,
|
|
346
|
+
"valid/mAP": valid_mAP,
|
|
347
|
+
"lr": self.optimizer.param_groups[0]["lr"],
|
|
348
|
+
"best/mAP": self.best_criterion_valid if self.criterion_valid == "map" else None,
|
|
349
|
+
"best/loss": self.best_criterion_valid if self.criterion_valid == "loss" else None,
|
|
350
|
+
})
|
|
351
|
+
|
|
352
|
+
if self.save_dir is not None:
|
|
353
|
+
os.makedirs(self.save_dir, exist_ok=True)
|
|
354
|
+
store_json(
|
|
355
|
+
os.path.join(self.save_dir, "loss.json"),
|
|
356
|
+
self.losses,
|
|
357
|
+
pretty=True
|
|
358
|
+
)
|
|
359
|
+
self.save_checkpoint(epoch, is_best)
|
|
360
|
+
|
|
361
|
+
logging.info(f"Training completed. Best epoch: {self.best_epoch}")
|
|
362
|
+
|
|
363
|
+
if self.dali:
|
|
364
|
+
train_loader.delete()
|
|
365
|
+
valid_loader.delete()
|
|
366
|
+
if self.criterion_valid == "map":
|
|
367
|
+
dataset_Valid_Frames.delete()
|
|
368
|
+
|
|
369
|
+
if self.save_dir is not None:
|
|
370
|
+
self._run_final_evaluation(classes, eval_splits=["valid"])
|
|
371
|
+
|
|
372
|
+
def _run_final_evaluation(self, classes, eval_splits=["valid", "test"]):
|
|
373
|
+
from opensportslib.core.utils.checkpoint import load_checkpoint, localization_remap
|
|
374
|
+
"""Run final evaluation using best model."""
|
|
375
|
+
# Load best model for evaluation
|
|
376
|
+
best_checkpoint_path = os.path.join(
|
|
377
|
+
self.save_dir, f"best_checkpoint.pt"
|
|
378
|
+
)
|
|
379
|
+
self.model._model, _, _, epoch = load_checkpoint(model=self.model._model,
|
|
380
|
+
path=best_checkpoint_path,
|
|
381
|
+
key_remap_fn=localization_remap)
|
|
382
|
+
logging.info(f"Loaded best model from epoch {self.best_epoch}")
|
|
383
|
+
|
|
384
|
+
for split in eval_splits:
|
|
385
|
+
if split == "valid":
|
|
386
|
+
cfg_tmp = self.cfg_valid_data_frames
|
|
387
|
+
split = "valid_data_frames"
|
|
388
|
+
elif split == "test":
|
|
389
|
+
cfg_tmp = self.cfg_test
|
|
390
|
+
# elif split == "challenge":
|
|
391
|
+
# cfg_tmp = self.cfg_challenge
|
|
392
|
+
|
|
393
|
+
split_path = os.path.join(cfg_tmp.path)
|
|
394
|
+
if not os.path.exists(split_path):
|
|
395
|
+
continue
|
|
396
|
+
|
|
397
|
+
data_obj = build_dataset(self.config, split=split)
|
|
398
|
+
split_data = data_obj.building_dataset(
|
|
399
|
+
data_obj.cfg,
|
|
400
|
+
None,
|
|
401
|
+
{"repartitions": self.repartitions, "classes": classes},
|
|
402
|
+
)
|
|
403
|
+
split_data.print_info()
|
|
404
|
+
|
|
405
|
+
pred_file = None
|
|
406
|
+
if self.save_dir is not None:
|
|
407
|
+
pred_file = os.path.join(
|
|
408
|
+
self.save_dir, f"pred-{split}_{self.best_epoch:03d}"
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
infer_and_process_predictions_e2e(
|
|
412
|
+
self.model,
|
|
413
|
+
self.dali,
|
|
414
|
+
split_data,
|
|
415
|
+
split.upper(),
|
|
416
|
+
classes,
|
|
417
|
+
pred_file,
|
|
418
|
+
calc_stats=split != "challenge",
|
|
419
|
+
dataloader_params=cfg_tmp.dataloader,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
if self.dali:
|
|
423
|
+
split_data.delete()
|
|
424
|
+
|
|
425
|
+
logging.info(f"Final evaluation completed. Best epoch: {self.best_epoch}")
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def build_inferer(cfg, model, default_args=None):
|
|
429
|
+
"""Build a inferer from config dict.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
cfg (dict): Config dict. It should at least contain the key "type".
|
|
433
|
+
model: The model that will be used to infer.
|
|
434
|
+
default_args (dict | None, optional): Default initialization arguments.
|
|
435
|
+
Default: None.
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
inferer: The constructed inferer.
|
|
439
|
+
"""
|
|
440
|
+
|
|
441
|
+
if cfg.runner.type == "runner_JSON":
|
|
442
|
+
inferer = Inferer(cfg=cfg, model=model, infer_Spotting="infer_JSON")
|
|
443
|
+
elif cfg.runner.type == "runner_pooling":
|
|
444
|
+
inferer = Inferer(cfg=cfg, model=model, infer_Spotting="infer_SN")
|
|
445
|
+
elif cfg.runner.type == "runner_CALF":
|
|
446
|
+
inferer = Inferer(cfg=cfg, model=model, infer_Spotting="infer_SN")
|
|
447
|
+
elif cfg.runner.type == "runner_e2e":
|
|
448
|
+
inferer = Inferer(cfg=cfg, model=model, infer_Spotting="infer_E2E")
|
|
449
|
+
|
|
450
|
+
return inferer
|
|
451
|
+
|
|
452
|
+
class Inferer:
|
|
453
|
+
def __init__(self, cfg, model, infer_Spotting):
|
|
454
|
+
"""Initialize the Inferer class.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
cfg (dict): Config dict. It should at least contain the key "type".
|
|
458
|
+
model: The model that will be used to infer.
|
|
459
|
+
infer_Spotting: The method that is used to infer.
|
|
460
|
+
"""
|
|
461
|
+
self.cfg_model = cfg
|
|
462
|
+
self.model = model
|
|
463
|
+
self.infer_Spotting=infer_Spotting
|
|
464
|
+
|
|
465
|
+
def infer(self, cfg, data):
|
|
466
|
+
"""Infer actions from data.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
data : The data from which we will infer.
|
|
470
|
+
|
|
471
|
+
Returns:
|
|
472
|
+
Dict containing predictions
|
|
473
|
+
"""
|
|
474
|
+
if self.infer_Spotting=="infer_JSON":
|
|
475
|
+
return self.infer_JSON(cfg, self.model, data)
|
|
476
|
+
elif self.infer_Spotting=="infer_SN":
|
|
477
|
+
return self.infer_SN(cfg, self.model, data)
|
|
478
|
+
elif self.infer_Spotting=="infer_E2E":
|
|
479
|
+
return self.infer_E2E(cfg, self.model, data)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def infer_common(self, cfg, model, data):
|
|
483
|
+
"""Infer actions from data using a given model.
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
cfg (dict): Config dict. It should at least contain the key "type".
|
|
487
|
+
model: The model that will be used to infer.
|
|
488
|
+
data : The data from which we will infer.
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
Dict containing predictions
|
|
492
|
+
"""
|
|
493
|
+
# Run Inference on Dataset
|
|
494
|
+
pass
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def infer_JSON(self, cfg, model, data):
|
|
498
|
+
"""Infer actions from data using a given model for NetVlad/CALF methods
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
cfg (dict): Config dict. It should at least contain the key "type".
|
|
502
|
+
model: The model that will be used to infer.
|
|
503
|
+
data : The data from which we will infer.
|
|
504
|
+
|
|
505
|
+
Returns:
|
|
506
|
+
Dict containing predictions
|
|
507
|
+
"""
|
|
508
|
+
return self.infer_common(cfg, model, data)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def infer_SN(self, cfg, model, data):
|
|
512
|
+
"""Infer actions from data using a given model for the SNV2 data
|
|
513
|
+
|
|
514
|
+
Args:
|
|
515
|
+
cfg (dict): Config dict. It should at least contain the key "type".
|
|
516
|
+
model: The model that will be used to infer.
|
|
517
|
+
data : The data from which we will infer.
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
Dict containing predictions
|
|
521
|
+
"""
|
|
522
|
+
return self.infer_common(cfg, model, data)
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def infer_E2E(self, cfg, model, data):
|
|
526
|
+
"""Infer actions from data using a given model for the e2espot method.
|
|
527
|
+
|
|
528
|
+
Args:
|
|
529
|
+
cfg (dict): Config dict. It should at least contain the key "type".
|
|
530
|
+
model: The model that will be used to infer.
|
|
531
|
+
data : The data from which we will infer.
|
|
532
|
+
|
|
533
|
+
Returns:
|
|
534
|
+
Dict containing predictions
|
|
535
|
+
"""
|
|
536
|
+
pred_file = None
|
|
537
|
+
if cfg.SYSTEM.work_dir is not None:
|
|
538
|
+
pred_file = os.path.join(cfg.SYSTEM.work_dir, cfg.DATA.test.results)
|
|
539
|
+
mAP = infer_and_process_predictions_e2e(
|
|
540
|
+
model,
|
|
541
|
+
getattr(cfg, "dali", False),
|
|
542
|
+
data,
|
|
543
|
+
"infer",
|
|
544
|
+
cfg.DATA.classes,
|
|
545
|
+
pred_file,
|
|
546
|
+
True,
|
|
547
|
+
cfg.DATA.test.dataloader,
|
|
548
|
+
return_pred=False,
|
|
549
|
+
)
|
|
550
|
+
wandb.log({
|
|
551
|
+
"test/Avg_mAP": mAP,
|
|
552
|
+
})
|
|
553
|
+
pred_json_file = os.path.join(pred_file + ".json")
|
|
554
|
+
pred_recall_file = os.path.join(pred_file + ".recall.json.gz")
|
|
555
|
+
logging.info("Predictions saved")
|
|
556
|
+
logging.info(pred_json_file)
|
|
557
|
+
logging.info("High recall predictions saved")
|
|
558
|
+
logging.info(pred_recall_file)
|
|
559
|
+
#json_gz_file = cfg.DATA.test.results + ".recall.json.gz"
|
|
560
|
+
return pred_recall_file
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def build_evaluator(cfg, default_args=None):
|
|
564
|
+
"""Build a evaluator from config dict.
|
|
565
|
+
|
|
566
|
+
Args:
|
|
567
|
+
cfg (dict): Config dict. It should at least contain the key "type".
|
|
568
|
+
default_args (dict | None, optional): Default initialization arguments.
|
|
569
|
+
Default: None.
|
|
570
|
+
|
|
571
|
+
Returns:
|
|
572
|
+
evaluator: The constructed evaluator.
|
|
573
|
+
"""
|
|
574
|
+
if cfg.MODEL.runner.type == "runner_JSON":
|
|
575
|
+
evaluator = Evaluator(cfg=cfg, evaluate_Spotting="evaluate_pred_JSON")
|
|
576
|
+
elif cfg.MODEL.runner.type == "runner_pooling":
|
|
577
|
+
evaluator = Evaluator(cfg=cfg, evaluate_Spotting="evaluate_pred_SN")
|
|
578
|
+
elif cfg.MODEL.runner.type == "runner_CALF":
|
|
579
|
+
evaluator = Evaluator(cfg=cfg, evaluate_Spotting="evaluate_pred_SN")
|
|
580
|
+
elif cfg.MODEL.runner.type == "runner_e2e":
|
|
581
|
+
evaluator = Evaluator(cfg=cfg, evaluate_Spotting="evaluate_pred_E2E")
|
|
582
|
+
|
|
583
|
+
return evaluator
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
class Evaluator:
|
|
587
|
+
"""Evaluator class that is used to make easier the process of evaluate since there is only
|
|
588
|
+
one evaluate method that uses the evaluate_Spotting method.
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
cfg (dict): Config dict.
|
|
592
|
+
evaluate_Spotting (method): The method that is used to evaluate.
|
|
593
|
+
"""
|
|
594
|
+
|
|
595
|
+
def __init__(self, cfg, evaluate_Spotting):
|
|
596
|
+
self.cfg = cfg
|
|
597
|
+
self.extract_fps = getattr(cfg.DATA, "extract_fps", 2)
|
|
598
|
+
self.evaluate_Spotting = evaluate_Spotting
|
|
599
|
+
|
|
600
|
+
def evaluate(self, cfg_testset, json_gz_file=None):
|
|
601
|
+
"""Evaluate predictions.
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
cfg_testset (dict): Config dict that contains informations for the predictions.
|
|
605
|
+
"""
|
|
606
|
+
if self.evaluate_Spotting == "evaluate_pred_JSON":
|
|
607
|
+
return self.evaluate_pred_JSON(cfg_testset, self.cfg.SYSTEM.work_dir, json_gz_file, metric=cfg_testset.metric)
|
|
608
|
+
elif self.evaluate_Spotting == "evaluate_pred_SN":
|
|
609
|
+
return self.evaluate_pred_SN(cfg_testset, self.cfg.SYSTEM.work_dir, json_gz_file, metric=cfg_testset.metric)
|
|
610
|
+
elif self.evaluate_Spotting == "evaluate_pred_E2E":
|
|
611
|
+
return self.evaluate_pred_E2E(cfg_testset, self.cfg.SYSTEM.work_dir, json_gz_file, metric=cfg_testset.metric)
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
# def evaluate_common_JSON(self, cfg, results, metric):
|
|
615
|
+
# if cfg.path == None:
|
|
616
|
+
# return
|
|
617
|
+
# with open(cfg.path) as f:
|
|
618
|
+
# GT_data = json.load(f)
|
|
619
|
+
|
|
620
|
+
# print(results)
|
|
621
|
+
# pred_path_is_json = False
|
|
622
|
+
# if results.endswith(".json"):
|
|
623
|
+
# pred_path_is_json = True
|
|
624
|
+
# with open(results) as f:
|
|
625
|
+
# pred_data = json.load(f)
|
|
626
|
+
|
|
627
|
+
# targets_numpy = list()
|
|
628
|
+
# detections_numpy = list()
|
|
629
|
+
# closests_numpy = list()
|
|
630
|
+
|
|
631
|
+
# if "labels" in GT_data.keys():
|
|
632
|
+
# classes = GT_data["labels"]
|
|
633
|
+
# else:
|
|
634
|
+
# assert isinstance(cfg.classes, list) or os.path.isfile(cfg.classes)
|
|
635
|
+
# if isinstance(cfg.classes, list):
|
|
636
|
+
# classes = cfg.classes
|
|
637
|
+
|
|
638
|
+
# classes = sorted(classes)
|
|
639
|
+
# EVENT_DICTIONARY = {x: i for i, x in enumerate(classes)}
|
|
640
|
+
# INVERSE_EVENT_DICTIONARY = {i: x for i, x in enumerate(classes)}
|
|
641
|
+
|
|
642
|
+
# if "videos" in GT_data.keys():
|
|
643
|
+
# videos = GT_data["videos"]
|
|
644
|
+
# else:
|
|
645
|
+
# videos = [GT_data]
|
|
646
|
+
|
|
647
|
+
# for game in tqdm.tqdm(videos):
|
|
648
|
+
# print(game.keys())
|
|
649
|
+
# # fetch labels
|
|
650
|
+
# labels = game["annotations"]
|
|
651
|
+
# if not pred_path_is_json:
|
|
652
|
+
# try:
|
|
653
|
+
# pred_file = os.path.join(results, os.path.splitext(game["path"])[0], "results_spotting.json")
|
|
654
|
+
# print(pred_file)
|
|
655
|
+
# with open(pred_file) as f:
|
|
656
|
+
# pred_data = json.load(f)
|
|
657
|
+
# except FileNotFoundError:
|
|
658
|
+
# continue
|
|
659
|
+
# predictions = pred_data["predictions"]
|
|
660
|
+
# # convert labels to dense vector
|
|
661
|
+
# dense_labels = label2vector(
|
|
662
|
+
# labels,
|
|
663
|
+
# num_classes=len(classes),
|
|
664
|
+
# EVENT_DICTIONARY=EVENT_DICTIONARY,
|
|
665
|
+
# framerate=(
|
|
666
|
+
# pred_data["fps"] if "fps" in pred_data.keys() else self.extract_fps
|
|
667
|
+
# ),
|
|
668
|
+
# )
|
|
669
|
+
# print(dense_labels.shape)
|
|
670
|
+
# # convert predictions to vector
|
|
671
|
+
# dense_predictions = predictions2vector(
|
|
672
|
+
# predictions,
|
|
673
|
+
# vector_size=game["num_frames"] if "num_frames" in game.keys() else None,
|
|
674
|
+
# framerate=(
|
|
675
|
+
# pred_data["fps"] if "fps" in pred_data.keys() else self.extract_fps
|
|
676
|
+
# ),
|
|
677
|
+
# num_classes=len(classes),
|
|
678
|
+
# EVENT_DICTIONARY=EVENT_DICTIONARY,
|
|
679
|
+
# )
|
|
680
|
+
# print(dense_predictions.shape)
|
|
681
|
+
|
|
682
|
+
# targets_numpy.append(dense_labels)
|
|
683
|
+
# detections_numpy.append(dense_predictions)
|
|
684
|
+
|
|
685
|
+
# closest_numpy = np.zeros(dense_labels.shape) - 1
|
|
686
|
+
# # Get the closest action index
|
|
687
|
+
# closests_numpy.append(get_closest_action_index(dense_labels, closest_numpy))
|
|
688
|
+
|
|
689
|
+
# if targets_numpy:
|
|
690
|
+
# return compute_performances_mAP(
|
|
691
|
+
# metric,
|
|
692
|
+
# targets_numpy,
|
|
693
|
+
# detections_numpy,
|
|
694
|
+
# closests_numpy,
|
|
695
|
+
# INVERSE_EVENT_DICTIONARY,
|
|
696
|
+
# )
|
|
697
|
+
# else:
|
|
698
|
+
# logging.warning("No predictions found for evaluation. Returning None.")
|
|
699
|
+
# return None
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
def evaluate_common_JSON(self, cfg, results, metric):
|
|
704
|
+
|
|
705
|
+
if cfg.path is None:
|
|
706
|
+
return
|
|
707
|
+
|
|
708
|
+
# --------------------------------------------------
|
|
709
|
+
# LOAD GT
|
|
710
|
+
# --------------------------------------------------
|
|
711
|
+
with open(cfg.path) as f:
|
|
712
|
+
GT_data = json.load(f)
|
|
713
|
+
|
|
714
|
+
# --------------------------------------------------
|
|
715
|
+
# LOAD PRED FILE (json / json.gz / folder)
|
|
716
|
+
# --------------------------------------------------
|
|
717
|
+
pred_data = None
|
|
718
|
+
pred_path_is_file = results.endswith(".json") or results.endswith(".json.gz")
|
|
719
|
+
|
|
720
|
+
if pred_path_is_file:
|
|
721
|
+
pred_data = load_gz_json(results) if results.endswith(".gz") else load_json(results)
|
|
722
|
+
|
|
723
|
+
# detect v2 prediction
|
|
724
|
+
pred_is_v2 = isinstance(pred_data, dict) and pred_data is not None and "data" in pred_data
|
|
725
|
+
# --------------------------------------------------
|
|
726
|
+
# CLASSES
|
|
727
|
+
# --------------------------------------------------
|
|
728
|
+
if isinstance(GT_data.get("labels"), dict):
|
|
729
|
+
classes = list(GT_data["labels"].values())[0]["labels"]
|
|
730
|
+
elif "labels" in GT_data:
|
|
731
|
+
classes = GT_data["labels"]
|
|
732
|
+
else:
|
|
733
|
+
classes = cfg.classes
|
|
734
|
+
|
|
735
|
+
classes = sorted(classes)
|
|
736
|
+
EVENT_DICTIONARY = {x: i for i, x in enumerate(classes)}
|
|
737
|
+
INVERSE_EVENT_DICTIONARY = {i: x for i, x in enumerate(classes)}
|
|
738
|
+
|
|
739
|
+
# --------------------------------------------------
|
|
740
|
+
# GT VIDEOS
|
|
741
|
+
# --------------------------------------------------
|
|
742
|
+
if "videos" in GT_data:
|
|
743
|
+
videos = GT_data["videos"]
|
|
744
|
+
gt_is_v2 = False
|
|
745
|
+
else:
|
|
746
|
+
videos = GT_data["data"]
|
|
747
|
+
gt_is_v2 = True
|
|
748
|
+
|
|
749
|
+
# --------------------------------------------------
|
|
750
|
+
# BUILD PRED LOOKUP IF V2
|
|
751
|
+
# --------------------------------------------------
|
|
752
|
+
pred_lookup = {}
|
|
753
|
+
if pred_is_v2:
|
|
754
|
+
for item in pred_data["data"]:
|
|
755
|
+
video_path = item["inputs"][0]["path"]
|
|
756
|
+
pred_lookup[video_path] = item
|
|
757
|
+
|
|
758
|
+
targets_numpy = []
|
|
759
|
+
detections_numpy = []
|
|
760
|
+
closests_numpy = []
|
|
761
|
+
|
|
762
|
+
# ==================================================
|
|
763
|
+
# LOOP
|
|
764
|
+
# ==================================================
|
|
765
|
+
for game in tqdm.tqdm(videos):
|
|
766
|
+
|
|
767
|
+
# ---------------- GT ----------------
|
|
768
|
+
if gt_is_v2:
|
|
769
|
+
video_path = game["inputs"][0]["path"]
|
|
770
|
+
labels = [{"label": e.get("label"),
|
|
771
|
+
"gameTime": e.get("gameTime"),
|
|
772
|
+
"position": int(e.get("position_ms")),
|
|
773
|
+
} for e in game.get("events", [])]
|
|
774
|
+
else:
|
|
775
|
+
video_path = game["path"]
|
|
776
|
+
labels = game["annotations"]
|
|
777
|
+
|
|
778
|
+
# ---------------- PRED ----------------
|
|
779
|
+
if pred_path_is_file:
|
|
780
|
+
|
|
781
|
+
# ===== V2 PRED =====
|
|
782
|
+
if pred_is_v2:
|
|
783
|
+
if video_path not in pred_lookup:
|
|
784
|
+
continue
|
|
785
|
+
|
|
786
|
+
item = pred_lookup[video_path]
|
|
787
|
+
fps = item["inputs"][0].get("fps", self.extract_fps)
|
|
788
|
+
|
|
789
|
+
predictions = [
|
|
790
|
+
{
|
|
791
|
+
"label": e.get("label"),
|
|
792
|
+
"gameTime": e.get("gameTime"),
|
|
793
|
+
"confidence": e.get("confidence"),
|
|
794
|
+
"position": int(e.get("position_ms")),
|
|
795
|
+
"frame": e.get("frame")
|
|
796
|
+
}
|
|
797
|
+
for e in item.get("events", [])
|
|
798
|
+
]
|
|
799
|
+
|
|
800
|
+
# ===== OLD PRED =====
|
|
801
|
+
else:
|
|
802
|
+
if "predictions" not in pred_data:
|
|
803
|
+
continue
|
|
804
|
+
|
|
805
|
+
predictions = pred_data["predictions"]
|
|
806
|
+
fps = pred_data.get("fps", self.extract_fps)
|
|
807
|
+
|
|
808
|
+
else:
|
|
809
|
+
# ===== FOLDER MODE =====
|
|
810
|
+
pred_file = os.path.join(results, os.path.splitext(video_path)[0], "results_spotting.json")
|
|
811
|
+
|
|
812
|
+
if not os.path.exists(pred_file):
|
|
813
|
+
continue
|
|
814
|
+
|
|
815
|
+
with open(pred_file) as f:
|
|
816
|
+
pred_data_local = json.load(f)
|
|
817
|
+
|
|
818
|
+
if "data" in pred_data_local:
|
|
819
|
+
# v2 file inside folder
|
|
820
|
+
item = pred_data_local["data"][0]
|
|
821
|
+
fps = item["inputs"][0].get("fps", self.extract_fps)
|
|
822
|
+
|
|
823
|
+
predictions = [
|
|
824
|
+
{
|
|
825
|
+
"label": e.get("label"),
|
|
826
|
+
"gameTime": e.get("gameTime"),
|
|
827
|
+
"confidence": e.get("confidence"),
|
|
828
|
+
"position": int(e.get("position_ms")),
|
|
829
|
+
"frame": e.get("frame")
|
|
830
|
+
}
|
|
831
|
+
for e in item.get("events", [])
|
|
832
|
+
]
|
|
833
|
+
else:
|
|
834
|
+
predictions = pred_data_local["predictions"]
|
|
835
|
+
fps = pred_data_local.get("fps", self.extract_fps)
|
|
836
|
+
|
|
837
|
+
# ---------------- VECTORS ----------------
|
|
838
|
+
dense_labels = label2vector(labels, num_classes=len(classes), EVENT_DICTIONARY=EVENT_DICTIONARY, framerate=fps)
|
|
839
|
+
|
|
840
|
+
dense_predictions = predictions2vector(
|
|
841
|
+
predictions,
|
|
842
|
+
vector_size=None,
|
|
843
|
+
framerate=fps,
|
|
844
|
+
num_classes=len(classes),
|
|
845
|
+
EVENT_DICTIONARY=EVENT_DICTIONARY,
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
targets_numpy.append(dense_labels)
|
|
849
|
+
detections_numpy.append(dense_predictions)
|
|
850
|
+
|
|
851
|
+
closest_numpy = np.zeros(dense_labels.shape) - 1
|
|
852
|
+
closests_numpy.append(get_closest_action_index(dense_labels, closest_numpy))
|
|
853
|
+
|
|
854
|
+
# --------------------------------------------------
|
|
855
|
+
# METRICS
|
|
856
|
+
# --------------------------------------------------
|
|
857
|
+
if targets_numpy:
|
|
858
|
+
return compute_performances_mAP(
|
|
859
|
+
metric,
|
|
860
|
+
targets_numpy,
|
|
861
|
+
detections_numpy,
|
|
862
|
+
closests_numpy,
|
|
863
|
+
INVERSE_EVENT_DICTIONARY,
|
|
864
|
+
)
|
|
865
|
+
else:
|
|
866
|
+
logging.warning("No predictions found.")
|
|
867
|
+
return None
|
|
868
|
+
|
|
869
|
+
def evaluate_pred_E2E(self, cfg, work_dir, pred_path, metric="loose"):
|
|
870
|
+
"""Evaluate predictions infered with E2E method and display performances.
|
|
871
|
+
Args:
|
|
872
|
+
cfg (dict): It should containt the keys; classes (list of classes), path (path of the groundtruth data).
|
|
873
|
+
It should contain the key nms_window if evaluation of raw predictions. It should containt the key extract_fps if predictions file do not contain the fps at which the frames have been processed to infer.
|
|
874
|
+
work_dir: The folder path under which the prediction files are stored.
|
|
875
|
+
pred_path: The path for predictions files. It can be:
|
|
876
|
+
- folder path (that contains predictions files)
|
|
877
|
+
- file path (if raw prediction that needs to be processed first)
|
|
878
|
+
metric (string): metric used to evaluate.
|
|
879
|
+
In ["loose","tight","at1","at2","at3","at4","at5"].
|
|
880
|
+
Default: "loose".
|
|
881
|
+
|
|
882
|
+
Returns
|
|
883
|
+
The different mAPs computed.
|
|
884
|
+
"""
|
|
885
|
+
|
|
886
|
+
results = pred_path
|
|
887
|
+
|
|
888
|
+
if os.path.isfile(results) and (
|
|
889
|
+
results.endswith(".gz") or results.endswith(".json")
|
|
890
|
+
):
|
|
891
|
+
pred = (load_gz_json if results.endswith(".gz") else load_json)(results)
|
|
892
|
+
# --------------------------------------------------
|
|
893
|
+
# SUPPORT NEW V2 FORMAT (dict)
|
|
894
|
+
# --------------------------------------------------
|
|
895
|
+
if isinstance(pred, dict) and "data" in pred:
|
|
896
|
+
internal = []
|
|
897
|
+
|
|
898
|
+
for item in pred["data"]:
|
|
899
|
+
video = item["inputs"][0]["path"]
|
|
900
|
+
fps = item["inputs"][0].get("fps", self.extract_fps)
|
|
901
|
+
|
|
902
|
+
events = []
|
|
903
|
+
for ev in item.get("events", []):
|
|
904
|
+
events.append({
|
|
905
|
+
"frame": ev.get("frame"),
|
|
906
|
+
"label": ev.get("label"),
|
|
907
|
+
"confidence": ev.get("confidence"),
|
|
908
|
+
"position": int(ev.get("position_ms")),
|
|
909
|
+
"gameTime": ev.get("gameTime"),
|
|
910
|
+
})
|
|
911
|
+
|
|
912
|
+
internal.append({
|
|
913
|
+
"video": video,
|
|
914
|
+
"fps": fps,
|
|
915
|
+
"events": events,
|
|
916
|
+
})
|
|
917
|
+
|
|
918
|
+
pred = internal
|
|
919
|
+
nms_window = cfg.nms_window
|
|
920
|
+
if isinstance(pred, list):
|
|
921
|
+
if nms_window > 0:
|
|
922
|
+
logging.info("Applying NMS: " + str(nms_window))
|
|
923
|
+
pred = non_maximum_supression(pred, nms_window)
|
|
924
|
+
|
|
925
|
+
eval_dir = os.path.join(work_dir, pred_path.split(".gz")[0].split(".json")[0])
|
|
926
|
+
only_one_file = store_eval_files_json(pred, eval_dir)
|
|
927
|
+
logging.info("Done processing prediction files!")
|
|
928
|
+
if only_one_file:
|
|
929
|
+
results = os.path.join(eval_dir, "results_spotting.json")
|
|
930
|
+
else:
|
|
931
|
+
results = eval_dir
|
|
932
|
+
return self.evaluate_common_JSON(cfg, results, metric)
|
|
933
|
+
|
|
934
|
+
|
|
935
|
+
def evaluate_pred_JSON(self, cfg, work_dir, pred_path, metric="loose"):
|
|
936
|
+
"""Evaluate predictions infered with Json files and display performances.
|
|
937
|
+
Args:
|
|
938
|
+
cfg (dict): It should containt the key path (path of the groundtruth data). It should containt the key classes (list of classes) if the different classes are not in the ground truth data.
|
|
939
|
+
work_dir: The folder path under which the prediction files are stored.
|
|
940
|
+
pred_path: The path for predictions files. It can be:
|
|
941
|
+
- folder path (that contains predictions files)
|
|
942
|
+
- json file path if evaluate only one json file.
|
|
943
|
+
metric (string): metric used to evaluate.
|
|
944
|
+
In ["loose","tight","at1","at2","at3","at4","at5"].
|
|
945
|
+
Default: "loose".
|
|
946
|
+
|
|
947
|
+
Returns
|
|
948
|
+
The different mAPs computed.
|
|
949
|
+
"""
|
|
950
|
+
return self.evaluate_common_JSON(cfg, os.path.join(work_dir, pred_path), metric)
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
def evaluate_pred_SN(self, cfg, work_dir, pred_path, metric="loose"):
|
|
954
|
+
"""Evaluate predictions infered using SNV2 splits and display performances. This method should be used only for SNV2 dataset.
|
|
955
|
+
Args:
|
|
956
|
+
cfg (dict): It should containt the key path (path of the groundtruth data). It should containt the key classes (list of classes) if the different classes are not in the ground truth data.
|
|
957
|
+
work_dir: The folder path under which the prediction files are stored.
|
|
958
|
+
pred_path: The path for predictions files.
|
|
959
|
+
metric (string): metric used to evaluate.
|
|
960
|
+
In ["loose","tight","at1","at2","at3","at4","at5"].
|
|
961
|
+
Default: "loose".
|
|
962
|
+
|
|
963
|
+
Returns
|
|
964
|
+
The different mAPs computed.
|
|
965
|
+
"""
|
|
966
|
+
|
|
967
|
+
# challenge sets to be tested on EvalAI
|
|
968
|
+
if "challenge" in cfg.split:
|
|
969
|
+
print("Visit eval.ai to evaluate performances on Challenge set")
|
|
970
|
+
return None
|
|
971
|
+
# GT_path = cfg.data_root
|
|
972
|
+
pred_path = os.path.join(work_dir, pred_path)
|
|
973
|
+
results = evaluate(
|
|
974
|
+
SoccerNet_path=cfg.data_root,
|
|
975
|
+
Predictions_path=pred_path,
|
|
976
|
+
split=cfg.split,
|
|
977
|
+
prediction_file="results_spotting.json",
|
|
978
|
+
version=getattr(cfg, "version", 2),
|
|
979
|
+
metric=metric,
|
|
980
|
+
)
|
|
981
|
+
rows = []
|
|
982
|
+
for i in range(len(results["a_mAP_per_class"])):
|
|
983
|
+
label = INVERSE_EVENT_DICTIONARY_V2[i]
|
|
984
|
+
rows.append(
|
|
985
|
+
(
|
|
986
|
+
label,
|
|
987
|
+
"{:0.2f}".format(results["a_mAP_per_class"][i] * 100),
|
|
988
|
+
"{:0.2f}".format(results["a_mAP_per_class_visible"][i] * 100),
|
|
989
|
+
"{:0.2f}".format(results["a_mAP_per_class_unshown"][i] * 100),
|
|
990
|
+
)
|
|
991
|
+
)
|
|
992
|
+
rows.append(
|
|
993
|
+
(
|
|
994
|
+
"Average mAP",
|
|
995
|
+
"{:0.2f}".format(results["a_mAP"] * 100),
|
|
996
|
+
"{:0.2f}".format(results["a_mAP_visible"] * 100),
|
|
997
|
+
"{:0.2f}".format(results["a_mAP_unshown"] * 100),
|
|
998
|
+
)
|
|
999
|
+
)
|
|
1000
|
+
|
|
1001
|
+
logging.info("Best Performance at end of training ")
|
|
1002
|
+
logging.info("Metric: " + metric)
|
|
1003
|
+
print(tabulate(rows, headers=["", "Any", "Visible", "Unseen"]))
|
|
1004
|
+
# logging.info("a_mAP visibility all: " + str(results["a_mAP"]))
|
|
1005
|
+
# logging.info("a_mAP visibility all per class: " + str( results["a_mAP_per_class"]))
|
|
1006
|
+
|
|
1007
|
+
return results
|
|
1008
|
+
|
|
1009
|
+
|