opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1009 @@
1
+ """
2
+ Copyright 2022 James Hong, Haotian Zhang, Matthew Fisher, Michael Gharbi,
3
+ Kayvon Fatahalian
4
+
5
+ Redistribution and use in source and binary forms, with or without modification,
6
+ are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation and/or
13
+ other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its contributors
16
+ may be used to endorse or promote products derived from this software without
17
+ specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
23
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ """
30
+ from opensportslib.metrics.localization_metric import *
31
+ from opensportslib.core.optimizer.builder import build_optimizer
32
+ from opensportslib.core.optimizer.builder import build_optimizer
33
+ from opensportslib.core.scheduler.builder import build_scheduler
34
+ from opensportslib.core.utils.config import store_json
35
+ from opensportslib.datasets.builder import build_dataset
36
+ import os
37
+ import torch
38
+ import wandb
39
+ import time
40
+ import json
41
+ import tqdm
42
+ import numpy as np
43
+ from opensportslib.core.utils.config import load_gz_json, load_json
44
+ from abc import ABC, abstractmethod
45
+ import logging
46
+ logger = logging.getLogger(__name__)
47
+
48
+ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
49
+ """Build a trainer from config dict.
50
+
51
+ Args:
52
+ cfg (dict): Config dict. It should at least contain the key "type".
53
+ model : The model that is used to train. Needed only if E2E method because training do not rely on pytorch lightning.
54
+ Default: None.
55
+ default_args (dict | None, optional): Default initialization arguments.
56
+ Default: None.
57
+
58
+ Returns:
59
+ evaluator: The constructed trainer.
60
+ """
61
+ if cfg.TRAIN.type == "trainer_e2e":
62
+ print(cfg.SYSTEM.work_dir)
63
+ checkpoint_dir = default_args["work_dir"]
64
+ start_epoch = 0
65
+ logging.info(f"Checkpoint directory: {checkpoint_dir}")
66
+
67
+ # Handle checkpoint loading
68
+ if resume_from is not None:
69
+ if not os.path.isfile(resume_from):
70
+ raise ValueError(f"Checkpoint file not found: {resume_from}")
71
+
72
+ logging.info(f"Loading checkpoint from: {resume_from}")
73
+ checkpoint = torch.load(resume_from)
74
+
75
+ # Load model state
76
+ model.load(checkpoint['model_state_dict'])
77
+ logging.info("Model state loaded successfully")
78
+
79
+ # Get current training progress
80
+ start_epoch = checkpoint['epoch'] + 1
81
+ logging.info(f"Resuming from epoch {start_epoch}")
82
+
83
+ # Check if we've already reached target epochs
84
+ if start_epoch >= cfg.TRAIN.num_epochs:
85
+ logging.error(f"Model already trained for {start_epoch} epochs")
86
+ logging.error(f"Target epochs in config: {cfg.TRAIN.num_epochs}")
87
+ logging.error("Please increase num_epochs in config to continue training")
88
+ raise ValueError("Need to increase num_epochs to continue training")
89
+
90
+ logging.info(f"Will continue training from epoch {start_epoch} to {cfg.TRAIN.num_epochs}")
91
+
92
+ logging.info("Building optimizer...")
93
+ optimizer, scaler = build_optimizer(model._get_params(), cfg.TRAIN.optimizer)
94
+
95
+ # Load optimizer state if available in checkpoint
96
+ if resume_from is not None and 'optimizer_state_dict' in checkpoint:
97
+ try:
98
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
99
+ scaler.load_state_dict(checkpoint['scaler_state_dict'])
100
+ logging.info("Optimizer and scaler states loaded")
101
+ except Exception as e:
102
+ logging.warning(f"Could not load optimizer state: {e}")
103
+ logging.warning("Will start with fresh optimizer state")
104
+
105
+ logging.info("Building scheduler...")
106
+ lr_scheduler = build_scheduler(optimizer, cfg.TRAIN.scheduler, default_args)
107
+
108
+ # Load scheduler state if available
109
+ if resume_from is not None and 'lr_state_dict' in checkpoint:
110
+ try:
111
+ lr_scheduler.load_state_dict(checkpoint['lr_state_dict'])
112
+ logging.info("Scheduler state loaded")
113
+ except Exception as e:
114
+ logging.warning(f"Could not load scheduler state: {e}")
115
+ logging.warning("Will start with fresh scheduler state")
116
+
117
+ trainer = Trainer_e2e(
118
+ cfg,
119
+ model,
120
+ optimizer,
121
+ scaler,
122
+ lr_scheduler,
123
+ default_args["work_dir"],
124
+ default_args["dali"],
125
+ default_args["repartitions"],
126
+ default_args["cfg_test"],
127
+ #default_args["cfg_challenge"],
128
+ default_args["cfg_valid_data_frames"],
129
+ start_epoch=start_epoch
130
+ )
131
+
132
+ # Load training history if resuming
133
+ if resume_from is not None:
134
+ trainer.best_epoch = checkpoint.get('best_epoch', 0)
135
+ trainer.best_criterion_valid = checkpoint.get('best_criterion_valid',
136
+ 0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
137
+ logging.info(f"Restored best epoch: {trainer.best_epoch}")
138
+
139
+
140
+ return trainer
141
+
142
+ class Trainer(ABC):
143
+ def __init__(self):
144
+ pass
145
+
146
+ @abstractmethod
147
+ def train(self):
148
+ pass
149
+
150
+
151
+
152
+ class Trainer_e2e(Trainer):
153
+ """Trainer class used for the e2e model.
154
+
155
+ Args:
156
+ args (dict): Dict of config.
157
+ model.
158
+ optimizer (torch.optim.Optimizer): The optimizer to update model parameters. Set to None if validation epoch.
159
+ scaler (torch.cuda.amp.GradScaler): The gradient scaler for mixed precision training.
160
+ lr_scheduler : The learning rate scheduler.
161
+ work_dir (string): The folder in which the different files will be saved.
162
+ dali (bool): Whether videos are processed with dali or opencv.
163
+ repartitions (List[int]): List of gpus used data processing.
164
+ Default: None.
165
+ cfg_test (dict): Dict of config for the inference (testing purpose) and evaluation of the test split. Occurs once training is done.
166
+ Default: None.
167
+ cfg_challenge (dict): Dict of config for the inference (testing purpose) of the challenge split. Occurs once training is done.
168
+ Default: None.
169
+ cfg_valid_data_frames (dict): Dict of config for the inference (testing purpose) and evaluation of the valid split. Occurs through the epochs after a certain number of epochs only if the criterion for the valid split is 'map'.
170
+ Default: None.
171
+ """
172
+
173
+ def __init__(
174
+ self,
175
+ args,
176
+ model,
177
+ optimizer,
178
+ scaler,
179
+ lr_scheduler,
180
+ work_dir,
181
+ dali,
182
+ repartitions=None,
183
+ cfg_test=None,
184
+ #cfg_challenge=None,
185
+ cfg_valid_data_frames=None,
186
+ start_epoch=0
187
+ ):
188
+ self.config = args
189
+ self.losses = []
190
+ self.best_epoch = 0
191
+ self.best_criterion_valid = 0 if args.TRAIN.criterion_valid == "map" else float("inf")
192
+
193
+ self.num_epochs = args.TRAIN.num_epochs
194
+ self.epoch = start_epoch
195
+ self.model = model
196
+
197
+ self.optimizer = optimizer
198
+ self.scaler = scaler
199
+ self.lr_scheduler = lr_scheduler
200
+
201
+ self.acc_grad_iter = args.TRAIN.acc_grad_iter
202
+
203
+ self.start_valid_epoch = args.TRAIN.start_valid_epoch
204
+ self.criterion_valid = args.TRAIN.criterion_valid
205
+ self.valid_map_every = args.TRAIN.valid_map_every
206
+ #self.save_dir = work_dir
207
+ self.dali = dali
208
+
209
+ self.repartitions = repartitions
210
+ self.cfg_test = cfg_test
211
+ #self.cfg_challenge = cfg_challenge
212
+ self.cfg_valid_data_frames = cfg_valid_data_frames
213
+
214
+ self.best_checkpoint_path = None
215
+
216
+ self.save_dir = work_dir #os.path.join(work_dir, run_name, run_id)
217
+ os.makedirs(self.save_dir, exist_ok=True)
218
+ try:
219
+ wandb.watch(self.model, log="gradients", log_freq=100)
220
+ except Exception:
221
+ pass
222
+
223
+ def save_checkpoint(self, epoch, is_best=False):
224
+ """Save checkpoint with training state."""
225
+ checkpoint = {
226
+ 'epoch': epoch,
227
+ 'model_state_dict': self.model.state_dict(),
228
+ 'optimizer_state_dict': self.optimizer.state_dict(),
229
+ 'scaler_state_dict': self.scaler.state_dict(),
230
+ 'lr_state_dict': self.lr_scheduler.state_dict(),
231
+ 'best_epoch': self.best_epoch,
232
+ 'best_criterion_valid': self.best_criterion_valid
233
+ }
234
+
235
+ os.makedirs(self.save_dir, exist_ok=True)
236
+ # Save latest checkpoint
237
+ # latest_path = os.path.join(self.save_dir, f"latest_checkpoint_{epoch:03d}.pt")
238
+ # torch.save(checkpoint, latest_path)
239
+ # logging.info(f"Latest checkpoint saved: {latest_path}")
240
+
241
+ # # Remove previous latest checkpoint
242
+ # for f in os.listdir(self.save_dir):
243
+ # if f.startswith("latest_checkpoint_") and f != os.path.basename(latest_path):
244
+ # os.remove(os.path.join(self.save_dir, f))
245
+
246
+ latest_path = os.path.join(self.save_dir, "latest_checkpoint.pt")
247
+ torch.save(checkpoint, latest_path)
248
+ logging.info(f"Latest checkpoint saved: {latest_path}")
249
+
250
+ # Save best checkpoint if needed
251
+ # if is_best:
252
+ # # best_path = os.path.join(self.save_dir, f"best_checkpoint_{epoch:03d}.pt")
253
+ # best_path = os.path.join(self.save_dir, f"best_checkpoint_{epoch:03d}.pt")
254
+ # torch.save(checkpoint, best_path)
255
+ # logging.info(f"Best checkpoint saved: {best_path}")
256
+
257
+ # # Remove previous best checkpoint
258
+ # for f in os.listdir(self.save_dir):
259
+ # if f.startswith("best_checkpoint_") and f != os.path.basename(best_path):
260
+ # os.remove(os.path.join(self.save_dir, f))
261
+
262
+ if is_best:
263
+ best_path = os.path.join(self.save_dir, "best_checkpoint.pt")
264
+ self.best_checkpoint_path = best_path
265
+ torch.save(checkpoint, best_path)
266
+ logging.info(f"Best checkpoint saved: {best_path}")
267
+
268
+ def train(self, train_loader, valid_loader, classes):
269
+ """Training loop with checkpoint management."""
270
+ if self.criterion_valid == "map":
271
+ data_obj_valid = build_dataset(self.config, split="valid_data_frames")
272
+ dataset_Valid_Frames = data_obj_valid.building_dataset(
273
+ data_obj_valid.cfg,
274
+ None,
275
+ {"repartitions": self.repartitions, "classes": classes},
276
+ )
277
+
278
+ for epoch in range(self.epoch, self.num_epochs):
279
+ train_loss = self.model.epoch(
280
+ train_loader,
281
+ self.dali,
282
+ self.optimizer,
283
+ self.scaler,
284
+ lr_scheduler=self.lr_scheduler,
285
+ acc_grad_iter=self.acc_grad_iter,
286
+ )
287
+
288
+ valid_loss = self.model.epoch(
289
+ valid_loader, self.dali, acc_grad_iter=self.acc_grad_iter
290
+ )
291
+ print(
292
+ f"[Epoch {epoch+1}/{self.num_epochs}] Train loss: {train_loss:.5f} Valid loss: {valid_loss:.5f}"
293
+ )
294
+ logging.info(
295
+ f"[Epoch {epoch+1}/{self.num_epochs}] Train loss: {train_loss:.5f} Valid loss: {valid_loss:.5f}"
296
+ )
297
+
298
+ valid_mAP = 0
299
+ is_best = False
300
+
301
+ if self.criterion_valid == "loss":
302
+ if valid_loss < self.best_criterion_valid:
303
+ self.best_criterion_valid = valid_loss
304
+ self.best_epoch = epoch
305
+ is_best = True
306
+ print("New best epoch!")
307
+ elif self.criterion_valid == "map":
308
+ if epoch >= self.start_valid_epoch and epoch % self.valid_map_every == 0:
309
+ pred_file = None
310
+ if self.save_dir is not None:
311
+ pred_file = os.path.join(
312
+ self.save_dir, f"pred-valid_{epoch:03d}"
313
+ )
314
+ os.makedirs(self.save_dir, exist_ok=True)
315
+ valid_mAP = infer_and_process_predictions_e2e(
316
+ self.model,
317
+ self.dali,
318
+ dataset_Valid_Frames,
319
+ "VALID",
320
+ classes,
321
+ pred_file,
322
+ dataloader_params=self.cfg_valid_data_frames.dataloader,
323
+ )
324
+ if valid_mAP > self.best_criterion_valid:
325
+ self.best_criterion_valid = valid_mAP
326
+ self.best_epoch = epoch
327
+ is_best = True
328
+ print("New best epoch!")
329
+ else:
330
+ print("Unknown criterion:", self.criterion_valid)
331
+
332
+ self.losses.append(
333
+ {
334
+ "epoch": epoch,
335
+ "train": train_loss,
336
+ "valid": valid_loss,
337
+ "valid_mAP": valid_mAP,
338
+ }
339
+ )
340
+
341
+ # ---------------- W&B LOG ----------------
342
+ wandb.log({
343
+ "epoch": epoch + 1,
344
+ "train/loss": train_loss,
345
+ "valid/loss": valid_loss,
346
+ "valid/mAP": valid_mAP,
347
+ "lr": self.optimizer.param_groups[0]["lr"],
348
+ "best/mAP": self.best_criterion_valid if self.criterion_valid == "map" else None,
349
+ "best/loss": self.best_criterion_valid if self.criterion_valid == "loss" else None,
350
+ })
351
+
352
+ if self.save_dir is not None:
353
+ os.makedirs(self.save_dir, exist_ok=True)
354
+ store_json(
355
+ os.path.join(self.save_dir, "loss.json"),
356
+ self.losses,
357
+ pretty=True
358
+ )
359
+ self.save_checkpoint(epoch, is_best)
360
+
361
+ logging.info(f"Training completed. Best epoch: {self.best_epoch}")
362
+
363
+ if self.dali:
364
+ train_loader.delete()
365
+ valid_loader.delete()
366
+ if self.criterion_valid == "map":
367
+ dataset_Valid_Frames.delete()
368
+
369
+ if self.save_dir is not None:
370
+ self._run_final_evaluation(classes, eval_splits=["valid"])
371
+
372
+ def _run_final_evaluation(self, classes, eval_splits=["valid", "test"]):
373
+ from opensportslib.core.utils.checkpoint import load_checkpoint, localization_remap
374
+ """Run final evaluation using best model."""
375
+ # Load best model for evaluation
376
+ best_checkpoint_path = os.path.join(
377
+ self.save_dir, f"best_checkpoint.pt"
378
+ )
379
+ self.model._model, _, _, epoch = load_checkpoint(model=self.model._model,
380
+ path=best_checkpoint_path,
381
+ key_remap_fn=localization_remap)
382
+ logging.info(f"Loaded best model from epoch {self.best_epoch}")
383
+
384
+ for split in eval_splits:
385
+ if split == "valid":
386
+ cfg_tmp = self.cfg_valid_data_frames
387
+ split = "valid_data_frames"
388
+ elif split == "test":
389
+ cfg_tmp = self.cfg_test
390
+ # elif split == "challenge":
391
+ # cfg_tmp = self.cfg_challenge
392
+
393
+ split_path = os.path.join(cfg_tmp.path)
394
+ if not os.path.exists(split_path):
395
+ continue
396
+
397
+ data_obj = build_dataset(self.config, split=split)
398
+ split_data = data_obj.building_dataset(
399
+ data_obj.cfg,
400
+ None,
401
+ {"repartitions": self.repartitions, "classes": classes},
402
+ )
403
+ split_data.print_info()
404
+
405
+ pred_file = None
406
+ if self.save_dir is not None:
407
+ pred_file = os.path.join(
408
+ self.save_dir, f"pred-{split}_{self.best_epoch:03d}"
409
+ )
410
+
411
+ infer_and_process_predictions_e2e(
412
+ self.model,
413
+ self.dali,
414
+ split_data,
415
+ split.upper(),
416
+ classes,
417
+ pred_file,
418
+ calc_stats=split != "challenge",
419
+ dataloader_params=cfg_tmp.dataloader,
420
+ )
421
+
422
+ if self.dali:
423
+ split_data.delete()
424
+
425
+ logging.info(f"Final evaluation completed. Best epoch: {self.best_epoch}")
426
+
427
+
428
+ def build_inferer(cfg, model, default_args=None):
429
+ """Build a inferer from config dict.
430
+
431
+ Args:
432
+ cfg (dict): Config dict. It should at least contain the key "type".
433
+ model: The model that will be used to infer.
434
+ default_args (dict | None, optional): Default initialization arguments.
435
+ Default: None.
436
+
437
+ Returns:
438
+ inferer: The constructed inferer.
439
+ """
440
+
441
+ if cfg.runner.type == "runner_JSON":
442
+ inferer = Inferer(cfg=cfg, model=model, infer_Spotting="infer_JSON")
443
+ elif cfg.runner.type == "runner_pooling":
444
+ inferer = Inferer(cfg=cfg, model=model, infer_Spotting="infer_SN")
445
+ elif cfg.runner.type == "runner_CALF":
446
+ inferer = Inferer(cfg=cfg, model=model, infer_Spotting="infer_SN")
447
+ elif cfg.runner.type == "runner_e2e":
448
+ inferer = Inferer(cfg=cfg, model=model, infer_Spotting="infer_E2E")
449
+
450
+ return inferer
451
+
452
+ class Inferer:
453
+ def __init__(self, cfg, model, infer_Spotting):
454
+ """Initialize the Inferer class.
455
+
456
+ Args:
457
+ cfg (dict): Config dict. It should at least contain the key "type".
458
+ model: The model that will be used to infer.
459
+ infer_Spotting: The method that is used to infer.
460
+ """
461
+ self.cfg_model = cfg
462
+ self.model = model
463
+ self.infer_Spotting=infer_Spotting
464
+
465
+ def infer(self, cfg, data):
466
+ """Infer actions from data.
467
+
468
+ Args:
469
+ data : The data from which we will infer.
470
+
471
+ Returns:
472
+ Dict containing predictions
473
+ """
474
+ if self.infer_Spotting=="infer_JSON":
475
+ return self.infer_JSON(cfg, self.model, data)
476
+ elif self.infer_Spotting=="infer_SN":
477
+ return self.infer_SN(cfg, self.model, data)
478
+ elif self.infer_Spotting=="infer_E2E":
479
+ return self.infer_E2E(cfg, self.model, data)
480
+
481
+
482
+ def infer_common(self, cfg, model, data):
483
+ """Infer actions from data using a given model.
484
+
485
+ Args:
486
+ cfg (dict): Config dict. It should at least contain the key "type".
487
+ model: The model that will be used to infer.
488
+ data : The data from which we will infer.
489
+
490
+ Returns:
491
+ Dict containing predictions
492
+ """
493
+ # Run Inference on Dataset
494
+ pass
495
+
496
+
497
+ def infer_JSON(self, cfg, model, data):
498
+ """Infer actions from data using a given model for NetVlad/CALF methods
499
+
500
+ Args:
501
+ cfg (dict): Config dict. It should at least contain the key "type".
502
+ model: The model that will be used to infer.
503
+ data : The data from which we will infer.
504
+
505
+ Returns:
506
+ Dict containing predictions
507
+ """
508
+ return self.infer_common(cfg, model, data)
509
+
510
+
511
+ def infer_SN(self, cfg, model, data):
512
+ """Infer actions from data using a given model for the SNV2 data
513
+
514
+ Args:
515
+ cfg (dict): Config dict. It should at least contain the key "type".
516
+ model: The model that will be used to infer.
517
+ data : The data from which we will infer.
518
+
519
+ Returns:
520
+ Dict containing predictions
521
+ """
522
+ return self.infer_common(cfg, model, data)
523
+
524
+
525
+ def infer_E2E(self, cfg, model, data):
526
+ """Infer actions from data using a given model for the e2espot method.
527
+
528
+ Args:
529
+ cfg (dict): Config dict. It should at least contain the key "type".
530
+ model: The model that will be used to infer.
531
+ data : The data from which we will infer.
532
+
533
+ Returns:
534
+ Dict containing predictions
535
+ """
536
+ pred_file = None
537
+ if cfg.SYSTEM.work_dir is not None:
538
+ pred_file = os.path.join(cfg.SYSTEM.work_dir, cfg.DATA.test.results)
539
+ mAP = infer_and_process_predictions_e2e(
540
+ model,
541
+ getattr(cfg, "dali", False),
542
+ data,
543
+ "infer",
544
+ cfg.DATA.classes,
545
+ pred_file,
546
+ True,
547
+ cfg.DATA.test.dataloader,
548
+ return_pred=False,
549
+ )
550
+ wandb.log({
551
+ "test/Avg_mAP": mAP,
552
+ })
553
+ pred_json_file = os.path.join(pred_file + ".json")
554
+ pred_recall_file = os.path.join(pred_file + ".recall.json.gz")
555
+ logging.info("Predictions saved")
556
+ logging.info(pred_json_file)
557
+ logging.info("High recall predictions saved")
558
+ logging.info(pred_recall_file)
559
+ #json_gz_file = cfg.DATA.test.results + ".recall.json.gz"
560
+ return pred_recall_file
561
+
562
+
563
+ def build_evaluator(cfg, default_args=None):
564
+ """Build a evaluator from config dict.
565
+
566
+ Args:
567
+ cfg (dict): Config dict. It should at least contain the key "type".
568
+ default_args (dict | None, optional): Default initialization arguments.
569
+ Default: None.
570
+
571
+ Returns:
572
+ evaluator: The constructed evaluator.
573
+ """
574
+ if cfg.MODEL.runner.type == "runner_JSON":
575
+ evaluator = Evaluator(cfg=cfg, evaluate_Spotting="evaluate_pred_JSON")
576
+ elif cfg.MODEL.runner.type == "runner_pooling":
577
+ evaluator = Evaluator(cfg=cfg, evaluate_Spotting="evaluate_pred_SN")
578
+ elif cfg.MODEL.runner.type == "runner_CALF":
579
+ evaluator = Evaluator(cfg=cfg, evaluate_Spotting="evaluate_pred_SN")
580
+ elif cfg.MODEL.runner.type == "runner_e2e":
581
+ evaluator = Evaluator(cfg=cfg, evaluate_Spotting="evaluate_pred_E2E")
582
+
583
+ return evaluator
584
+
585
+
586
+ class Evaluator:
587
+ """Evaluator class that is used to make easier the process of evaluate since there is only
588
+ one evaluate method that uses the evaluate_Spotting method.
589
+
590
+ Args:
591
+ cfg (dict): Config dict.
592
+ evaluate_Spotting (method): The method that is used to evaluate.
593
+ """
594
+
595
+ def __init__(self, cfg, evaluate_Spotting):
596
+ self.cfg = cfg
597
+ self.extract_fps = getattr(cfg.DATA, "extract_fps", 2)
598
+ self.evaluate_Spotting = evaluate_Spotting
599
+
600
+ def evaluate(self, cfg_testset, json_gz_file=None):
601
+ """Evaluate predictions.
602
+
603
+ Args:
604
+ cfg_testset (dict): Config dict that contains informations for the predictions.
605
+ """
606
+ if self.evaluate_Spotting == "evaluate_pred_JSON":
607
+ return self.evaluate_pred_JSON(cfg_testset, self.cfg.SYSTEM.work_dir, json_gz_file, metric=cfg_testset.metric)
608
+ elif self.evaluate_Spotting == "evaluate_pred_SN":
609
+ return self.evaluate_pred_SN(cfg_testset, self.cfg.SYSTEM.work_dir, json_gz_file, metric=cfg_testset.metric)
610
+ elif self.evaluate_Spotting == "evaluate_pred_E2E":
611
+ return self.evaluate_pred_E2E(cfg_testset, self.cfg.SYSTEM.work_dir, json_gz_file, metric=cfg_testset.metric)
612
+
613
+
614
+ # def evaluate_common_JSON(self, cfg, results, metric):
615
+ # if cfg.path == None:
616
+ # return
617
+ # with open(cfg.path) as f:
618
+ # GT_data = json.load(f)
619
+
620
+ # print(results)
621
+ # pred_path_is_json = False
622
+ # if results.endswith(".json"):
623
+ # pred_path_is_json = True
624
+ # with open(results) as f:
625
+ # pred_data = json.load(f)
626
+
627
+ # targets_numpy = list()
628
+ # detections_numpy = list()
629
+ # closests_numpy = list()
630
+
631
+ # if "labels" in GT_data.keys():
632
+ # classes = GT_data["labels"]
633
+ # else:
634
+ # assert isinstance(cfg.classes, list) or os.path.isfile(cfg.classes)
635
+ # if isinstance(cfg.classes, list):
636
+ # classes = cfg.classes
637
+
638
+ # classes = sorted(classes)
639
+ # EVENT_DICTIONARY = {x: i for i, x in enumerate(classes)}
640
+ # INVERSE_EVENT_DICTIONARY = {i: x for i, x in enumerate(classes)}
641
+
642
+ # if "videos" in GT_data.keys():
643
+ # videos = GT_data["videos"]
644
+ # else:
645
+ # videos = [GT_data]
646
+
647
+ # for game in tqdm.tqdm(videos):
648
+ # print(game.keys())
649
+ # # fetch labels
650
+ # labels = game["annotations"]
651
+ # if not pred_path_is_json:
652
+ # try:
653
+ # pred_file = os.path.join(results, os.path.splitext(game["path"])[0], "results_spotting.json")
654
+ # print(pred_file)
655
+ # with open(pred_file) as f:
656
+ # pred_data = json.load(f)
657
+ # except FileNotFoundError:
658
+ # continue
659
+ # predictions = pred_data["predictions"]
660
+ # # convert labels to dense vector
661
+ # dense_labels = label2vector(
662
+ # labels,
663
+ # num_classes=len(classes),
664
+ # EVENT_DICTIONARY=EVENT_DICTIONARY,
665
+ # framerate=(
666
+ # pred_data["fps"] if "fps" in pred_data.keys() else self.extract_fps
667
+ # ),
668
+ # )
669
+ # print(dense_labels.shape)
670
+ # # convert predictions to vector
671
+ # dense_predictions = predictions2vector(
672
+ # predictions,
673
+ # vector_size=game["num_frames"] if "num_frames" in game.keys() else None,
674
+ # framerate=(
675
+ # pred_data["fps"] if "fps" in pred_data.keys() else self.extract_fps
676
+ # ),
677
+ # num_classes=len(classes),
678
+ # EVENT_DICTIONARY=EVENT_DICTIONARY,
679
+ # )
680
+ # print(dense_predictions.shape)
681
+
682
+ # targets_numpy.append(dense_labels)
683
+ # detections_numpy.append(dense_predictions)
684
+
685
+ # closest_numpy = np.zeros(dense_labels.shape) - 1
686
+ # # Get the closest action index
687
+ # closests_numpy.append(get_closest_action_index(dense_labels, closest_numpy))
688
+
689
+ # if targets_numpy:
690
+ # return compute_performances_mAP(
691
+ # metric,
692
+ # targets_numpy,
693
+ # detections_numpy,
694
+ # closests_numpy,
695
+ # INVERSE_EVENT_DICTIONARY,
696
+ # )
697
+ # else:
698
+ # logging.warning("No predictions found for evaluation. Returning None.")
699
+ # return None
700
+
701
+
702
+
703
+ def evaluate_common_JSON(self, cfg, results, metric):
704
+
705
+ if cfg.path is None:
706
+ return
707
+
708
+ # --------------------------------------------------
709
+ # LOAD GT
710
+ # --------------------------------------------------
711
+ with open(cfg.path) as f:
712
+ GT_data = json.load(f)
713
+
714
+ # --------------------------------------------------
715
+ # LOAD PRED FILE (json / json.gz / folder)
716
+ # --------------------------------------------------
717
+ pred_data = None
718
+ pred_path_is_file = results.endswith(".json") or results.endswith(".json.gz")
719
+
720
+ if pred_path_is_file:
721
+ pred_data = load_gz_json(results) if results.endswith(".gz") else load_json(results)
722
+
723
+ # detect v2 prediction
724
+ pred_is_v2 = isinstance(pred_data, dict) and pred_data is not None and "data" in pred_data
725
+ # --------------------------------------------------
726
+ # CLASSES
727
+ # --------------------------------------------------
728
+ if isinstance(GT_data.get("labels"), dict):
729
+ classes = list(GT_data["labels"].values())[0]["labels"]
730
+ elif "labels" in GT_data:
731
+ classes = GT_data["labels"]
732
+ else:
733
+ classes = cfg.classes
734
+
735
+ classes = sorted(classes)
736
+ EVENT_DICTIONARY = {x: i for i, x in enumerate(classes)}
737
+ INVERSE_EVENT_DICTIONARY = {i: x for i, x in enumerate(classes)}
738
+
739
+ # --------------------------------------------------
740
+ # GT VIDEOS
741
+ # --------------------------------------------------
742
+ if "videos" in GT_data:
743
+ videos = GT_data["videos"]
744
+ gt_is_v2 = False
745
+ else:
746
+ videos = GT_data["data"]
747
+ gt_is_v2 = True
748
+
749
+ # --------------------------------------------------
750
+ # BUILD PRED LOOKUP IF V2
751
+ # --------------------------------------------------
752
+ pred_lookup = {}
753
+ if pred_is_v2:
754
+ for item in pred_data["data"]:
755
+ video_path = item["inputs"][0]["path"]
756
+ pred_lookup[video_path] = item
757
+
758
+ targets_numpy = []
759
+ detections_numpy = []
760
+ closests_numpy = []
761
+
762
+ # ==================================================
763
+ # LOOP
764
+ # ==================================================
765
+ for game in tqdm.tqdm(videos):
766
+
767
+ # ---------------- GT ----------------
768
+ if gt_is_v2:
769
+ video_path = game["inputs"][0]["path"]
770
+ labels = [{"label": e.get("label"),
771
+ "gameTime": e.get("gameTime"),
772
+ "position": int(e.get("position_ms")),
773
+ } for e in game.get("events", [])]
774
+ else:
775
+ video_path = game["path"]
776
+ labels = game["annotations"]
777
+
778
+ # ---------------- PRED ----------------
779
+ if pred_path_is_file:
780
+
781
+ # ===== V2 PRED =====
782
+ if pred_is_v2:
783
+ if video_path not in pred_lookup:
784
+ continue
785
+
786
+ item = pred_lookup[video_path]
787
+ fps = item["inputs"][0].get("fps", self.extract_fps)
788
+
789
+ predictions = [
790
+ {
791
+ "label": e.get("label"),
792
+ "gameTime": e.get("gameTime"),
793
+ "confidence": e.get("confidence"),
794
+ "position": int(e.get("position_ms")),
795
+ "frame": e.get("frame")
796
+ }
797
+ for e in item.get("events", [])
798
+ ]
799
+
800
+ # ===== OLD PRED =====
801
+ else:
802
+ if "predictions" not in pred_data:
803
+ continue
804
+
805
+ predictions = pred_data["predictions"]
806
+ fps = pred_data.get("fps", self.extract_fps)
807
+
808
+ else:
809
+ # ===== FOLDER MODE =====
810
+ pred_file = os.path.join(results, os.path.splitext(video_path)[0], "results_spotting.json")
811
+
812
+ if not os.path.exists(pred_file):
813
+ continue
814
+
815
+ with open(pred_file) as f:
816
+ pred_data_local = json.load(f)
817
+
818
+ if "data" in pred_data_local:
819
+ # v2 file inside folder
820
+ item = pred_data_local["data"][0]
821
+ fps = item["inputs"][0].get("fps", self.extract_fps)
822
+
823
+ predictions = [
824
+ {
825
+ "label": e.get("label"),
826
+ "gameTime": e.get("gameTime"),
827
+ "confidence": e.get("confidence"),
828
+ "position": int(e.get("position_ms")),
829
+ "frame": e.get("frame")
830
+ }
831
+ for e in item.get("events", [])
832
+ ]
833
+ else:
834
+ predictions = pred_data_local["predictions"]
835
+ fps = pred_data_local.get("fps", self.extract_fps)
836
+
837
+ # ---------------- VECTORS ----------------
838
+ dense_labels = label2vector(labels, num_classes=len(classes), EVENT_DICTIONARY=EVENT_DICTIONARY, framerate=fps)
839
+
840
+ dense_predictions = predictions2vector(
841
+ predictions,
842
+ vector_size=None,
843
+ framerate=fps,
844
+ num_classes=len(classes),
845
+ EVENT_DICTIONARY=EVENT_DICTIONARY,
846
+ )
847
+
848
+ targets_numpy.append(dense_labels)
849
+ detections_numpy.append(dense_predictions)
850
+
851
+ closest_numpy = np.zeros(dense_labels.shape) - 1
852
+ closests_numpy.append(get_closest_action_index(dense_labels, closest_numpy))
853
+
854
+ # --------------------------------------------------
855
+ # METRICS
856
+ # --------------------------------------------------
857
+ if targets_numpy:
858
+ return compute_performances_mAP(
859
+ metric,
860
+ targets_numpy,
861
+ detections_numpy,
862
+ closests_numpy,
863
+ INVERSE_EVENT_DICTIONARY,
864
+ )
865
+ else:
866
+ logging.warning("No predictions found.")
867
+ return None
868
+
869
+ def evaluate_pred_E2E(self, cfg, work_dir, pred_path, metric="loose"):
870
+ """Evaluate predictions infered with E2E method and display performances.
871
+ Args:
872
+ cfg (dict): It should containt the keys; classes (list of classes), path (path of the groundtruth data).
873
+ It should contain the key nms_window if evaluation of raw predictions. It should containt the key extract_fps if predictions file do not contain the fps at which the frames have been processed to infer.
874
+ work_dir: The folder path under which the prediction files are stored.
875
+ pred_path: The path for predictions files. It can be:
876
+ - folder path (that contains predictions files)
877
+ - file path (if raw prediction that needs to be processed first)
878
+ metric (string): metric used to evaluate.
879
+ In ["loose","tight","at1","at2","at3","at4","at5"].
880
+ Default: "loose".
881
+
882
+ Returns
883
+ The different mAPs computed.
884
+ """
885
+
886
+ results = pred_path
887
+
888
+ if os.path.isfile(results) and (
889
+ results.endswith(".gz") or results.endswith(".json")
890
+ ):
891
+ pred = (load_gz_json if results.endswith(".gz") else load_json)(results)
892
+ # --------------------------------------------------
893
+ # SUPPORT NEW V2 FORMAT (dict)
894
+ # --------------------------------------------------
895
+ if isinstance(pred, dict) and "data" in pred:
896
+ internal = []
897
+
898
+ for item in pred["data"]:
899
+ video = item["inputs"][0]["path"]
900
+ fps = item["inputs"][0].get("fps", self.extract_fps)
901
+
902
+ events = []
903
+ for ev in item.get("events", []):
904
+ events.append({
905
+ "frame": ev.get("frame"),
906
+ "label": ev.get("label"),
907
+ "confidence": ev.get("confidence"),
908
+ "position": int(ev.get("position_ms")),
909
+ "gameTime": ev.get("gameTime"),
910
+ })
911
+
912
+ internal.append({
913
+ "video": video,
914
+ "fps": fps,
915
+ "events": events,
916
+ })
917
+
918
+ pred = internal
919
+ nms_window = cfg.nms_window
920
+ if isinstance(pred, list):
921
+ if nms_window > 0:
922
+ logging.info("Applying NMS: " + str(nms_window))
923
+ pred = non_maximum_supression(pred, nms_window)
924
+
925
+ eval_dir = os.path.join(work_dir, pred_path.split(".gz")[0].split(".json")[0])
926
+ only_one_file = store_eval_files_json(pred, eval_dir)
927
+ logging.info("Done processing prediction files!")
928
+ if only_one_file:
929
+ results = os.path.join(eval_dir, "results_spotting.json")
930
+ else:
931
+ results = eval_dir
932
+ return self.evaluate_common_JSON(cfg, results, metric)
933
+
934
+
935
+ def evaluate_pred_JSON(self, cfg, work_dir, pred_path, metric="loose"):
936
+ """Evaluate predictions infered with Json files and display performances.
937
+ Args:
938
+ cfg (dict): It should containt the key path (path of the groundtruth data). It should containt the key classes (list of classes) if the different classes are not in the ground truth data.
939
+ work_dir: The folder path under which the prediction files are stored.
940
+ pred_path: The path for predictions files. It can be:
941
+ - folder path (that contains predictions files)
942
+ - json file path if evaluate only one json file.
943
+ metric (string): metric used to evaluate.
944
+ In ["loose","tight","at1","at2","at3","at4","at5"].
945
+ Default: "loose".
946
+
947
+ Returns
948
+ The different mAPs computed.
949
+ """
950
+ return self.evaluate_common_JSON(cfg, os.path.join(work_dir, pred_path), metric)
951
+
952
+
953
+ def evaluate_pred_SN(self, cfg, work_dir, pred_path, metric="loose"):
954
+ """Evaluate predictions infered using SNV2 splits and display performances. This method should be used only for SNV2 dataset.
955
+ Args:
956
+ cfg (dict): It should containt the key path (path of the groundtruth data). It should containt the key classes (list of classes) if the different classes are not in the ground truth data.
957
+ work_dir: The folder path under which the prediction files are stored.
958
+ pred_path: The path for predictions files.
959
+ metric (string): metric used to evaluate.
960
+ In ["loose","tight","at1","at2","at3","at4","at5"].
961
+ Default: "loose".
962
+
963
+ Returns
964
+ The different mAPs computed.
965
+ """
966
+
967
+ # challenge sets to be tested on EvalAI
968
+ if "challenge" in cfg.split:
969
+ print("Visit eval.ai to evaluate performances on Challenge set")
970
+ return None
971
+ # GT_path = cfg.data_root
972
+ pred_path = os.path.join(work_dir, pred_path)
973
+ results = evaluate(
974
+ SoccerNet_path=cfg.data_root,
975
+ Predictions_path=pred_path,
976
+ split=cfg.split,
977
+ prediction_file="results_spotting.json",
978
+ version=getattr(cfg, "version", 2),
979
+ metric=metric,
980
+ )
981
+ rows = []
982
+ for i in range(len(results["a_mAP_per_class"])):
983
+ label = INVERSE_EVENT_DICTIONARY_V2[i]
984
+ rows.append(
985
+ (
986
+ label,
987
+ "{:0.2f}".format(results["a_mAP_per_class"][i] * 100),
988
+ "{:0.2f}".format(results["a_mAP_per_class_visible"][i] * 100),
989
+ "{:0.2f}".format(results["a_mAP_per_class_unshown"][i] * 100),
990
+ )
991
+ )
992
+ rows.append(
993
+ (
994
+ "Average mAP",
995
+ "{:0.2f}".format(results["a_mAP"] * 100),
996
+ "{:0.2f}".format(results["a_mAP_visible"] * 100),
997
+ "{:0.2f}".format(results["a_mAP_unshown"] * 100),
998
+ )
999
+ )
1000
+
1001
+ logging.info("Best Performance at end of training ")
1002
+ logging.info("Metric: " + metric)
1003
+ print(tabulate(rows, headers=["", "Any", "Visible", "Unseen"]))
1004
+ # logging.info("a_mAP visibility all: " + str(results["a_mAP"]))
1005
+ # logging.info("a_mAP visibility all per class: " + str( results["a_mAP_per_class"]))
1006
+
1007
+ return results
1008
+
1009
+