opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1482 @@
1
+ """
2
+ Copyright 2022 James Hong, Haotian Zhang, Matthew Fisher, Michael Gharbi,
3
+ Kayvon Fatahalian
4
+
5
+ Redistribution and use in source and binary forms, with or without modification,
6
+ are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation and/or
13
+ other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its contributors
16
+ may be used to endorse or promote products derived from this software without
17
+ specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
23
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ """
30
+
31
+ from collections import defaultdict
32
+ import os
33
+ import numpy as np
34
+ from tabulate import tabulate
35
+ from tqdm import tqdm
36
+ from torch.utils.data import DataLoader
37
+ from opensportslib.core.utils.config import load_json, store_gz_json, store_json
38
+ from opensportslib.core.utils.wandb import log_table_wandb
39
+ # from oslactionspotting.core.utils.score import compute_mAPs_E2E
40
+ import time
41
+ import json
42
+ import os
43
+ import sys
44
+ from collections import defaultdict
45
+ from tabulate import tabulate
46
+ import numpy as np
47
+ import matplotlib.pyplot as plt
48
+ import zipfile
49
+ import logging
50
+ import copy
51
+ from datetime import datetime
52
+
53
+
54
+ def parse_ground_truth(truth):
55
+ """Parse ground truth dict to create dict with the following structure:
56
+ {label : {game : list of frames index}}
57
+
58
+ Args:
59
+ truth : object containing list of informations about videos.
60
+
61
+ Returns:
62
+ label_dict : {label : {game : list of frames index}}.
63
+ """
64
+ label_dict = defaultdict(lambda: defaultdict(list))
65
+ for x in truth:
66
+ for e in x["events"]:
67
+ label_dict[e["label"]][x["path"]].append(e["frame"])
68
+ return label_dict
69
+
70
+
71
+ def get_predictions(pred, label=None):
72
+ """Get a list of all predictions for a particular label.
73
+
74
+ Args:
75
+ pred : Object containing predictions data.
76
+
77
+ Returns:
78
+ flat_pred (List): List of tuples containing the name of the video, the frame index and
79
+ the score confidence for each occurence of the particular label.
80
+ """
81
+ flat_pred = []
82
+ for x in pred:
83
+ for e in x["events"]:
84
+ if label is None or e["label"] == label:
85
+ flat_pred.append((x["video"], e["frame"], e["confidence"]))
86
+ flat_pred.sort(key=lambda x: x[-1], reverse=True)
87
+ return flat_pred
88
+
89
+
90
+ def compute_average_precision(
91
+ pred,
92
+ truth,
93
+ tolerance=0,
94
+ min_precision=0,
95
+ plot_ax=None,
96
+ plot_label=None,
97
+ plot_raw_pr=True,
98
+ ):
99
+ """Compute average precision of predictions regarding truth data with a certain tolerance.
100
+
101
+ Args:
102
+ pred: prediction data.
103
+ truth: groundtruth data.
104
+ tolerance (int).
105
+ Default: 0.
106
+ min_precision (int).
107
+ Default: 0.
108
+ plot_ax (list): list of indexes for the axes.
109
+ Default: None.
110
+ plot_label (string): label of the plot.
111
+ Default: None.
112
+ plot_raw_pr (bool): whether to plot raw predictions.
113
+ Default: False.
114
+ Returns:
115
+ average precision.
116
+ """
117
+ total = sum([len(x) for x in truth.values()])
118
+ recalled = set()
119
+
120
+ # The full precision curve has TOTAL number of bins, when recall increases
121
+ # by in increments of one
122
+ pc = []
123
+ _prev_score = 1
124
+ for i, (video, frame, score) in enumerate(pred, 1):
125
+ assert score <= _prev_score
126
+ _prev_score = score
127
+
128
+ # Find the ground truth frame that is closest to the prediction
129
+ gt_closest = None
130
+ for gt_frame in truth.get(video, []):
131
+ if (video, gt_frame) in recalled:
132
+ continue
133
+ if gt_closest is None or (abs(frame - gt_closest) > abs(frame - gt_frame)):
134
+ gt_closest = gt_frame
135
+
136
+ # Record precision each time a true positive is encountered
137
+ if gt_closest is not None and abs(frame - gt_closest) <= tolerance:
138
+ recalled.add((video, gt_closest))
139
+ p = len(recalled) / i
140
+ pc.append(p)
141
+
142
+ # Stop evaluation early if the precision is too low.
143
+ # Not used, however when nin_precision is 0.
144
+ if p < min_precision:
145
+ break
146
+
147
+ interp_pc = []
148
+ max_p = 0
149
+ for p in pc[::-1]:
150
+ max_p = max(p, max_p)
151
+ interp_pc.append(max_p)
152
+ interp_pc.reverse() # Not actually necessary for integration
153
+
154
+ if plot_ax is not None:
155
+ rc = np.arange(1, len(pc) + 1) / total
156
+ if plot_raw_pr:
157
+ plot_ax.plot(rc, pc, label=plot_label, alpha=0.8)
158
+ plot_ax.plot(rc, interp_pc, label=plot_label, alpha=0.8)
159
+
160
+ # Compute AUC by integrating up to TOTAL bins
161
+ return sum(interp_pc) / total
162
+
163
+
164
+ def compute_mAPs_E2E(truth, pred, tolerances=[0, 1, 2, 3, 4], plot_pr=False):
165
+ """Compute mAPs metric for the training module for the E2E method.
166
+
167
+ Args:
168
+ truth : Object containing the ground truth data.
169
+ pred : Object containing the predictions data.
170
+ tolerances (List[int]): List of tolerances values.
171
+ Default: [0, 1, 2, 4].
172
+ plot_pr (bool): Whether to plot or not the precision recall curve.
173
+ Default: False.
174
+ """
175
+
176
+ assert {v["path"] for v in truth} == {
177
+ v["video"] for v in pred
178
+ }, "Video set mismatch!"
179
+
180
+ truth_by_label = parse_ground_truth(truth)
181
+
182
+ fig, axes = None, None
183
+ if plot_pr:
184
+ fig, axes = plt.subplots(
185
+ len(truth_by_label),
186
+ len(tolerances),
187
+ sharex=True,
188
+ sharey=True,
189
+ figsize=(16, 16),
190
+ )
191
+
192
+ class_aps_for_tol = []
193
+ mAPs = []
194
+ for i, tol in enumerate(tolerances):
195
+ class_aps = []
196
+ for j, (label, truth_for_label) in enumerate(sorted(truth_by_label.items())):
197
+ ap = compute_average_precision(
198
+ get_predictions(pred, label=label),
199
+ truth_for_label,
200
+ tolerance=tol,
201
+ plot_ax=axes[j, i] if axes is not None else None,
202
+ )
203
+ class_aps.append((label, ap))
204
+ mAP = np.mean([x[1] for x in class_aps])
205
+ mAPs.append(mAP)
206
+ class_aps.append(("mAP", mAP))
207
+ class_aps_for_tol.append(class_aps)
208
+ header = ["Class"] + [f"AP@{t}" for t in tolerances]
209
+ rows = []
210
+ for c, _ in class_aps_for_tol[0]:
211
+ row = [str(c)]
212
+ for class_aps in class_aps_for_tol:
213
+ for c2, val in class_aps:
214
+ if c2 == c:
215
+ row.append(float(val) * 100)
216
+ rows.append(row)
217
+
218
+ log_table_wandb(name="AP@tol", rows=rows, headers=header)
219
+ logging.info(tabulate(rows, headers=header, floatfmt="0.2f"))
220
+ logging.info("Avg mAP (across tolerances): {:0.2f}".format(np.mean(mAPs) * 100))
221
+
222
+ if plot_pr:
223
+ for i, tol in enumerate(tolerances):
224
+ for j, label in enumerate(sorted(truth_by_label.keys())):
225
+ ax = axes[j, i]
226
+ ax.set_xlabel("Recall")
227
+ ax.set_xlim(0, 1)
228
+ ax.set_ylabel("Precision")
229
+ ax.set_ylim(0, 1.01)
230
+ ax.set_title("{} @ tol={}".format(label, tol))
231
+ plt.tight_layout()
232
+ plt.show()
233
+ plt.close(fig)
234
+
235
+ sys.stdout.flush()
236
+ return mAPs, tolerances
237
+
238
+
239
+ class ErrorStat:
240
+ """Class to have error statistics"""
241
+
242
+ def __init__(self):
243
+ self._total = 0
244
+ self._err = 0
245
+
246
+ def update(self, true, pred):
247
+ self._err += np.sum(true != pred)
248
+ self._total += true.shape[0]
249
+
250
+ def get(self):
251
+ return self._err / self._total
252
+
253
+ def get_acc(self):
254
+ return 1.0 - self._get()
255
+
256
+
257
+ class ForegroundF1:
258
+ """Class to have f1 scores"""
259
+
260
+ def __init__(self):
261
+ self._tp = defaultdict(int)
262
+ self._fp = defaultdict(int)
263
+ self._fn = defaultdict(int)
264
+
265
+ def update(self, true, pred):
266
+ if pred != 0:
267
+ if true != 0:
268
+ self._tp[None] += 1
269
+ else:
270
+ self._fp[None] += 1
271
+
272
+ if pred == true:
273
+ self._tp[pred] += 1
274
+ else:
275
+ self._fp[pred] += 1
276
+ if true != 0:
277
+ self._fn[true] += 1
278
+ elif true != 0:
279
+ self._fn[None] += 1
280
+ self._fn[true] += 1
281
+
282
+ def get(self, k):
283
+ return self._f1(k)
284
+
285
+ def tp_fp_fn(self, k):
286
+ return self._tp[k], self._fp[k], self._fn[k]
287
+
288
+ def _f1(self, k):
289
+ denom = self._tp[k] + 0.5 * self._fp[k] + 0.5 * self._fn[k]
290
+ if denom == 0:
291
+ assert self._tp[k] == 0
292
+ denom = 1
293
+ return self._tp[k] / denom
294
+
295
+
296
+ def build_snpro_prediction_json(pred_events, head_name="action", split=None, created_by="model"):
297
+ """
298
+ Build new v2 prediction JSON.
299
+ Only 'data' is required.
300
+ """
301
+
302
+ data = []
303
+
304
+ for item in pred_events:
305
+ video = item["video"]
306
+ fps = item["fps"]
307
+
308
+ events = []
309
+ for ev in item["events"]:
310
+ events.append({
311
+ "head": head_name,
312
+ "label": ev["label"],
313
+ "gameTime": ev["gameTime"],
314
+ "frame": ev["frame"],
315
+ "position_ms": ev["position"],
316
+ "confidence": ev["confidence"],
317
+ })
318
+
319
+ data.append({
320
+ "inputs": [{
321
+ "type": "video",
322
+ "path": video,
323
+ "fps": fps,
324
+ }],
325
+ "events": events
326
+ })
327
+
328
+ return {
329
+ "version": "2.0",
330
+ "date": datetime.now().strftime("%Y-%m-%d"),
331
+ "task": "localization",
332
+ "metadata": {
333
+ "type": "predictions",
334
+ "created_by": created_by,
335
+ "split": split,
336
+ },
337
+ "data": data
338
+ }
339
+
340
+ def process_frame_predictions(
341
+ dali, dataset, classes, pred_dict, high_recall_score_threshold=0.01
342
+ ):
343
+ """Process predictions by computing statistics, creating dictionnaries
344
+ with predictions and their associated informations
345
+
346
+ Args:
347
+ dali (bool): If data processing with dali or opencv.
348
+ dataset (Dataset or DaliGenericIterator).
349
+ classes (Dict): Classes associated with indexes.
350
+ pred_dict (Dict): Mapping between clip and a tuple of scores and support.
351
+ high_recall_score_threshold (float):
352
+ Default: 0.01.
353
+
354
+ Returns:
355
+ err (ErrorStat).
356
+ f1 (ForegroundF1).
357
+ pred_events (List[dict]): List of dictionnaries with video, events, fps. Only one class maximum per frame.
358
+ pred_events_high_recall (List[dict]): List of dictionnaries with video, events, fps. Several possible classes per frame.
359
+ pred_scores (dict): Mapping between videos and associated scores.
360
+ """
361
+ classes_inv = {v: k for k, v in classes.items()}
362
+
363
+ fps_dict = {}
364
+
365
+ for video, _, fps in dataset.videos:
366
+ fps_dict[video] = fps
367
+
368
+ err = ErrorStat()
369
+ f1 = ForegroundF1()
370
+
371
+ pred_events = []
372
+ pred_events_high_recall = []
373
+ pred_scores = {}
374
+ for video, (scores, support) in sorted(pred_dict.items()):
375
+ label = dataset.get_labels(video)
376
+
377
+ # support[support == 0] = 1 # get rid of divide by zero
378
+ if dali:
379
+ assert np.min(support[1:]) > 0, (video, support[1:].tolist())
380
+ scores[1:] /= support[1:, None]
381
+ pred = np.argmax(scores[1:], axis=1)
382
+ err.update(label[1:], pred)
383
+ else:
384
+ assert np.min(support) > 0, (video, support.tolist())
385
+ scores /= support[:, None]
386
+ pred = np.argmax(scores, axis=1)
387
+ err.update(label, pred)
388
+
389
+ pred_scores[video] = scores.tolist()
390
+
391
+ events = []
392
+ events_high_recall = []
393
+ # for i in range(1,pred.shape[0]):
394
+ for i in range(pred.shape[0]):
395
+
396
+ if dali:
397
+ f1.update(label[i + 1], pred[i])
398
+ else:
399
+ f1.update(label[i], pred[i])
400
+ if pred[i] != 0:
401
+ if dali:
402
+ tmp = i + 1
403
+ else:
404
+ tmp = i
405
+ seconds = int((tmp // fps_dict[video]) % 60)
406
+ minutes = int((tmp // fps_dict[video]) // 60)
407
+ events.append(
408
+ {
409
+ "label": classes_inv[pred[i]],
410
+ "position": int((tmp * 1000) / fps_dict[video]),
411
+ "gameTime": f"{minutes:02.0f}:{seconds:02.0f}",
412
+ # 'frame': i,
413
+ "frame": tmp,
414
+ "confidence": scores[tmp, pred[i]].item(),
415
+ # 'score': scores[i, pred[i]].item()
416
+ }
417
+ )
418
+
419
+ for j in classes_inv:
420
+ if dali:
421
+ tmp = i + 1
422
+ else:
423
+ tmp = i
424
+ if scores[tmp, j] >= high_recall_score_threshold:
425
+ # if scores[i, j] >= high_recall_score_threshold:
426
+ seconds = int((tmp // fps_dict[video]) % 60)
427
+ minutes = int((tmp // fps_dict[video]) // 60)
428
+ events_high_recall.append(
429
+ {
430
+ "label": classes_inv[j],
431
+ "position": int((tmp * 1000) / fps_dict[video]),
432
+ "gameTime": f"{minutes:02.0f}:{seconds:02.0f}",
433
+ "frame": tmp,
434
+ # 'frame': i,
435
+ "confidence": scores[tmp, j].item(),
436
+ # 'score': scores[i, j].item()
437
+ }
438
+ )
439
+ pred_events.append({"video": video, "events": events, "fps": fps_dict[video]})
440
+ pred_events_high_recall.append(
441
+ {"video": video, "events": events_high_recall, "fps": fps_dict[video]}
442
+ )
443
+
444
+ return err, f1, pred_events, pred_events_high_recall, pred_scores
445
+
446
+
447
+ def infer_and_process_predictions_e2e(
448
+ model,
449
+ dali,
450
+ dataset,
451
+ split,
452
+ classes,
453
+ save_pred,
454
+ calc_stats=True,
455
+ dataloader_params=None,
456
+ return_pred=False,
457
+ ):
458
+ """Infer prediction of actions from clips, process these predictions.
459
+
460
+ Args:
461
+ model .
462
+ dali (bool): Whether dali has been used or opencv to process videos.
463
+ dataset (Dataset or DaliGenericIterator).
464
+ split (string): Split of the data.
465
+ classes (dict) : Classes associated with indexes.
466
+ save_pred (bool) : Save predictions or not.
467
+ calc_stats (bool) : display stats or not.
468
+ Default: True.
469
+ dataloader_params (dict): Parameters for the dataloader.
470
+ Default: None.
471
+ return_pred (bool): Return dict of predictions or not.
472
+ Default: False
473
+
474
+ Returns:
475
+ pred_events_high_recall (List[dict]): List of dictionnaries with video, events, fps. Several possible classes per frame.
476
+ avg_mAP (float): Average mean AP computed for the predictions.
477
+ """
478
+ # print(dataset.)
479
+ pred_dict = {}
480
+ for video, video_len, _ in dataset.videos:
481
+ pred_dict[video] = (
482
+ np.zeros((video_len, len(classes) + 1), np.float32),
483
+ np.zeros(video_len, np.int32),
484
+ )
485
+
486
+ batch_size = dataloader_params.batch_size
487
+
488
+ for clip in tqdm(
489
+ dataset
490
+ if dali
491
+ else DataLoader(
492
+ dataset,
493
+ num_workers=dataloader_params.num_workers,
494
+ pin_memory=dataloader_params.pin_memory,
495
+ batch_size=batch_size,
496
+ )
497
+ ):
498
+ if batch_size > 1:
499
+ # Batched by dataloader
500
+ _, batch_pred_scores = model.predict(clip["frame"])
501
+ for i in range(clip["frame"].shape[0]):
502
+ video = clip["video"][i]
503
+ scores, support = pred_dict[video]
504
+ pred_scores = batch_pred_scores[i]
505
+ start = clip["start"][i].item()
506
+ if start < 0:
507
+ pred_scores = pred_scores[-start:, :]
508
+ start = 0
509
+ end = start + pred_scores.shape[0]
510
+ if end >= scores.shape[0]:
511
+ end = scores.shape[0]
512
+ pred_scores = pred_scores[: end - start, :]
513
+ scores[start:end, :] += pred_scores
514
+ support[start:end] += 1
515
+
516
+ else:
517
+ # Batched by dataset
518
+ scores, support = pred_dict[clip["video"][0]]
519
+
520
+ start = clip["start"][0].item()
521
+ # start=start-1
522
+ _, pred_scores = model.predict(clip["frame"][0])
523
+ if start < 0:
524
+ pred_scores = pred_scores[:, -start:, :]
525
+ start = 0
526
+ end = start + pred_scores.shape[1]
527
+ if end >= scores.shape[0]:
528
+ end = scores.shape[0]
529
+ pred_scores = pred_scores[:, : end - start, :]
530
+
531
+ scores[start:end, :] += np.sum(pred_scores, axis=0)
532
+ support[start:end] += pred_scores.shape[0]
533
+
534
+ err, f1, pred_events, pred_events_high_recall, pred_scores = (
535
+ process_frame_predictions(dali, dataset, classes, pred_dict)
536
+ )
537
+
538
+ avg_mAP = None
539
+ if calc_stats:
540
+ logging.info(f"=== Results on {split} (w/o NMS) ===")
541
+ logging.info(f"Error (frame-level): {err.get() * 100:.2f}")
542
+
543
+ def get_f1_tab_row(str_k):
544
+ k = classes[str_k] if str_k != "any" else None
545
+ return [str_k, f1.get(k) * 100, *f1.tp_fp_fn(k)]
546
+
547
+ rows = [get_f1_tab_row("any")]
548
+ for c in sorted(classes):
549
+ rows.append(get_f1_tab_row(c))
550
+ header = ["Exact frame", "F1", "TP", "FP", "FN"]
551
+ logging.info(
552
+ tabulate(
553
+ rows, headers=header, floatfmt="0.2f"
554
+ )
555
+ )
556
+ log_table_wandb(name="Confusion matrix table", rows=rows, headers=header)
557
+
558
+ mAPs, _ = compute_mAPs_E2E(dataset.labels, pred_events_high_recall)
559
+ avg_mAP = np.mean(mAPs[1:])
560
+
561
+ pred_events = build_snpro_prediction_json(pred_events, head_name=dataset.task_name, split=split, created_by="model")
562
+ pred_events_high_recall = build_snpro_prediction_json(pred_events_high_recall, head_name=dataset.task_name, split=split, created_by="model")
563
+ if save_pred is not None:
564
+ store_json(save_pred + ".json", pred_events, pretty=True)
565
+ store_json(save_pred + ".high_recall.json", pred_events_high_recall, pretty=True)
566
+ store_gz_json(save_pred + ".recall.json.gz", pred_events_high_recall)
567
+ # if save_scores:
568
+ # store_gz_json(save_pred + '.score.json.gz', pred_scores)
569
+ if return_pred:
570
+ return pred_events_high_recall
571
+
572
+ logging.info(f"avg_mAP: {avg_mAP}")
573
+ return avg_mAP
574
+
575
+
576
+ def search_best_epoch(work_dir):
577
+ """
578
+ Args:
579
+ work_dir (string): Path in which there is the json file that contains losses for each epoch.
580
+
581
+ Returns:
582
+ epoch/epoch_mAP (int): The best epoch.
583
+ """
584
+ loss = load_json(os.path.join(work_dir, "loss.json"))
585
+ valid_mAP = 0
586
+ valid = float("inf")
587
+ epoch = -1
588
+ epoch_mAP = -1
589
+ for epoch_loss in loss:
590
+ if epoch_loss["valid_mAP"] > valid_mAP:
591
+ valid_mAP = epoch_loss["valid_mAP"]
592
+ epoch_mAP = epoch_loss["epoch"]
593
+ if epoch_loss["valid"] < valid:
594
+ valid = epoch_loss["valid"]
595
+ epoch = epoch_loss["epoch"]
596
+ if epoch_mAP != -1:
597
+ return epoch_mAP
598
+ else:
599
+ return epoch
600
+
601
+
602
+ def non_maximum_supression(pred, window):
603
+ """Non maximum suppression for predictions for a window size
604
+
605
+ Args:
606
+ pred (List[dict]): List of dictionnaries that contain predictions per video
607
+ window (int): The window size between frames.
608
+
609
+ Returns:
610
+ new_pred (List[dict]) : The predictions after the nms.
611
+ """
612
+ new_pred = []
613
+ for video_pred in pred:
614
+ events_by_label = defaultdict(list)
615
+ for e in video_pred["events"]:
616
+ events_by_label[e["label"]].append(e)
617
+
618
+ events = []
619
+ for v in events_by_label.values():
620
+ for e1 in v:
621
+ for e2 in v:
622
+ if (
623
+ e1["frame"] != e2["frame"]
624
+ and abs(e1["frame"] - e2["frame"]) <= window
625
+ and e1["confidence"] < e2["confidence"]
626
+ ):
627
+ # Found another prediction in the window that has a
628
+ # higher score
629
+ break
630
+ else:
631
+ events.append(e1)
632
+ events.sort(key=lambda x: x["frame"])
633
+ new_video_pred = copy.deepcopy(video_pred)
634
+ new_video_pred["events"] = events
635
+ new_video_pred["num_events"] = len(events)
636
+ new_pred.append(new_video_pred)
637
+ return new_pred
638
+
639
+
640
+ # def store_eval_files_json(raw_pred, eval_dir):
641
+ # """
642
+ # Store predictions.
643
+
644
+ # If eval_dir ends with .json → store NEW v2 unified file
645
+ # Else → store OLD per-video SN files (original behavior)
646
+ # """
647
+
648
+ # # ==========================================================
649
+ # # 🟢 NEW V2 FORMAT (single json file)
650
+ # # ==========================================================
651
+ # if eval_dir.endswith(".json"):
652
+ # from datetime import datetime
653
+
654
+ # data = []
655
+
656
+ # for obj in raw_pred:
657
+ # video = obj["video"]
658
+ # fps = obj["fps"]
659
+
660
+ # events = []
661
+ # for ev in obj["events"]:
662
+ # events.append({
663
+ # "head": "action", # generic
664
+ # "label": ev["label"],
665
+ # "position_ms": ev["position"],
666
+ # "confidence": ev.get("confidence", 1.0),
667
+ # })
668
+
669
+ # data.append({
670
+ # "inputs": [{
671
+ # "type": "video",
672
+ # "path": video,
673
+ # "fps": fps,
674
+ # }],
675
+ # "events": events,
676
+ # })
677
+
678
+ # out = {
679
+ # "version": "2.0",
680
+ # "date": datetime.now().strftime("%Y-%m-%d"),
681
+ # "task": "action_spotting",
682
+ # "metadata": {
683
+ # "type": "predictions",
684
+ # "notes": "Generated by model"
685
+ # },
686
+ # "data": data,
687
+ # }
688
+
689
+ # os.makedirs(os.path.dirname(eval_dir), exist_ok=True)
690
+ # with open(eval_dir, "w") as f:
691
+ # json.dump(out, f, indent=2)
692
+
693
+ # logging.info(f"Stored V2 predictions → {eval_dir}")
694
+ # return True
695
+
696
+ # # ==========================================================
697
+ # # 🔵 OLD FORMAT (original code untouched)
698
+ # # ==========================================================
699
+ # only_one_file = False
700
+ # video_pred = defaultdict(list)
701
+ # video_fps = defaultdict(list)
702
+
703
+ # for obj in raw_pred:
704
+ # video = obj["video"]
705
+ # video_fps[video] = obj["fps"]
706
+
707
+ # for event in obj["events"]:
708
+ # video_pred[video].append(
709
+ # {
710
+ # "frame": event["frame"],
711
+ # "label": event["label"],
712
+ # "confidence": event["confidence"],
713
+ # "position": event["position"],
714
+ # "gameTime": event["gameTime"],
715
+ # }
716
+ # )
717
+
718
+ # for video, pred in video_pred.items():
719
+ # if len(raw_pred) == 1:
720
+ # video_out_dir = eval_dir
721
+ # only_one_file = True
722
+ # else:
723
+ # video_out_dir = os.path.join(eval_dir, os.path.splitext(video)[0])
724
+
725
+ # os.makedirs(video_out_dir, exist_ok=True)
726
+
727
+ # store_json(
728
+ # os.path.join(video_out_dir, "results_spotting.json"),
729
+ # {"Url": video, "predictions": pred, "fps": video_fps[video]},
730
+ # pretty=True,
731
+ # )
732
+
733
+ # logging.info(f"Stored OLD format → {eval_dir}")
734
+ # return only_one_file
735
+
736
+ def store_eval_files_json(raw_pred, eval_dir, save_v2=True):
737
+ """
738
+ Store predictions.
739
+
740
+ Supports:
741
+ - old internal format (list)
742
+ - new v2 format (dict with 'data')
743
+
744
+ Store predictions.
745
+
746
+ save_v2=True → store NEW unified v2 json
747
+ save_v2=False → original SN behavior.
748
+ """
749
+
750
+
751
+ # ==========================================================
752
+ # 🟢 NEW V2 FORMAT
753
+ # ==========================================================
754
+ only_one_file = False
755
+
756
+ if save_v2:
757
+ from datetime import datetime
758
+
759
+ for obj in raw_pred:
760
+ video = obj["video"]
761
+ fps = obj["fps"]
762
+
763
+ # directory logic SAME as old
764
+ if len(raw_pred) == 1:
765
+ video_out_dir = eval_dir
766
+ only_one_file = True
767
+ else:
768
+ video_out_dir = os.path.join(eval_dir, os.path.splitext(video)[0])
769
+
770
+ os.makedirs(video_out_dir, exist_ok=True)
771
+
772
+ events = []
773
+ for ev in obj["events"]:
774
+ events.append({
775
+ "head": "action",
776
+ "label": ev.get("label"),
777
+ "frame": ev.get("frame"),
778
+ "position_ms": int(ev.get("position")),
779
+ "confidence": float(ev.get("confidence")),
780
+ "gameTime": ev.get("gameTime")
781
+ })
782
+
783
+ out = {
784
+ "version": "2.0",
785
+ "date": datetime.now().strftime("%Y-%m-%d"),
786
+ "task": "action_spotting",
787
+ "metadata": {"type": "predictions"},
788
+ "data": [{
789
+ "inputs": [{
790
+ "type": "video",
791
+ "path": video,
792
+ "fps": fps,
793
+ }],
794
+ "events": events,
795
+ }],
796
+ }
797
+
798
+ out_path = os.path.join(video_out_dir, "results_spotting.json")
799
+ with open(out_path, "w") as f:
800
+ json.dump(out, f, indent=2)
801
+
802
+ logging.info(f"Stored V2 predictions → {eval_dir}")
803
+ return only_one_file
804
+
805
+
806
+ # ==========================================================
807
+ # 🔵 OLD FORMAT (UNCHANGED)
808
+ # ==========================================================
809
+ only_one_file = False
810
+ video_pred = defaultdict(list)
811
+ video_fps = defaultdict(list)
812
+
813
+ for obj in raw_pred:
814
+ video = obj["video"]
815
+ video_fps[video] = obj["fps"]
816
+
817
+ for event in obj["events"]:
818
+ video_pred[video].append(
819
+ {
820
+ "frame": event["frame"],
821
+ "label": event["label"],
822
+ "confidence": event["confidence"],
823
+ "position": event["position"],
824
+ "gameTime": event["gameTime"],
825
+ }
826
+ )
827
+
828
+ for video, pred in video_pred.items():
829
+ if len(raw_pred) == 1:
830
+ video_out_dir = eval_dir
831
+ only_one_file = True
832
+ else:
833
+ video_out_dir = os.path.join(eval_dir, os.path.splitext(video)[0])
834
+
835
+ os.makedirs(video_out_dir, exist_ok=True)
836
+
837
+ store_json(
838
+ os.path.join(video_out_dir, "results_spotting.json"),
839
+ {"Url": video, "predictions": pred, "fps": video_fps[video]},
840
+ pretty=True,
841
+ )
842
+
843
+ logging.info(f"Stored OLD format → {eval_dir}")
844
+ return only_one_file
845
+
846
+
847
+
848
+
849
+ def label2vector(
850
+ labels, num_classes=17, framerate=2, EVENT_DICTIONARY={}, vector_size=None
851
+ ):
852
+ """Transform list of dict containing ground truth labels into a 2D array.
853
+ Args:
854
+ labels (List[dict]): List of groundtruth labels for a video.
855
+ num_classes (int): Number of classes.
856
+ Default: 17.
857
+ framerate (int): Rate at which the frames have been processed in a video.
858
+ Default: 2.
859
+ EVENT_DICTIONARY (dict): Mapping between classes and indexes.
860
+ Default: {}.
861
+ vector_size (int): Size of the returned vector.
862
+ Default: None.
863
+
864
+ Returns:
865
+ dense_labels (Numpy array): Array of shape (vector_size, num_classes) containing a value.
866
+ The first index is the frame number and the second one is the class. The value can be -1 is the event
867
+ is unshown, 1 is the event is visible and 0 otherwise, meaning that the event do not occur for this frame.
868
+
869
+ """
870
+ vector_size = int(90 * 60 * framerate if vector_size is None else vector_size)
871
+
872
+ dense_labels = np.zeros((vector_size, num_classes))
873
+
874
+ for annotation in labels:
875
+
876
+ #print(annotation)
877
+ event = annotation["label"]
878
+ if "frame" in annotation:
879
+ frame = int(annotation["frame"])
880
+
881
+ else:
882
+ time = annotation["gameTime"]
883
+
884
+ minutes = int(time[-5:-3])
885
+ seconds = int(time[-2::])
886
+ # annotation at millisecond precision
887
+ if "position" in annotation:
888
+ frame = int(framerate * (int(annotation["position"]) / 1000))
889
+ # annotation at second precision
890
+ else:
891
+ frame = framerate * (seconds + 60 * minutes)
892
+
893
+ label = EVENT_DICTIONARY[event]
894
+
895
+ value = 1
896
+ if "visibility" in annotation.keys():
897
+ if annotation["visibility"] == "not shown":
898
+ value = -1
899
+
900
+ frame = min(frame, vector_size - 1)
901
+ dense_labels[frame][label] = value
902
+
903
+ return dense_labels
904
+
905
+
906
+ def predictions2vector(
907
+ predictions, num_classes=17, framerate=2, EVENT_DICTIONARY={}, vector_size=None
908
+ ):
909
+ """Transform list of dict containing predictions into a 2D array.
910
+ Args:
911
+ predictions (List[dict]): List of predictions for a video.
912
+ num_classes (int): Number of classes.
913
+ Default: 17.
914
+ framerate (int): Rate at which the frames have been processed in a video.
915
+ Default: 2.
916
+ EVENT_DICTIONARY (dict): Mapping between classes and indexes.
917
+ Default: {}.
918
+ vector_size (int): Size of the returned vector.
919
+ Default: None.
920
+
921
+ Returns:
922
+ dense_predictions (Numpy array): Array of shape (vector_size, num_classes) containing a value.
923
+ The first index is the frame number and the second one is the class. The value is the score of the class.
924
+
925
+ """
926
+ vector_size = int(90 * 60 * framerate if vector_size is None else vector_size)
927
+
928
+ dense_predictions = np.zeros((vector_size, num_classes)) - 1
929
+
930
+ for annotation in predictions:
931
+
932
+ event = annotation["label"]
933
+
934
+ if "frame" in annotation:
935
+ frame = int(annotation["frame"])
936
+ else:
937
+ time = int(annotation["position"])
938
+
939
+ # half = int(annotation["half"])
940
+
941
+ frame = int(framerate * (time / 1000))
942
+
943
+ label = EVENT_DICTIONARY[event]
944
+
945
+ frame = min(frame, vector_size - 1)
946
+ dense_predictions[frame][label] = annotation["confidence"]
947
+
948
+ return dense_predictions
949
+
950
+
951
+ np.seterr(divide="ignore", invalid="ignore")
952
+
953
+
954
+ class AverageMeter(object):
955
+ """Computes and stores the average and current value"""
956
+
957
+ def __init__(self):
958
+ self.reset()
959
+
960
+ def reset(self):
961
+ self.val = 0
962
+ self.avg = 0
963
+ self.sum = 0
964
+ self.count = 0
965
+
966
+ def update(self, val, n=1):
967
+ self.val = val
968
+ self.sum += val * n
969
+ self.count += n
970
+ self.avg = self.sum / self.count
971
+
972
+
973
+ def compute_class_scores(target, closest, detection, delta):
974
+ """Compute the scores for a single class.
975
+
976
+ Args:
977
+ target (np.array(vector_size): List of ground truth targets of shape (number of frames).
978
+ closest (np.array(vector_size - 1): List of closest action index of shape (number of frames - 1).
979
+ detection (np.array(vector_size): List of predictions of shape (number of frames).
980
+ delta (int): Tolerance.
981
+
982
+ Returns:
983
+ game_detections (np.array(number of predictions,3): Array to save the scores with the first column being
984
+ the prediction score, the second one being 1 if an event have been found with requirements, the third one being
985
+ the score of the closest action.
986
+ len(gt_indexes_visible) (int): number of visible events.
987
+ len(gt_indexes_unshown) (int): number of unshown events.
988
+
989
+ """
990
+
991
+ # Retrieving the important variables
992
+ gt_indexes = np.where(target != 0)[0]
993
+ gt_indexes_visible = np.where(target > 0)[0]
994
+ gt_indexes_unshown = np.where(target < 0)[0]
995
+ pred_indexes = np.where(detection >= 0)[0]
996
+ pred_scores = detection[pred_indexes]
997
+
998
+ # Array to save the results, each is [pred_scor,{1 or 0},closest[pred_indexes]]
999
+ game_detections = np.zeros((len(pred_indexes), 3))
1000
+ game_detections[:, 0] = np.copy(pred_scores)
1001
+ game_detections[:, 2] = np.copy(closest[pred_indexes])
1002
+
1003
+ remove_indexes = list()
1004
+
1005
+ for gt_index in gt_indexes:
1006
+
1007
+ max_score = -1
1008
+ max_index = None
1009
+ game_index = 0
1010
+ selected_game_index = 0
1011
+
1012
+ for pred_index, pred_score in zip(pred_indexes, pred_scores):
1013
+
1014
+ if pred_index < gt_index - delta:
1015
+ game_index += 1
1016
+ continue
1017
+ if pred_index > gt_index + delta:
1018
+ break
1019
+
1020
+ if (
1021
+ abs(pred_index - gt_index) <= delta / 2
1022
+ and pred_score > max_score
1023
+ and pred_index not in remove_indexes
1024
+ ):
1025
+ max_score = pred_score
1026
+ max_index = pred_index
1027
+ selected_game_index = game_index
1028
+ game_index += 1
1029
+
1030
+ if max_index is not None:
1031
+ game_detections[selected_game_index, 1] = 1
1032
+ remove_indexes.append(max_index)
1033
+
1034
+ return game_detections, len(gt_indexes_visible), len(gt_indexes_unshown)
1035
+
1036
+
1037
+ def compute_precision_recall_curve(targets, closests, detections, delta):
1038
+ """Compute precision recall curve.
1039
+
1040
+ Args:
1041
+ target (List(np.array(vector_size,num_classes)): List of ground truth targets of shape (number of videos, number of frames,number of classes).
1042
+ closest (List(np.array(vector_size - 1,num_classes)): List of closest action index of shape (number of videos, number of frames - 1,number of classes).
1043
+ detection (List(np.array(vector_size,num_classes)): List of predictions of shape (number of videos, number of frames,number of classes).
1044
+ delta (int): Tolerance.
1045
+
1046
+ Returns:
1047
+ precision: List of precision points.
1048
+ recall: List of recall points.
1049
+ precision_visible: List of precision points for only the visible events.
1050
+ recall_visible: List of recall points for only the visible events.
1051
+ precision_unshown: List of precision points for only the unshown events.
1052
+ recall_unshown: List of recall points for only the unshown events.
1053
+
1054
+ """
1055
+
1056
+ # Store the number of classes
1057
+ num_classes = targets[0].shape[-1]
1058
+
1059
+ # 200 confidence thresholds between [0,1]
1060
+ thresholds = np.linspace(0, 1, 200)
1061
+
1062
+ # Store the precision and recall points
1063
+ precision = list()
1064
+ recall = list()
1065
+ precision_visible = list()
1066
+ recall_visible = list()
1067
+ precision_unshown = list()
1068
+ recall_unshown = list()
1069
+
1070
+ # Apply Non-Maxima Suppression if required
1071
+ start = time.time()
1072
+
1073
+ # Precompute the predictions scores and their correspondence {TP, FP} for each class
1074
+ for c in np.arange(num_classes):
1075
+ total_detections = np.zeros((1, 3))
1076
+ total_detections[0, 0] = -1
1077
+ n_gt_labels_visible = 0
1078
+ n_gt_labels_unshown = 0
1079
+
1080
+ # Get the confidence scores and their corresponding TP or FP characteristics for each game
1081
+ for target, closest, detection in zip(targets, closests, detections):
1082
+ tmp_detections, tmp_n_gt_labels_visible, tmp_n_gt_labels_unshown = (
1083
+ compute_class_scores(
1084
+ target[:, c], closest[:, c], detection[:, c], delta
1085
+ )
1086
+ )
1087
+ total_detections = np.append(total_detections, tmp_detections, axis=0)
1088
+ n_gt_labels_visible = n_gt_labels_visible + tmp_n_gt_labels_visible
1089
+ n_gt_labels_unshown = n_gt_labels_unshown + tmp_n_gt_labels_unshown
1090
+
1091
+ precision.append(list())
1092
+ recall.append(list())
1093
+ precision_visible.append(list())
1094
+ recall_visible.append(list())
1095
+ precision_unshown.append(list())
1096
+ recall_unshown.append(list())
1097
+
1098
+ # Get only the visible or unshown actions
1099
+ total_detections_visible = np.copy(total_detections)
1100
+ total_detections_unshown = np.copy(total_detections)
1101
+ total_detections_visible[
1102
+ np.where(total_detections_visible[:, 2] <= 0.5)[0], 0
1103
+ ] = -1
1104
+ total_detections_unshown[
1105
+ np.where(total_detections_unshown[:, 2] >= -0.5)[0], 0
1106
+ ] = -1
1107
+
1108
+ # Get the precision and recall for each confidence threshold
1109
+ for threshold in thresholds:
1110
+ pred_indexes = np.where(total_detections[:, 0] >= threshold)[0]
1111
+ pred_indexes_visible = np.where(
1112
+ total_detections_visible[:, 0] >= threshold
1113
+ )[0]
1114
+ pred_indexes_unshown = np.where(
1115
+ total_detections_unshown[:, 0] >= threshold
1116
+ )[0]
1117
+ TP = np.sum(total_detections[pred_indexes, 1])
1118
+ TP_visible = np.sum(total_detections[pred_indexes_visible, 1])
1119
+ TP_unshown = np.sum(total_detections[pred_indexes_unshown, 1])
1120
+ p = np.nan_to_num(TP / len(pred_indexes))
1121
+ r = np.nan_to_num(TP / (n_gt_labels_visible + n_gt_labels_unshown))
1122
+ p_visible = np.nan_to_num(TP_visible / len(pred_indexes_visible))
1123
+ r_visible = np.nan_to_num(TP_visible / n_gt_labels_visible)
1124
+ p_unshown = np.nan_to_num(TP_unshown / len(pred_indexes_unshown))
1125
+ r_unshown = np.nan_to_num(TP_unshown / n_gt_labels_unshown)
1126
+ precision[-1].append(p)
1127
+ recall[-1].append(r)
1128
+ precision_visible[-1].append(p_visible)
1129
+ recall_visible[-1].append(r_visible)
1130
+ precision_unshown[-1].append(p_unshown)
1131
+ recall_unshown[-1].append(r_unshown)
1132
+
1133
+ precision = np.array(precision).transpose()
1134
+ recall = np.array(recall).transpose()
1135
+ precision_visible = np.array(precision_visible).transpose()
1136
+ recall_visible = np.array(recall_visible).transpose()
1137
+ precision_unshown = np.array(precision_unshown).transpose()
1138
+ recall_unshown = np.array(recall_unshown).transpose()
1139
+
1140
+ # Sort the points based on the recall, class per class
1141
+ for i in np.arange(num_classes):
1142
+ index_sort = np.argsort(recall[:, i])
1143
+ precision[:, i] = precision[index_sort, i]
1144
+ recall[:, i] = recall[index_sort, i]
1145
+
1146
+ # Sort the points based on the recall, class per class
1147
+ for i in np.arange(num_classes):
1148
+ index_sort = np.argsort(recall_visible[:, i])
1149
+ precision_visible[:, i] = precision_visible[index_sort, i]
1150
+ recall_visible[:, i] = recall_visible[index_sort, i]
1151
+
1152
+ # Sort the points based on the recall, class per class
1153
+ for i in np.arange(num_classes):
1154
+ index_sort = np.argsort(recall_unshown[:, i])
1155
+ precision_unshown[:, i] = precision_unshown[index_sort, i]
1156
+ recall_unshown[:, i] = recall_unshown[index_sort, i]
1157
+
1158
+ return (
1159
+ precision,
1160
+ recall,
1161
+ precision_visible,
1162
+ recall_visible,
1163
+ precision_unshown,
1164
+ recall_unshown,
1165
+ )
1166
+
1167
+
1168
+ def compute_mAP(precision, recall):
1169
+ """Compute mean AP per class from precision and recall points
1170
+
1171
+ Args:
1172
+ precision: List of precision points.
1173
+ recall: List of recall points.
1174
+
1175
+ Returns:
1176
+ np.mean(mAP_per_class): mean of the mAP over all classes.
1177
+ mAP_per_class: List of mAP for each class.
1178
+
1179
+ """
1180
+ # Array for storing the AP per class
1181
+ AP = np.array([0.0] * precision.shape[-1])
1182
+
1183
+ # Loop for all classes
1184
+ for i in np.arange(precision.shape[-1]):
1185
+
1186
+ # 11 point interpolation
1187
+ for j in np.arange(11) / 10:
1188
+
1189
+ index_recall = np.where(recall[:, i] >= j)[0]
1190
+
1191
+ possible_value_precision = precision[index_recall, i]
1192
+ max_value_precision = 0
1193
+
1194
+ if possible_value_precision.shape[0] != 0:
1195
+ max_value_precision = np.max(possible_value_precision)
1196
+
1197
+ AP[i] += max_value_precision
1198
+
1199
+ mAP_per_class = AP / 11
1200
+
1201
+ return np.mean(mAP_per_class), mAP_per_class
1202
+
1203
+
1204
+ # Tight: (SNv3): np.arange(5)*1 + 1
1205
+ # Loose: (SNv1/v2): np.arange(12)*5 + 5
1206
+
1207
+
1208
+ def delta_curve(targets, closests, detections, framerate, deltas=np.arange(5) * 1 + 1):
1209
+ """Compute lists of mAP for each tolerance.
1210
+
1211
+ Args:
1212
+ targets (List(np.array(vector_size,num_classes)): List of ground truth targets of shape (number of videos, number of frames,number of classes).
1213
+ closests (List(np.array(vector_size - 1,num_classes)): List of closest action index of shape (number of videos, number of frames - 1,number of classes).
1214
+ detections (List(np.array(vector_size,num_classes)): List of predictions of shape (number of videos, number of frames,number of classes).
1215
+ delta (np.array): Tolerances.
1216
+
1217
+ Returns:
1218
+ mAP: List of mean mAP over all the classes for each tolerance.
1219
+ mAP_per_class: List of list of mAP for all classes for each tolerance.
1220
+ mAP_visible: List of mean mAP over all the classes for each tolerance only for the visible events.
1221
+ mAP_per_class_visible: List of list of mAP for all classes for each tolerance only for the visible events.
1222
+ mAP_unshown: List of mean mAP over all the classes for each tolerance only for the unshown events.
1223
+ mAP_per_class_unshown: List of list of mAP for all classes for each tolerance only for the unshown events.
1224
+
1225
+ """
1226
+ mAP = list()
1227
+ mAP_per_class = list()
1228
+ mAP_visible = list()
1229
+ mAP_per_class_visible = list()
1230
+ mAP_unshown = list()
1231
+ mAP_per_class_unshown = list()
1232
+
1233
+ for delta in tqdm(deltas * framerate):
1234
+
1235
+ (
1236
+ precision,
1237
+ recall,
1238
+ precision_visible,
1239
+ recall_visible,
1240
+ precision_unshown,
1241
+ recall_unshown,
1242
+ ) = compute_precision_recall_curve(targets, closests, detections, delta)
1243
+
1244
+ tmp_mAP, tmp_mAP_per_class = compute_mAP(precision, recall)
1245
+ mAP.append(tmp_mAP)
1246
+ mAP_per_class.append(tmp_mAP_per_class)
1247
+ # TODO: compute visible/undshown from another JSON file containing only the visible/unshown annotations
1248
+ tmp_mAP_visible, tmp_mAP_per_class_visible = compute_mAP(
1249
+ precision_visible, recall_visible
1250
+ )
1251
+ mAP_visible.append(tmp_mAP_visible)
1252
+ mAP_per_class_visible.append(tmp_mAP_per_class_visible)
1253
+ tmp_mAP_unshown, tmp_mAP_per_class_unshown = compute_mAP(
1254
+ precision_unshown, recall_unshown
1255
+ )
1256
+ mAP_unshown.append(tmp_mAP_unshown)
1257
+ mAP_per_class_unshown.append(tmp_mAP_per_class_unshown)
1258
+
1259
+ return (
1260
+ mAP,
1261
+ mAP_per_class,
1262
+ mAP_visible,
1263
+ mAP_per_class_visible,
1264
+ mAP_unshown,
1265
+ mAP_per_class_unshown,
1266
+ )
1267
+
1268
+
1269
+ def average_mAP(
1270
+ targets, detections, closests, framerate=2, deltas=np.arange(5) * 1 + 1
1271
+ ):
1272
+ """Compute average mAP.
1273
+
1274
+ Args:
1275
+ targets (List(np.array(vector_size,num_classes)): List of ground truth targets of shape (number of videos, number of frames,number of classes).
1276
+ closests (List(np.array(vector_size - 1,num_classes)): List of closest action index of shape (number of videos, number of frames - 1,number of classes).
1277
+ detections (List(np.array(vector_size,num_classes)): List of predictions of shape (number of videos, number of frames,number of classes).
1278
+ framerate (int).
1279
+ Default: 2.
1280
+ delta (np.array): Tolerances.
1281
+
1282
+ Returns:
1283
+ a_mAP: average mAP.
1284
+ a_mAP_per_class: List of average mAP for all classes.
1285
+ a_mAP_visible: average mAP only for the visible events.
1286
+ a_mAP_per_class_visible: List of average mAP only for the visible events.
1287
+ a_mAP_unshown: average mAP only for the unshown events.
1288
+ a_mAP_per_class_unshown: List of average mAP only for the unshown events.
1289
+
1290
+ """
1291
+ (
1292
+ mAP,
1293
+ mAP_per_class,
1294
+ mAP_visible,
1295
+ mAP_per_class_visible,
1296
+ mAP_unshown,
1297
+ mAP_per_class_unshown,
1298
+ ) = delta_curve(targets, closests, detections, framerate, deltas)
1299
+
1300
+ if len(mAP) == 1:
1301
+ return (
1302
+ mAP[0],
1303
+ mAP_per_class[0],
1304
+ mAP_visible[0],
1305
+ mAP_per_class_visible[0],
1306
+ mAP_unshown[0],
1307
+ mAP_per_class_unshown[0],
1308
+ )
1309
+
1310
+ # Compute the average mAP
1311
+ integral = 0.0
1312
+ for i in np.arange(len(mAP) - 1):
1313
+ integral += (mAP[i] + mAP[i + 1]) / 2
1314
+ a_mAP = integral / ((len(mAP) - 1))
1315
+
1316
+ integral_visible = 0.0
1317
+ for i in np.arange(len(mAP_visible) - 1):
1318
+ integral_visible += (mAP_visible[i] + mAP_visible[i + 1]) / 2
1319
+ a_mAP_visible = integral_visible / ((len(mAP_visible) - 1))
1320
+
1321
+ integral_unshown = 0.0
1322
+ for i in np.arange(len(mAP_unshown) - 1):
1323
+ integral_unshown += (mAP_unshown[i] + mAP_unshown[i + 1]) / 2
1324
+ a_mAP_unshown = integral_unshown / ((len(mAP_unshown) - 1))
1325
+ a_mAP_unshown = a_mAP_unshown * 17 / 13
1326
+
1327
+ a_mAP_per_class = list()
1328
+ for c in np.arange(len(mAP_per_class[0])):
1329
+ integral_per_class = 0.0
1330
+ for i in np.arange(len(mAP_per_class) - 1):
1331
+ integral_per_class += (mAP_per_class[i][c] + mAP_per_class[i + 1][c]) / 2
1332
+ a_mAP_per_class.append(integral_per_class / ((len(mAP_per_class) - 1)))
1333
+
1334
+ a_mAP_per_class_visible = list()
1335
+ for c in np.arange(len(mAP_per_class_visible[0])):
1336
+ integral_per_class_visible = 0.0
1337
+ for i in np.arange(len(mAP_per_class_visible) - 1):
1338
+ integral_per_class_visible += (
1339
+ mAP_per_class_visible[i][c] + mAP_per_class_visible[i + 1][c]
1340
+ ) / 2
1341
+ a_mAP_per_class_visible.append(
1342
+ integral_per_class_visible / ((len(mAP_per_class_visible) - 1))
1343
+ )
1344
+
1345
+ a_mAP_per_class_unshown = list()
1346
+ for c in np.arange(len(mAP_per_class_unshown[0])):
1347
+ integral_per_class_unshown = 0.0
1348
+ for i in np.arange(len(mAP_per_class_unshown) - 1):
1349
+ integral_per_class_unshown += (
1350
+ mAP_per_class_unshown[i][c] + mAP_per_class_unshown[i + 1][c]
1351
+ ) / 2
1352
+ a_mAP_per_class_unshown.append(
1353
+ integral_per_class_unshown / ((len(mAP_per_class_unshown) - 1))
1354
+ )
1355
+
1356
+ return (
1357
+ a_mAP,
1358
+ a_mAP_per_class,
1359
+ a_mAP_visible,
1360
+ a_mAP_per_class_visible,
1361
+ a_mAP_unshown,
1362
+ a_mAP_per_class_unshown,
1363
+ )
1364
+
1365
+
1366
+ def LoadJsonFromZip(zippedFile, JsonPath):
1367
+ with zipfile.ZipFile(zippedFile, "r") as z:
1368
+ # print(filename)
1369
+ with z.open(JsonPath) as f:
1370
+ data = f.read()
1371
+ d = json.loads(data.decode("utf-8"))
1372
+
1373
+ return d
1374
+
1375
+
1376
+ def get_closest_action_index(dense_labels, closest_numpy):
1377
+ """Assign the action to a list of indexes based on the previous and the next indexes.
1378
+ Example:
1379
+ If action occurs at index 10, the previous index at which the actions occured is 6 and the next one is 12,
1380
+ the action will then be assigned from index 8 to 11.
1381
+
1382
+ Args:
1383
+ dense_labels: Vector of groundtruth labels.
1384
+ closest_numpy: Empty vector that stores the labels at the closest indexes.
1385
+
1386
+ Returns:
1387
+ closest_numpy: Vector that stores the labels at the closest indexes.
1388
+ """
1389
+ for c in np.arange(dense_labels.shape[-1]):
1390
+ indexes = np.where(dense_labels[:, c] != 0)[0].tolist()
1391
+ if len(indexes) == 0:
1392
+ continue
1393
+ indexes.insert(0, -indexes[0])
1394
+ indexes.append(2 * closest_numpy.shape[0])
1395
+ for i in np.arange(len(indexes) - 2) + 1:
1396
+ start = max(0, (indexes[i - 1] + indexes[i]) // 2)
1397
+ stop = min(closest_numpy.shape[0], (indexes[i] + indexes[i + 1]) // 2)
1398
+ closest_numpy[start:stop, c] = dense_labels[indexes[i], c]
1399
+ return closest_numpy
1400
+
1401
+
1402
+ def compute_performances_mAP(
1403
+ metric, targets_numpy, detections_numpy, closests_numpy, INVERSE_EVENT_DICTIONARY
1404
+ ):
1405
+ """Compute the different mAP and display them.
1406
+
1407
+ Args:
1408
+ metric (string): The metric that will be used to compute the mAP.
1409
+ targets_numpy (List(np.array(vector_size,num_classes)): List of ground truth targets of shape (number of videos, number of frames,number of classes).
1410
+ detections_numpy (List(np.array(vector_size - 1,num_classes)): List of closest action index of shape (number of videos, number of frames - 1,number of classes).
1411
+ closests_numpy (List(np.array(vector_size,num_classes)): List of predictions of shape (number of videos, number of frames,number of classes).
1412
+ INVERSE_EVENT_DICTIONARY (dict): mapping between indexes and class name.
1413
+
1414
+ Returns:
1415
+ results (dict): Dictionnary containing the different mAP computed.
1416
+ """
1417
+ if metric == "loose":
1418
+ deltas = np.arange(12) * 5 + 5
1419
+ elif metric == "tight":
1420
+ deltas = np.arange(5) * 1 + 1
1421
+ elif metric == "at1":
1422
+ deltas = np.array([1])
1423
+ elif metric == "at2":
1424
+ deltas = np.array([2])
1425
+ elif metric == "at3":
1426
+ deltas = np.array([3])
1427
+ elif metric == "at4":
1428
+ deltas = np.array([4])
1429
+ elif metric == "at5":
1430
+ deltas = np.array([5])
1431
+
1432
+ # Compute the performances
1433
+ (
1434
+ a_mAP,
1435
+ a_mAP_per_class,
1436
+ a_mAP_visible,
1437
+ a_mAP_per_class_visible,
1438
+ a_mAP_unshown,
1439
+ a_mAP_per_class_unshown,
1440
+ ) = average_mAP(
1441
+ targets_numpy, detections_numpy, closests_numpy, framerate=2, deltas=deltas
1442
+ )
1443
+
1444
+ results = {
1445
+ "a_mAP": a_mAP,
1446
+ "a_mAP_per_class": a_mAP_per_class,
1447
+ "a_mAP_visible": a_mAP_visible,
1448
+ "a_mAP_per_class_visible": a_mAP_per_class_visible,
1449
+ "a_mAP_unshown": a_mAP_unshown,
1450
+ "a_mAP_per_class_unshown": a_mAP_per_class_unshown,
1451
+ }
1452
+
1453
+ rows = []
1454
+ for i in range(len(results["a_mAP_per_class"])):
1455
+ label = INVERSE_EVENT_DICTIONARY[i]
1456
+ rows.append(
1457
+ (
1458
+ label,
1459
+ "{:0.2f}".format(results["a_mAP_per_class"][i] * 100),
1460
+ "{:0.2f}".format(results["a_mAP_per_class_visible"][i] * 100),
1461
+ "{:0.2f}".format(results["a_mAP_per_class_unshown"][i] * 100),
1462
+ )
1463
+ )
1464
+ rows.append(
1465
+ (
1466
+ "Average mAP",
1467
+ "{:0.2f}".format(results["a_mAP"] * 100),
1468
+ "{:0.2f}".format(results["a_mAP_visible"] * 100),
1469
+ "{:0.2f}".format(results["a_mAP_unshown"] * 100),
1470
+ )
1471
+ )
1472
+
1473
+ header = ["", "Any", "Visible", "Unseen"]
1474
+ result = tabulate(rows, headers=header)
1475
+ log_table_wandb(name=f"Final Scores (Metric : {metric})", rows=rows, headers=header)
1476
+ logging.info("Best Performance at end of training ")
1477
+ logging.info("Metric: " + metric)
1478
+ logging.info("\n" + result)
1479
+ # print(tabulate(rows, headers=['', 'Any', 'Visible', 'Unseen']))
1480
+ return result
1481
+
1482
+ np.seterr(divide="ignore", invalid="ignore")