opensportslib 0.0.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opensportslib/__init__.py +18 -0
- opensportslib/apis/__init__.py +21 -0
- opensportslib/apis/classification.py +361 -0
- opensportslib/apis/localization.py +228 -0
- opensportslib/config/classification.yaml +104 -0
- opensportslib/config/classification_tracking.yaml +103 -0
- opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
- opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
- opensportslib/config/localization.yaml +132 -0
- opensportslib/config/sngar_frames.yaml +98 -0
- opensportslib/core/__init__.py +0 -0
- opensportslib/core/loss/__init__.py +0 -0
- opensportslib/core/loss/builder.py +40 -0
- opensportslib/core/loss/calf.py +258 -0
- opensportslib/core/loss/ce.py +23 -0
- opensportslib/core/loss/combine.py +42 -0
- opensportslib/core/loss/nll.py +25 -0
- opensportslib/core/optimizer/__init__.py +0 -0
- opensportslib/core/optimizer/builder.py +38 -0
- opensportslib/core/sampler/weighted_sampler.py +104 -0
- opensportslib/core/scheduler/__init__.py +0 -0
- opensportslib/core/scheduler/builder.py +77 -0
- opensportslib/core/trainer/__init__.py +0 -0
- opensportslib/core/trainer/classification_trainer.py +1131 -0
- opensportslib/core/trainer/localization_trainer.py +1009 -0
- opensportslib/core/utils/checkpoint.py +238 -0
- opensportslib/core/utils/config.py +199 -0
- opensportslib/core/utils/data.py +85 -0
- opensportslib/core/utils/ddp.py +77 -0
- opensportslib/core/utils/default_args.py +110 -0
- opensportslib/core/utils/load_annotations.py +485 -0
- opensportslib/core/utils/seed.py +26 -0
- opensportslib/core/utils/video_processing.py +389 -0
- opensportslib/core/utils/wandb.py +110 -0
- opensportslib/datasets/__init__.py +0 -0
- opensportslib/datasets/builder.py +42 -0
- opensportslib/datasets/classification_dataset.py +582 -0
- opensportslib/datasets/localization_dataset.py +813 -0
- opensportslib/datasets/utils/__init__.py +15 -0
- opensportslib/datasets/utils/tracking.py +615 -0
- opensportslib/metrics/classification_metric.py +176 -0
- opensportslib/metrics/localization_metric.py +1482 -0
- opensportslib/models/__init__.py +0 -0
- opensportslib/models/backbones/builder.py +590 -0
- opensportslib/models/base/e2e.py +252 -0
- opensportslib/models/base/tracking.py +73 -0
- opensportslib/models/base/vars.py +29 -0
- opensportslib/models/base/video.py +130 -0
- opensportslib/models/base/video_mae.py +60 -0
- opensportslib/models/builder.py +43 -0
- opensportslib/models/heads/builder.py +266 -0
- opensportslib/models/neck/builder.py +210 -0
- opensportslib/models/utils/common.py +176 -0
- opensportslib/models/utils/impl/__init__.py +0 -0
- opensportslib/models/utils/impl/asformer.py +390 -0
- opensportslib/models/utils/impl/calf.py +74 -0
- opensportslib/models/utils/impl/gsm.py +112 -0
- opensportslib/models/utils/impl/gtad.py +347 -0
- opensportslib/models/utils/impl/tsm.py +123 -0
- opensportslib/models/utils/litebase.py +59 -0
- opensportslib/models/utils/modules.py +120 -0
- opensportslib/models/utils/shift.py +135 -0
- opensportslib/models/utils/utils.py +276 -0
- opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
- opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
- opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
- opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1482 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2022 James Hong, Haotian Zhang, Matthew Fisher, Michael Gharbi,
|
|
3
|
+
Kayvon Fatahalian
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without modification,
|
|
6
|
+
are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation and/or
|
|
13
|
+
other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its contributors
|
|
16
|
+
may be used to endorse or promote products derived from this software without
|
|
17
|
+
specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
20
|
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
21
|
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
|
23
|
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
|
24
|
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
|
25
|
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
|
26
|
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
27
|
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
28
|
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
from collections import defaultdict
|
|
32
|
+
import os
|
|
33
|
+
import numpy as np
|
|
34
|
+
from tabulate import tabulate
|
|
35
|
+
from tqdm import tqdm
|
|
36
|
+
from torch.utils.data import DataLoader
|
|
37
|
+
from opensportslib.core.utils.config import load_json, store_gz_json, store_json
|
|
38
|
+
from opensportslib.core.utils.wandb import log_table_wandb
|
|
39
|
+
# from oslactionspotting.core.utils.score import compute_mAPs_E2E
|
|
40
|
+
import time
|
|
41
|
+
import json
|
|
42
|
+
import os
|
|
43
|
+
import sys
|
|
44
|
+
from collections import defaultdict
|
|
45
|
+
from tabulate import tabulate
|
|
46
|
+
import numpy as np
|
|
47
|
+
import matplotlib.pyplot as plt
|
|
48
|
+
import zipfile
|
|
49
|
+
import logging
|
|
50
|
+
import copy
|
|
51
|
+
from datetime import datetime
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def parse_ground_truth(truth):
|
|
55
|
+
"""Parse ground truth dict to create dict with the following structure:
|
|
56
|
+
{label : {game : list of frames index}}
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
truth : object containing list of informations about videos.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
label_dict : {label : {game : list of frames index}}.
|
|
63
|
+
"""
|
|
64
|
+
label_dict = defaultdict(lambda: defaultdict(list))
|
|
65
|
+
for x in truth:
|
|
66
|
+
for e in x["events"]:
|
|
67
|
+
label_dict[e["label"]][x["path"]].append(e["frame"])
|
|
68
|
+
return label_dict
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_predictions(pred, label=None):
|
|
72
|
+
"""Get a list of all predictions for a particular label.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
pred : Object containing predictions data.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
flat_pred (List): List of tuples containing the name of the video, the frame index and
|
|
79
|
+
the score confidence for each occurence of the particular label.
|
|
80
|
+
"""
|
|
81
|
+
flat_pred = []
|
|
82
|
+
for x in pred:
|
|
83
|
+
for e in x["events"]:
|
|
84
|
+
if label is None or e["label"] == label:
|
|
85
|
+
flat_pred.append((x["video"], e["frame"], e["confidence"]))
|
|
86
|
+
flat_pred.sort(key=lambda x: x[-1], reverse=True)
|
|
87
|
+
return flat_pred
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def compute_average_precision(
|
|
91
|
+
pred,
|
|
92
|
+
truth,
|
|
93
|
+
tolerance=0,
|
|
94
|
+
min_precision=0,
|
|
95
|
+
plot_ax=None,
|
|
96
|
+
plot_label=None,
|
|
97
|
+
plot_raw_pr=True,
|
|
98
|
+
):
|
|
99
|
+
"""Compute average precision of predictions regarding truth data with a certain tolerance.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
pred: prediction data.
|
|
103
|
+
truth: groundtruth data.
|
|
104
|
+
tolerance (int).
|
|
105
|
+
Default: 0.
|
|
106
|
+
min_precision (int).
|
|
107
|
+
Default: 0.
|
|
108
|
+
plot_ax (list): list of indexes for the axes.
|
|
109
|
+
Default: None.
|
|
110
|
+
plot_label (string): label of the plot.
|
|
111
|
+
Default: None.
|
|
112
|
+
plot_raw_pr (bool): whether to plot raw predictions.
|
|
113
|
+
Default: False.
|
|
114
|
+
Returns:
|
|
115
|
+
average precision.
|
|
116
|
+
"""
|
|
117
|
+
total = sum([len(x) for x in truth.values()])
|
|
118
|
+
recalled = set()
|
|
119
|
+
|
|
120
|
+
# The full precision curve has TOTAL number of bins, when recall increases
|
|
121
|
+
# by in increments of one
|
|
122
|
+
pc = []
|
|
123
|
+
_prev_score = 1
|
|
124
|
+
for i, (video, frame, score) in enumerate(pred, 1):
|
|
125
|
+
assert score <= _prev_score
|
|
126
|
+
_prev_score = score
|
|
127
|
+
|
|
128
|
+
# Find the ground truth frame that is closest to the prediction
|
|
129
|
+
gt_closest = None
|
|
130
|
+
for gt_frame in truth.get(video, []):
|
|
131
|
+
if (video, gt_frame) in recalled:
|
|
132
|
+
continue
|
|
133
|
+
if gt_closest is None or (abs(frame - gt_closest) > abs(frame - gt_frame)):
|
|
134
|
+
gt_closest = gt_frame
|
|
135
|
+
|
|
136
|
+
# Record precision each time a true positive is encountered
|
|
137
|
+
if gt_closest is not None and abs(frame - gt_closest) <= tolerance:
|
|
138
|
+
recalled.add((video, gt_closest))
|
|
139
|
+
p = len(recalled) / i
|
|
140
|
+
pc.append(p)
|
|
141
|
+
|
|
142
|
+
# Stop evaluation early if the precision is too low.
|
|
143
|
+
# Not used, however when nin_precision is 0.
|
|
144
|
+
if p < min_precision:
|
|
145
|
+
break
|
|
146
|
+
|
|
147
|
+
interp_pc = []
|
|
148
|
+
max_p = 0
|
|
149
|
+
for p in pc[::-1]:
|
|
150
|
+
max_p = max(p, max_p)
|
|
151
|
+
interp_pc.append(max_p)
|
|
152
|
+
interp_pc.reverse() # Not actually necessary for integration
|
|
153
|
+
|
|
154
|
+
if plot_ax is not None:
|
|
155
|
+
rc = np.arange(1, len(pc) + 1) / total
|
|
156
|
+
if plot_raw_pr:
|
|
157
|
+
plot_ax.plot(rc, pc, label=plot_label, alpha=0.8)
|
|
158
|
+
plot_ax.plot(rc, interp_pc, label=plot_label, alpha=0.8)
|
|
159
|
+
|
|
160
|
+
# Compute AUC by integrating up to TOTAL bins
|
|
161
|
+
return sum(interp_pc) / total
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def compute_mAPs_E2E(truth, pred, tolerances=[0, 1, 2, 3, 4], plot_pr=False):
|
|
165
|
+
"""Compute mAPs metric for the training module for the E2E method.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
truth : Object containing the ground truth data.
|
|
169
|
+
pred : Object containing the predictions data.
|
|
170
|
+
tolerances (List[int]): List of tolerances values.
|
|
171
|
+
Default: [0, 1, 2, 4].
|
|
172
|
+
plot_pr (bool): Whether to plot or not the precision recall curve.
|
|
173
|
+
Default: False.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
assert {v["path"] for v in truth} == {
|
|
177
|
+
v["video"] for v in pred
|
|
178
|
+
}, "Video set mismatch!"
|
|
179
|
+
|
|
180
|
+
truth_by_label = parse_ground_truth(truth)
|
|
181
|
+
|
|
182
|
+
fig, axes = None, None
|
|
183
|
+
if plot_pr:
|
|
184
|
+
fig, axes = plt.subplots(
|
|
185
|
+
len(truth_by_label),
|
|
186
|
+
len(tolerances),
|
|
187
|
+
sharex=True,
|
|
188
|
+
sharey=True,
|
|
189
|
+
figsize=(16, 16),
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
class_aps_for_tol = []
|
|
193
|
+
mAPs = []
|
|
194
|
+
for i, tol in enumerate(tolerances):
|
|
195
|
+
class_aps = []
|
|
196
|
+
for j, (label, truth_for_label) in enumerate(sorted(truth_by_label.items())):
|
|
197
|
+
ap = compute_average_precision(
|
|
198
|
+
get_predictions(pred, label=label),
|
|
199
|
+
truth_for_label,
|
|
200
|
+
tolerance=tol,
|
|
201
|
+
plot_ax=axes[j, i] if axes is not None else None,
|
|
202
|
+
)
|
|
203
|
+
class_aps.append((label, ap))
|
|
204
|
+
mAP = np.mean([x[1] for x in class_aps])
|
|
205
|
+
mAPs.append(mAP)
|
|
206
|
+
class_aps.append(("mAP", mAP))
|
|
207
|
+
class_aps_for_tol.append(class_aps)
|
|
208
|
+
header = ["Class"] + [f"AP@{t}" for t in tolerances]
|
|
209
|
+
rows = []
|
|
210
|
+
for c, _ in class_aps_for_tol[0]:
|
|
211
|
+
row = [str(c)]
|
|
212
|
+
for class_aps in class_aps_for_tol:
|
|
213
|
+
for c2, val in class_aps:
|
|
214
|
+
if c2 == c:
|
|
215
|
+
row.append(float(val) * 100)
|
|
216
|
+
rows.append(row)
|
|
217
|
+
|
|
218
|
+
log_table_wandb(name="AP@tol", rows=rows, headers=header)
|
|
219
|
+
logging.info(tabulate(rows, headers=header, floatfmt="0.2f"))
|
|
220
|
+
logging.info("Avg mAP (across tolerances): {:0.2f}".format(np.mean(mAPs) * 100))
|
|
221
|
+
|
|
222
|
+
if plot_pr:
|
|
223
|
+
for i, tol in enumerate(tolerances):
|
|
224
|
+
for j, label in enumerate(sorted(truth_by_label.keys())):
|
|
225
|
+
ax = axes[j, i]
|
|
226
|
+
ax.set_xlabel("Recall")
|
|
227
|
+
ax.set_xlim(0, 1)
|
|
228
|
+
ax.set_ylabel("Precision")
|
|
229
|
+
ax.set_ylim(0, 1.01)
|
|
230
|
+
ax.set_title("{} @ tol={}".format(label, tol))
|
|
231
|
+
plt.tight_layout()
|
|
232
|
+
plt.show()
|
|
233
|
+
plt.close(fig)
|
|
234
|
+
|
|
235
|
+
sys.stdout.flush()
|
|
236
|
+
return mAPs, tolerances
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class ErrorStat:
|
|
240
|
+
"""Class to have error statistics"""
|
|
241
|
+
|
|
242
|
+
def __init__(self):
|
|
243
|
+
self._total = 0
|
|
244
|
+
self._err = 0
|
|
245
|
+
|
|
246
|
+
def update(self, true, pred):
|
|
247
|
+
self._err += np.sum(true != pred)
|
|
248
|
+
self._total += true.shape[0]
|
|
249
|
+
|
|
250
|
+
def get(self):
|
|
251
|
+
return self._err / self._total
|
|
252
|
+
|
|
253
|
+
def get_acc(self):
|
|
254
|
+
return 1.0 - self._get()
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class ForegroundF1:
|
|
258
|
+
"""Class to have f1 scores"""
|
|
259
|
+
|
|
260
|
+
def __init__(self):
|
|
261
|
+
self._tp = defaultdict(int)
|
|
262
|
+
self._fp = defaultdict(int)
|
|
263
|
+
self._fn = defaultdict(int)
|
|
264
|
+
|
|
265
|
+
def update(self, true, pred):
|
|
266
|
+
if pred != 0:
|
|
267
|
+
if true != 0:
|
|
268
|
+
self._tp[None] += 1
|
|
269
|
+
else:
|
|
270
|
+
self._fp[None] += 1
|
|
271
|
+
|
|
272
|
+
if pred == true:
|
|
273
|
+
self._tp[pred] += 1
|
|
274
|
+
else:
|
|
275
|
+
self._fp[pred] += 1
|
|
276
|
+
if true != 0:
|
|
277
|
+
self._fn[true] += 1
|
|
278
|
+
elif true != 0:
|
|
279
|
+
self._fn[None] += 1
|
|
280
|
+
self._fn[true] += 1
|
|
281
|
+
|
|
282
|
+
def get(self, k):
|
|
283
|
+
return self._f1(k)
|
|
284
|
+
|
|
285
|
+
def tp_fp_fn(self, k):
|
|
286
|
+
return self._tp[k], self._fp[k], self._fn[k]
|
|
287
|
+
|
|
288
|
+
def _f1(self, k):
|
|
289
|
+
denom = self._tp[k] + 0.5 * self._fp[k] + 0.5 * self._fn[k]
|
|
290
|
+
if denom == 0:
|
|
291
|
+
assert self._tp[k] == 0
|
|
292
|
+
denom = 1
|
|
293
|
+
return self._tp[k] / denom
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def build_snpro_prediction_json(pred_events, head_name="action", split=None, created_by="model"):
|
|
297
|
+
"""
|
|
298
|
+
Build new v2 prediction JSON.
|
|
299
|
+
Only 'data' is required.
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
data = []
|
|
303
|
+
|
|
304
|
+
for item in pred_events:
|
|
305
|
+
video = item["video"]
|
|
306
|
+
fps = item["fps"]
|
|
307
|
+
|
|
308
|
+
events = []
|
|
309
|
+
for ev in item["events"]:
|
|
310
|
+
events.append({
|
|
311
|
+
"head": head_name,
|
|
312
|
+
"label": ev["label"],
|
|
313
|
+
"gameTime": ev["gameTime"],
|
|
314
|
+
"frame": ev["frame"],
|
|
315
|
+
"position_ms": ev["position"],
|
|
316
|
+
"confidence": ev["confidence"],
|
|
317
|
+
})
|
|
318
|
+
|
|
319
|
+
data.append({
|
|
320
|
+
"inputs": [{
|
|
321
|
+
"type": "video",
|
|
322
|
+
"path": video,
|
|
323
|
+
"fps": fps,
|
|
324
|
+
}],
|
|
325
|
+
"events": events
|
|
326
|
+
})
|
|
327
|
+
|
|
328
|
+
return {
|
|
329
|
+
"version": "2.0",
|
|
330
|
+
"date": datetime.now().strftime("%Y-%m-%d"),
|
|
331
|
+
"task": "localization",
|
|
332
|
+
"metadata": {
|
|
333
|
+
"type": "predictions",
|
|
334
|
+
"created_by": created_by,
|
|
335
|
+
"split": split,
|
|
336
|
+
},
|
|
337
|
+
"data": data
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
def process_frame_predictions(
|
|
341
|
+
dali, dataset, classes, pred_dict, high_recall_score_threshold=0.01
|
|
342
|
+
):
|
|
343
|
+
"""Process predictions by computing statistics, creating dictionnaries
|
|
344
|
+
with predictions and their associated informations
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
dali (bool): If data processing with dali or opencv.
|
|
348
|
+
dataset (Dataset or DaliGenericIterator).
|
|
349
|
+
classes (Dict): Classes associated with indexes.
|
|
350
|
+
pred_dict (Dict): Mapping between clip and a tuple of scores and support.
|
|
351
|
+
high_recall_score_threshold (float):
|
|
352
|
+
Default: 0.01.
|
|
353
|
+
|
|
354
|
+
Returns:
|
|
355
|
+
err (ErrorStat).
|
|
356
|
+
f1 (ForegroundF1).
|
|
357
|
+
pred_events (List[dict]): List of dictionnaries with video, events, fps. Only one class maximum per frame.
|
|
358
|
+
pred_events_high_recall (List[dict]): List of dictionnaries with video, events, fps. Several possible classes per frame.
|
|
359
|
+
pred_scores (dict): Mapping between videos and associated scores.
|
|
360
|
+
"""
|
|
361
|
+
classes_inv = {v: k for k, v in classes.items()}
|
|
362
|
+
|
|
363
|
+
fps_dict = {}
|
|
364
|
+
|
|
365
|
+
for video, _, fps in dataset.videos:
|
|
366
|
+
fps_dict[video] = fps
|
|
367
|
+
|
|
368
|
+
err = ErrorStat()
|
|
369
|
+
f1 = ForegroundF1()
|
|
370
|
+
|
|
371
|
+
pred_events = []
|
|
372
|
+
pred_events_high_recall = []
|
|
373
|
+
pred_scores = {}
|
|
374
|
+
for video, (scores, support) in sorted(pred_dict.items()):
|
|
375
|
+
label = dataset.get_labels(video)
|
|
376
|
+
|
|
377
|
+
# support[support == 0] = 1 # get rid of divide by zero
|
|
378
|
+
if dali:
|
|
379
|
+
assert np.min(support[1:]) > 0, (video, support[1:].tolist())
|
|
380
|
+
scores[1:] /= support[1:, None]
|
|
381
|
+
pred = np.argmax(scores[1:], axis=1)
|
|
382
|
+
err.update(label[1:], pred)
|
|
383
|
+
else:
|
|
384
|
+
assert np.min(support) > 0, (video, support.tolist())
|
|
385
|
+
scores /= support[:, None]
|
|
386
|
+
pred = np.argmax(scores, axis=1)
|
|
387
|
+
err.update(label, pred)
|
|
388
|
+
|
|
389
|
+
pred_scores[video] = scores.tolist()
|
|
390
|
+
|
|
391
|
+
events = []
|
|
392
|
+
events_high_recall = []
|
|
393
|
+
# for i in range(1,pred.shape[0]):
|
|
394
|
+
for i in range(pred.shape[0]):
|
|
395
|
+
|
|
396
|
+
if dali:
|
|
397
|
+
f1.update(label[i + 1], pred[i])
|
|
398
|
+
else:
|
|
399
|
+
f1.update(label[i], pred[i])
|
|
400
|
+
if pred[i] != 0:
|
|
401
|
+
if dali:
|
|
402
|
+
tmp = i + 1
|
|
403
|
+
else:
|
|
404
|
+
tmp = i
|
|
405
|
+
seconds = int((tmp // fps_dict[video]) % 60)
|
|
406
|
+
minutes = int((tmp // fps_dict[video]) // 60)
|
|
407
|
+
events.append(
|
|
408
|
+
{
|
|
409
|
+
"label": classes_inv[pred[i]],
|
|
410
|
+
"position": int((tmp * 1000) / fps_dict[video]),
|
|
411
|
+
"gameTime": f"{minutes:02.0f}:{seconds:02.0f}",
|
|
412
|
+
# 'frame': i,
|
|
413
|
+
"frame": tmp,
|
|
414
|
+
"confidence": scores[tmp, pred[i]].item(),
|
|
415
|
+
# 'score': scores[i, pred[i]].item()
|
|
416
|
+
}
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
for j in classes_inv:
|
|
420
|
+
if dali:
|
|
421
|
+
tmp = i + 1
|
|
422
|
+
else:
|
|
423
|
+
tmp = i
|
|
424
|
+
if scores[tmp, j] >= high_recall_score_threshold:
|
|
425
|
+
# if scores[i, j] >= high_recall_score_threshold:
|
|
426
|
+
seconds = int((tmp // fps_dict[video]) % 60)
|
|
427
|
+
minutes = int((tmp // fps_dict[video]) // 60)
|
|
428
|
+
events_high_recall.append(
|
|
429
|
+
{
|
|
430
|
+
"label": classes_inv[j],
|
|
431
|
+
"position": int((tmp * 1000) / fps_dict[video]),
|
|
432
|
+
"gameTime": f"{minutes:02.0f}:{seconds:02.0f}",
|
|
433
|
+
"frame": tmp,
|
|
434
|
+
# 'frame': i,
|
|
435
|
+
"confidence": scores[tmp, j].item(),
|
|
436
|
+
# 'score': scores[i, j].item()
|
|
437
|
+
}
|
|
438
|
+
)
|
|
439
|
+
pred_events.append({"video": video, "events": events, "fps": fps_dict[video]})
|
|
440
|
+
pred_events_high_recall.append(
|
|
441
|
+
{"video": video, "events": events_high_recall, "fps": fps_dict[video]}
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
return err, f1, pred_events, pred_events_high_recall, pred_scores
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def infer_and_process_predictions_e2e(
|
|
448
|
+
model,
|
|
449
|
+
dali,
|
|
450
|
+
dataset,
|
|
451
|
+
split,
|
|
452
|
+
classes,
|
|
453
|
+
save_pred,
|
|
454
|
+
calc_stats=True,
|
|
455
|
+
dataloader_params=None,
|
|
456
|
+
return_pred=False,
|
|
457
|
+
):
|
|
458
|
+
"""Infer prediction of actions from clips, process these predictions.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
model .
|
|
462
|
+
dali (bool): Whether dali has been used or opencv to process videos.
|
|
463
|
+
dataset (Dataset or DaliGenericIterator).
|
|
464
|
+
split (string): Split of the data.
|
|
465
|
+
classes (dict) : Classes associated with indexes.
|
|
466
|
+
save_pred (bool) : Save predictions or not.
|
|
467
|
+
calc_stats (bool) : display stats or not.
|
|
468
|
+
Default: True.
|
|
469
|
+
dataloader_params (dict): Parameters for the dataloader.
|
|
470
|
+
Default: None.
|
|
471
|
+
return_pred (bool): Return dict of predictions or not.
|
|
472
|
+
Default: False
|
|
473
|
+
|
|
474
|
+
Returns:
|
|
475
|
+
pred_events_high_recall (List[dict]): List of dictionnaries with video, events, fps. Several possible classes per frame.
|
|
476
|
+
avg_mAP (float): Average mean AP computed for the predictions.
|
|
477
|
+
"""
|
|
478
|
+
# print(dataset.)
|
|
479
|
+
pred_dict = {}
|
|
480
|
+
for video, video_len, _ in dataset.videos:
|
|
481
|
+
pred_dict[video] = (
|
|
482
|
+
np.zeros((video_len, len(classes) + 1), np.float32),
|
|
483
|
+
np.zeros(video_len, np.int32),
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
batch_size = dataloader_params.batch_size
|
|
487
|
+
|
|
488
|
+
for clip in tqdm(
|
|
489
|
+
dataset
|
|
490
|
+
if dali
|
|
491
|
+
else DataLoader(
|
|
492
|
+
dataset,
|
|
493
|
+
num_workers=dataloader_params.num_workers,
|
|
494
|
+
pin_memory=dataloader_params.pin_memory,
|
|
495
|
+
batch_size=batch_size,
|
|
496
|
+
)
|
|
497
|
+
):
|
|
498
|
+
if batch_size > 1:
|
|
499
|
+
# Batched by dataloader
|
|
500
|
+
_, batch_pred_scores = model.predict(clip["frame"])
|
|
501
|
+
for i in range(clip["frame"].shape[0]):
|
|
502
|
+
video = clip["video"][i]
|
|
503
|
+
scores, support = pred_dict[video]
|
|
504
|
+
pred_scores = batch_pred_scores[i]
|
|
505
|
+
start = clip["start"][i].item()
|
|
506
|
+
if start < 0:
|
|
507
|
+
pred_scores = pred_scores[-start:, :]
|
|
508
|
+
start = 0
|
|
509
|
+
end = start + pred_scores.shape[0]
|
|
510
|
+
if end >= scores.shape[0]:
|
|
511
|
+
end = scores.shape[0]
|
|
512
|
+
pred_scores = pred_scores[: end - start, :]
|
|
513
|
+
scores[start:end, :] += pred_scores
|
|
514
|
+
support[start:end] += 1
|
|
515
|
+
|
|
516
|
+
else:
|
|
517
|
+
# Batched by dataset
|
|
518
|
+
scores, support = pred_dict[clip["video"][0]]
|
|
519
|
+
|
|
520
|
+
start = clip["start"][0].item()
|
|
521
|
+
# start=start-1
|
|
522
|
+
_, pred_scores = model.predict(clip["frame"][0])
|
|
523
|
+
if start < 0:
|
|
524
|
+
pred_scores = pred_scores[:, -start:, :]
|
|
525
|
+
start = 0
|
|
526
|
+
end = start + pred_scores.shape[1]
|
|
527
|
+
if end >= scores.shape[0]:
|
|
528
|
+
end = scores.shape[0]
|
|
529
|
+
pred_scores = pred_scores[:, : end - start, :]
|
|
530
|
+
|
|
531
|
+
scores[start:end, :] += np.sum(pred_scores, axis=0)
|
|
532
|
+
support[start:end] += pred_scores.shape[0]
|
|
533
|
+
|
|
534
|
+
err, f1, pred_events, pred_events_high_recall, pred_scores = (
|
|
535
|
+
process_frame_predictions(dali, dataset, classes, pred_dict)
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
avg_mAP = None
|
|
539
|
+
if calc_stats:
|
|
540
|
+
logging.info(f"=== Results on {split} (w/o NMS) ===")
|
|
541
|
+
logging.info(f"Error (frame-level): {err.get() * 100:.2f}")
|
|
542
|
+
|
|
543
|
+
def get_f1_tab_row(str_k):
|
|
544
|
+
k = classes[str_k] if str_k != "any" else None
|
|
545
|
+
return [str_k, f1.get(k) * 100, *f1.tp_fp_fn(k)]
|
|
546
|
+
|
|
547
|
+
rows = [get_f1_tab_row("any")]
|
|
548
|
+
for c in sorted(classes):
|
|
549
|
+
rows.append(get_f1_tab_row(c))
|
|
550
|
+
header = ["Exact frame", "F1", "TP", "FP", "FN"]
|
|
551
|
+
logging.info(
|
|
552
|
+
tabulate(
|
|
553
|
+
rows, headers=header, floatfmt="0.2f"
|
|
554
|
+
)
|
|
555
|
+
)
|
|
556
|
+
log_table_wandb(name="Confusion matrix table", rows=rows, headers=header)
|
|
557
|
+
|
|
558
|
+
mAPs, _ = compute_mAPs_E2E(dataset.labels, pred_events_high_recall)
|
|
559
|
+
avg_mAP = np.mean(mAPs[1:])
|
|
560
|
+
|
|
561
|
+
pred_events = build_snpro_prediction_json(pred_events, head_name=dataset.task_name, split=split, created_by="model")
|
|
562
|
+
pred_events_high_recall = build_snpro_prediction_json(pred_events_high_recall, head_name=dataset.task_name, split=split, created_by="model")
|
|
563
|
+
if save_pred is not None:
|
|
564
|
+
store_json(save_pred + ".json", pred_events, pretty=True)
|
|
565
|
+
store_json(save_pred + ".high_recall.json", pred_events_high_recall, pretty=True)
|
|
566
|
+
store_gz_json(save_pred + ".recall.json.gz", pred_events_high_recall)
|
|
567
|
+
# if save_scores:
|
|
568
|
+
# store_gz_json(save_pred + '.score.json.gz', pred_scores)
|
|
569
|
+
if return_pred:
|
|
570
|
+
return pred_events_high_recall
|
|
571
|
+
|
|
572
|
+
logging.info(f"avg_mAP: {avg_mAP}")
|
|
573
|
+
return avg_mAP
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
def search_best_epoch(work_dir):
|
|
577
|
+
"""
|
|
578
|
+
Args:
|
|
579
|
+
work_dir (string): Path in which there is the json file that contains losses for each epoch.
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
epoch/epoch_mAP (int): The best epoch.
|
|
583
|
+
"""
|
|
584
|
+
loss = load_json(os.path.join(work_dir, "loss.json"))
|
|
585
|
+
valid_mAP = 0
|
|
586
|
+
valid = float("inf")
|
|
587
|
+
epoch = -1
|
|
588
|
+
epoch_mAP = -1
|
|
589
|
+
for epoch_loss in loss:
|
|
590
|
+
if epoch_loss["valid_mAP"] > valid_mAP:
|
|
591
|
+
valid_mAP = epoch_loss["valid_mAP"]
|
|
592
|
+
epoch_mAP = epoch_loss["epoch"]
|
|
593
|
+
if epoch_loss["valid"] < valid:
|
|
594
|
+
valid = epoch_loss["valid"]
|
|
595
|
+
epoch = epoch_loss["epoch"]
|
|
596
|
+
if epoch_mAP != -1:
|
|
597
|
+
return epoch_mAP
|
|
598
|
+
else:
|
|
599
|
+
return epoch
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def non_maximum_supression(pred, window):
|
|
603
|
+
"""Non maximum suppression for predictions for a window size
|
|
604
|
+
|
|
605
|
+
Args:
|
|
606
|
+
pred (List[dict]): List of dictionnaries that contain predictions per video
|
|
607
|
+
window (int): The window size between frames.
|
|
608
|
+
|
|
609
|
+
Returns:
|
|
610
|
+
new_pred (List[dict]) : The predictions after the nms.
|
|
611
|
+
"""
|
|
612
|
+
new_pred = []
|
|
613
|
+
for video_pred in pred:
|
|
614
|
+
events_by_label = defaultdict(list)
|
|
615
|
+
for e in video_pred["events"]:
|
|
616
|
+
events_by_label[e["label"]].append(e)
|
|
617
|
+
|
|
618
|
+
events = []
|
|
619
|
+
for v in events_by_label.values():
|
|
620
|
+
for e1 in v:
|
|
621
|
+
for e2 in v:
|
|
622
|
+
if (
|
|
623
|
+
e1["frame"] != e2["frame"]
|
|
624
|
+
and abs(e1["frame"] - e2["frame"]) <= window
|
|
625
|
+
and e1["confidence"] < e2["confidence"]
|
|
626
|
+
):
|
|
627
|
+
# Found another prediction in the window that has a
|
|
628
|
+
# higher score
|
|
629
|
+
break
|
|
630
|
+
else:
|
|
631
|
+
events.append(e1)
|
|
632
|
+
events.sort(key=lambda x: x["frame"])
|
|
633
|
+
new_video_pred = copy.deepcopy(video_pred)
|
|
634
|
+
new_video_pred["events"] = events
|
|
635
|
+
new_video_pred["num_events"] = len(events)
|
|
636
|
+
new_pred.append(new_video_pred)
|
|
637
|
+
return new_pred
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
# def store_eval_files_json(raw_pred, eval_dir):
|
|
641
|
+
# """
|
|
642
|
+
# Store predictions.
|
|
643
|
+
|
|
644
|
+
# If eval_dir ends with .json → store NEW v2 unified file
|
|
645
|
+
# Else → store OLD per-video SN files (original behavior)
|
|
646
|
+
# """
|
|
647
|
+
|
|
648
|
+
# # ==========================================================
|
|
649
|
+
# # 🟢 NEW V2 FORMAT (single json file)
|
|
650
|
+
# # ==========================================================
|
|
651
|
+
# if eval_dir.endswith(".json"):
|
|
652
|
+
# from datetime import datetime
|
|
653
|
+
|
|
654
|
+
# data = []
|
|
655
|
+
|
|
656
|
+
# for obj in raw_pred:
|
|
657
|
+
# video = obj["video"]
|
|
658
|
+
# fps = obj["fps"]
|
|
659
|
+
|
|
660
|
+
# events = []
|
|
661
|
+
# for ev in obj["events"]:
|
|
662
|
+
# events.append({
|
|
663
|
+
# "head": "action", # generic
|
|
664
|
+
# "label": ev["label"],
|
|
665
|
+
# "position_ms": ev["position"],
|
|
666
|
+
# "confidence": ev.get("confidence", 1.0),
|
|
667
|
+
# })
|
|
668
|
+
|
|
669
|
+
# data.append({
|
|
670
|
+
# "inputs": [{
|
|
671
|
+
# "type": "video",
|
|
672
|
+
# "path": video,
|
|
673
|
+
# "fps": fps,
|
|
674
|
+
# }],
|
|
675
|
+
# "events": events,
|
|
676
|
+
# })
|
|
677
|
+
|
|
678
|
+
# out = {
|
|
679
|
+
# "version": "2.0",
|
|
680
|
+
# "date": datetime.now().strftime("%Y-%m-%d"),
|
|
681
|
+
# "task": "action_spotting",
|
|
682
|
+
# "metadata": {
|
|
683
|
+
# "type": "predictions",
|
|
684
|
+
# "notes": "Generated by model"
|
|
685
|
+
# },
|
|
686
|
+
# "data": data,
|
|
687
|
+
# }
|
|
688
|
+
|
|
689
|
+
# os.makedirs(os.path.dirname(eval_dir), exist_ok=True)
|
|
690
|
+
# with open(eval_dir, "w") as f:
|
|
691
|
+
# json.dump(out, f, indent=2)
|
|
692
|
+
|
|
693
|
+
# logging.info(f"Stored V2 predictions → {eval_dir}")
|
|
694
|
+
# return True
|
|
695
|
+
|
|
696
|
+
# # ==========================================================
|
|
697
|
+
# # 🔵 OLD FORMAT (original code untouched)
|
|
698
|
+
# # ==========================================================
|
|
699
|
+
# only_one_file = False
|
|
700
|
+
# video_pred = defaultdict(list)
|
|
701
|
+
# video_fps = defaultdict(list)
|
|
702
|
+
|
|
703
|
+
# for obj in raw_pred:
|
|
704
|
+
# video = obj["video"]
|
|
705
|
+
# video_fps[video] = obj["fps"]
|
|
706
|
+
|
|
707
|
+
# for event in obj["events"]:
|
|
708
|
+
# video_pred[video].append(
|
|
709
|
+
# {
|
|
710
|
+
# "frame": event["frame"],
|
|
711
|
+
# "label": event["label"],
|
|
712
|
+
# "confidence": event["confidence"],
|
|
713
|
+
# "position": event["position"],
|
|
714
|
+
# "gameTime": event["gameTime"],
|
|
715
|
+
# }
|
|
716
|
+
# )
|
|
717
|
+
|
|
718
|
+
# for video, pred in video_pred.items():
|
|
719
|
+
# if len(raw_pred) == 1:
|
|
720
|
+
# video_out_dir = eval_dir
|
|
721
|
+
# only_one_file = True
|
|
722
|
+
# else:
|
|
723
|
+
# video_out_dir = os.path.join(eval_dir, os.path.splitext(video)[0])
|
|
724
|
+
|
|
725
|
+
# os.makedirs(video_out_dir, exist_ok=True)
|
|
726
|
+
|
|
727
|
+
# store_json(
|
|
728
|
+
# os.path.join(video_out_dir, "results_spotting.json"),
|
|
729
|
+
# {"Url": video, "predictions": pred, "fps": video_fps[video]},
|
|
730
|
+
# pretty=True,
|
|
731
|
+
# )
|
|
732
|
+
|
|
733
|
+
# logging.info(f"Stored OLD format → {eval_dir}")
|
|
734
|
+
# return only_one_file
|
|
735
|
+
|
|
736
|
+
def store_eval_files_json(raw_pred, eval_dir, save_v2=True):
|
|
737
|
+
"""
|
|
738
|
+
Store predictions.
|
|
739
|
+
|
|
740
|
+
Supports:
|
|
741
|
+
- old internal format (list)
|
|
742
|
+
- new v2 format (dict with 'data')
|
|
743
|
+
|
|
744
|
+
Store predictions.
|
|
745
|
+
|
|
746
|
+
save_v2=True → store NEW unified v2 json
|
|
747
|
+
save_v2=False → original SN behavior.
|
|
748
|
+
"""
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
# ==========================================================
|
|
752
|
+
# 🟢 NEW V2 FORMAT
|
|
753
|
+
# ==========================================================
|
|
754
|
+
only_one_file = False
|
|
755
|
+
|
|
756
|
+
if save_v2:
|
|
757
|
+
from datetime import datetime
|
|
758
|
+
|
|
759
|
+
for obj in raw_pred:
|
|
760
|
+
video = obj["video"]
|
|
761
|
+
fps = obj["fps"]
|
|
762
|
+
|
|
763
|
+
# directory logic SAME as old
|
|
764
|
+
if len(raw_pred) == 1:
|
|
765
|
+
video_out_dir = eval_dir
|
|
766
|
+
only_one_file = True
|
|
767
|
+
else:
|
|
768
|
+
video_out_dir = os.path.join(eval_dir, os.path.splitext(video)[0])
|
|
769
|
+
|
|
770
|
+
os.makedirs(video_out_dir, exist_ok=True)
|
|
771
|
+
|
|
772
|
+
events = []
|
|
773
|
+
for ev in obj["events"]:
|
|
774
|
+
events.append({
|
|
775
|
+
"head": "action",
|
|
776
|
+
"label": ev.get("label"),
|
|
777
|
+
"frame": ev.get("frame"),
|
|
778
|
+
"position_ms": int(ev.get("position")),
|
|
779
|
+
"confidence": float(ev.get("confidence")),
|
|
780
|
+
"gameTime": ev.get("gameTime")
|
|
781
|
+
})
|
|
782
|
+
|
|
783
|
+
out = {
|
|
784
|
+
"version": "2.0",
|
|
785
|
+
"date": datetime.now().strftime("%Y-%m-%d"),
|
|
786
|
+
"task": "action_spotting",
|
|
787
|
+
"metadata": {"type": "predictions"},
|
|
788
|
+
"data": [{
|
|
789
|
+
"inputs": [{
|
|
790
|
+
"type": "video",
|
|
791
|
+
"path": video,
|
|
792
|
+
"fps": fps,
|
|
793
|
+
}],
|
|
794
|
+
"events": events,
|
|
795
|
+
}],
|
|
796
|
+
}
|
|
797
|
+
|
|
798
|
+
out_path = os.path.join(video_out_dir, "results_spotting.json")
|
|
799
|
+
with open(out_path, "w") as f:
|
|
800
|
+
json.dump(out, f, indent=2)
|
|
801
|
+
|
|
802
|
+
logging.info(f"Stored V2 predictions → {eval_dir}")
|
|
803
|
+
return only_one_file
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
# ==========================================================
|
|
807
|
+
# 🔵 OLD FORMAT (UNCHANGED)
|
|
808
|
+
# ==========================================================
|
|
809
|
+
only_one_file = False
|
|
810
|
+
video_pred = defaultdict(list)
|
|
811
|
+
video_fps = defaultdict(list)
|
|
812
|
+
|
|
813
|
+
for obj in raw_pred:
|
|
814
|
+
video = obj["video"]
|
|
815
|
+
video_fps[video] = obj["fps"]
|
|
816
|
+
|
|
817
|
+
for event in obj["events"]:
|
|
818
|
+
video_pred[video].append(
|
|
819
|
+
{
|
|
820
|
+
"frame": event["frame"],
|
|
821
|
+
"label": event["label"],
|
|
822
|
+
"confidence": event["confidence"],
|
|
823
|
+
"position": event["position"],
|
|
824
|
+
"gameTime": event["gameTime"],
|
|
825
|
+
}
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
for video, pred in video_pred.items():
|
|
829
|
+
if len(raw_pred) == 1:
|
|
830
|
+
video_out_dir = eval_dir
|
|
831
|
+
only_one_file = True
|
|
832
|
+
else:
|
|
833
|
+
video_out_dir = os.path.join(eval_dir, os.path.splitext(video)[0])
|
|
834
|
+
|
|
835
|
+
os.makedirs(video_out_dir, exist_ok=True)
|
|
836
|
+
|
|
837
|
+
store_json(
|
|
838
|
+
os.path.join(video_out_dir, "results_spotting.json"),
|
|
839
|
+
{"Url": video, "predictions": pred, "fps": video_fps[video]},
|
|
840
|
+
pretty=True,
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
logging.info(f"Stored OLD format → {eval_dir}")
|
|
844
|
+
return only_one_file
|
|
845
|
+
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def label2vector(
|
|
850
|
+
labels, num_classes=17, framerate=2, EVENT_DICTIONARY={}, vector_size=None
|
|
851
|
+
):
|
|
852
|
+
"""Transform list of dict containing ground truth labels into a 2D array.
|
|
853
|
+
Args:
|
|
854
|
+
labels (List[dict]): List of groundtruth labels for a video.
|
|
855
|
+
num_classes (int): Number of classes.
|
|
856
|
+
Default: 17.
|
|
857
|
+
framerate (int): Rate at which the frames have been processed in a video.
|
|
858
|
+
Default: 2.
|
|
859
|
+
EVENT_DICTIONARY (dict): Mapping between classes and indexes.
|
|
860
|
+
Default: {}.
|
|
861
|
+
vector_size (int): Size of the returned vector.
|
|
862
|
+
Default: None.
|
|
863
|
+
|
|
864
|
+
Returns:
|
|
865
|
+
dense_labels (Numpy array): Array of shape (vector_size, num_classes) containing a value.
|
|
866
|
+
The first index is the frame number and the second one is the class. The value can be -1 is the event
|
|
867
|
+
is unshown, 1 is the event is visible and 0 otherwise, meaning that the event do not occur for this frame.
|
|
868
|
+
|
|
869
|
+
"""
|
|
870
|
+
vector_size = int(90 * 60 * framerate if vector_size is None else vector_size)
|
|
871
|
+
|
|
872
|
+
dense_labels = np.zeros((vector_size, num_classes))
|
|
873
|
+
|
|
874
|
+
for annotation in labels:
|
|
875
|
+
|
|
876
|
+
#print(annotation)
|
|
877
|
+
event = annotation["label"]
|
|
878
|
+
if "frame" in annotation:
|
|
879
|
+
frame = int(annotation["frame"])
|
|
880
|
+
|
|
881
|
+
else:
|
|
882
|
+
time = annotation["gameTime"]
|
|
883
|
+
|
|
884
|
+
minutes = int(time[-5:-3])
|
|
885
|
+
seconds = int(time[-2::])
|
|
886
|
+
# annotation at millisecond precision
|
|
887
|
+
if "position" in annotation:
|
|
888
|
+
frame = int(framerate * (int(annotation["position"]) / 1000))
|
|
889
|
+
# annotation at second precision
|
|
890
|
+
else:
|
|
891
|
+
frame = framerate * (seconds + 60 * minutes)
|
|
892
|
+
|
|
893
|
+
label = EVENT_DICTIONARY[event]
|
|
894
|
+
|
|
895
|
+
value = 1
|
|
896
|
+
if "visibility" in annotation.keys():
|
|
897
|
+
if annotation["visibility"] == "not shown":
|
|
898
|
+
value = -1
|
|
899
|
+
|
|
900
|
+
frame = min(frame, vector_size - 1)
|
|
901
|
+
dense_labels[frame][label] = value
|
|
902
|
+
|
|
903
|
+
return dense_labels
|
|
904
|
+
|
|
905
|
+
|
|
906
|
+
def predictions2vector(
|
|
907
|
+
predictions, num_classes=17, framerate=2, EVENT_DICTIONARY={}, vector_size=None
|
|
908
|
+
):
|
|
909
|
+
"""Transform list of dict containing predictions into a 2D array.
|
|
910
|
+
Args:
|
|
911
|
+
predictions (List[dict]): List of predictions for a video.
|
|
912
|
+
num_classes (int): Number of classes.
|
|
913
|
+
Default: 17.
|
|
914
|
+
framerate (int): Rate at which the frames have been processed in a video.
|
|
915
|
+
Default: 2.
|
|
916
|
+
EVENT_DICTIONARY (dict): Mapping between classes and indexes.
|
|
917
|
+
Default: {}.
|
|
918
|
+
vector_size (int): Size of the returned vector.
|
|
919
|
+
Default: None.
|
|
920
|
+
|
|
921
|
+
Returns:
|
|
922
|
+
dense_predictions (Numpy array): Array of shape (vector_size, num_classes) containing a value.
|
|
923
|
+
The first index is the frame number and the second one is the class. The value is the score of the class.
|
|
924
|
+
|
|
925
|
+
"""
|
|
926
|
+
vector_size = int(90 * 60 * framerate if vector_size is None else vector_size)
|
|
927
|
+
|
|
928
|
+
dense_predictions = np.zeros((vector_size, num_classes)) - 1
|
|
929
|
+
|
|
930
|
+
for annotation in predictions:
|
|
931
|
+
|
|
932
|
+
event = annotation["label"]
|
|
933
|
+
|
|
934
|
+
if "frame" in annotation:
|
|
935
|
+
frame = int(annotation["frame"])
|
|
936
|
+
else:
|
|
937
|
+
time = int(annotation["position"])
|
|
938
|
+
|
|
939
|
+
# half = int(annotation["half"])
|
|
940
|
+
|
|
941
|
+
frame = int(framerate * (time / 1000))
|
|
942
|
+
|
|
943
|
+
label = EVENT_DICTIONARY[event]
|
|
944
|
+
|
|
945
|
+
frame = min(frame, vector_size - 1)
|
|
946
|
+
dense_predictions[frame][label] = annotation["confidence"]
|
|
947
|
+
|
|
948
|
+
return dense_predictions
|
|
949
|
+
|
|
950
|
+
|
|
951
|
+
np.seterr(divide="ignore", invalid="ignore")
|
|
952
|
+
|
|
953
|
+
|
|
954
|
+
class AverageMeter(object):
|
|
955
|
+
"""Computes and stores the average and current value"""
|
|
956
|
+
|
|
957
|
+
def __init__(self):
|
|
958
|
+
self.reset()
|
|
959
|
+
|
|
960
|
+
def reset(self):
|
|
961
|
+
self.val = 0
|
|
962
|
+
self.avg = 0
|
|
963
|
+
self.sum = 0
|
|
964
|
+
self.count = 0
|
|
965
|
+
|
|
966
|
+
def update(self, val, n=1):
|
|
967
|
+
self.val = val
|
|
968
|
+
self.sum += val * n
|
|
969
|
+
self.count += n
|
|
970
|
+
self.avg = self.sum / self.count
|
|
971
|
+
|
|
972
|
+
|
|
973
|
+
def compute_class_scores(target, closest, detection, delta):
|
|
974
|
+
"""Compute the scores for a single class.
|
|
975
|
+
|
|
976
|
+
Args:
|
|
977
|
+
target (np.array(vector_size): List of ground truth targets of shape (number of frames).
|
|
978
|
+
closest (np.array(vector_size - 1): List of closest action index of shape (number of frames - 1).
|
|
979
|
+
detection (np.array(vector_size): List of predictions of shape (number of frames).
|
|
980
|
+
delta (int): Tolerance.
|
|
981
|
+
|
|
982
|
+
Returns:
|
|
983
|
+
game_detections (np.array(number of predictions,3): Array to save the scores with the first column being
|
|
984
|
+
the prediction score, the second one being 1 if an event have been found with requirements, the third one being
|
|
985
|
+
the score of the closest action.
|
|
986
|
+
len(gt_indexes_visible) (int): number of visible events.
|
|
987
|
+
len(gt_indexes_unshown) (int): number of unshown events.
|
|
988
|
+
|
|
989
|
+
"""
|
|
990
|
+
|
|
991
|
+
# Retrieving the important variables
|
|
992
|
+
gt_indexes = np.where(target != 0)[0]
|
|
993
|
+
gt_indexes_visible = np.where(target > 0)[0]
|
|
994
|
+
gt_indexes_unshown = np.where(target < 0)[0]
|
|
995
|
+
pred_indexes = np.where(detection >= 0)[0]
|
|
996
|
+
pred_scores = detection[pred_indexes]
|
|
997
|
+
|
|
998
|
+
# Array to save the results, each is [pred_scor,{1 or 0},closest[pred_indexes]]
|
|
999
|
+
game_detections = np.zeros((len(pred_indexes), 3))
|
|
1000
|
+
game_detections[:, 0] = np.copy(pred_scores)
|
|
1001
|
+
game_detections[:, 2] = np.copy(closest[pred_indexes])
|
|
1002
|
+
|
|
1003
|
+
remove_indexes = list()
|
|
1004
|
+
|
|
1005
|
+
for gt_index in gt_indexes:
|
|
1006
|
+
|
|
1007
|
+
max_score = -1
|
|
1008
|
+
max_index = None
|
|
1009
|
+
game_index = 0
|
|
1010
|
+
selected_game_index = 0
|
|
1011
|
+
|
|
1012
|
+
for pred_index, pred_score in zip(pred_indexes, pred_scores):
|
|
1013
|
+
|
|
1014
|
+
if pred_index < gt_index - delta:
|
|
1015
|
+
game_index += 1
|
|
1016
|
+
continue
|
|
1017
|
+
if pred_index > gt_index + delta:
|
|
1018
|
+
break
|
|
1019
|
+
|
|
1020
|
+
if (
|
|
1021
|
+
abs(pred_index - gt_index) <= delta / 2
|
|
1022
|
+
and pred_score > max_score
|
|
1023
|
+
and pred_index not in remove_indexes
|
|
1024
|
+
):
|
|
1025
|
+
max_score = pred_score
|
|
1026
|
+
max_index = pred_index
|
|
1027
|
+
selected_game_index = game_index
|
|
1028
|
+
game_index += 1
|
|
1029
|
+
|
|
1030
|
+
if max_index is not None:
|
|
1031
|
+
game_detections[selected_game_index, 1] = 1
|
|
1032
|
+
remove_indexes.append(max_index)
|
|
1033
|
+
|
|
1034
|
+
return game_detections, len(gt_indexes_visible), len(gt_indexes_unshown)
|
|
1035
|
+
|
|
1036
|
+
|
|
1037
|
+
def compute_precision_recall_curve(targets, closests, detections, delta):
|
|
1038
|
+
"""Compute precision recall curve.
|
|
1039
|
+
|
|
1040
|
+
Args:
|
|
1041
|
+
target (List(np.array(vector_size,num_classes)): List of ground truth targets of shape (number of videos, number of frames,number of classes).
|
|
1042
|
+
closest (List(np.array(vector_size - 1,num_classes)): List of closest action index of shape (number of videos, number of frames - 1,number of classes).
|
|
1043
|
+
detection (List(np.array(vector_size,num_classes)): List of predictions of shape (number of videos, number of frames,number of classes).
|
|
1044
|
+
delta (int): Tolerance.
|
|
1045
|
+
|
|
1046
|
+
Returns:
|
|
1047
|
+
precision: List of precision points.
|
|
1048
|
+
recall: List of recall points.
|
|
1049
|
+
precision_visible: List of precision points for only the visible events.
|
|
1050
|
+
recall_visible: List of recall points for only the visible events.
|
|
1051
|
+
precision_unshown: List of precision points for only the unshown events.
|
|
1052
|
+
recall_unshown: List of recall points for only the unshown events.
|
|
1053
|
+
|
|
1054
|
+
"""
|
|
1055
|
+
|
|
1056
|
+
# Store the number of classes
|
|
1057
|
+
num_classes = targets[0].shape[-1]
|
|
1058
|
+
|
|
1059
|
+
# 200 confidence thresholds between [0,1]
|
|
1060
|
+
thresholds = np.linspace(0, 1, 200)
|
|
1061
|
+
|
|
1062
|
+
# Store the precision and recall points
|
|
1063
|
+
precision = list()
|
|
1064
|
+
recall = list()
|
|
1065
|
+
precision_visible = list()
|
|
1066
|
+
recall_visible = list()
|
|
1067
|
+
precision_unshown = list()
|
|
1068
|
+
recall_unshown = list()
|
|
1069
|
+
|
|
1070
|
+
# Apply Non-Maxima Suppression if required
|
|
1071
|
+
start = time.time()
|
|
1072
|
+
|
|
1073
|
+
# Precompute the predictions scores and their correspondence {TP, FP} for each class
|
|
1074
|
+
for c in np.arange(num_classes):
|
|
1075
|
+
total_detections = np.zeros((1, 3))
|
|
1076
|
+
total_detections[0, 0] = -1
|
|
1077
|
+
n_gt_labels_visible = 0
|
|
1078
|
+
n_gt_labels_unshown = 0
|
|
1079
|
+
|
|
1080
|
+
# Get the confidence scores and their corresponding TP or FP characteristics for each game
|
|
1081
|
+
for target, closest, detection in zip(targets, closests, detections):
|
|
1082
|
+
tmp_detections, tmp_n_gt_labels_visible, tmp_n_gt_labels_unshown = (
|
|
1083
|
+
compute_class_scores(
|
|
1084
|
+
target[:, c], closest[:, c], detection[:, c], delta
|
|
1085
|
+
)
|
|
1086
|
+
)
|
|
1087
|
+
total_detections = np.append(total_detections, tmp_detections, axis=0)
|
|
1088
|
+
n_gt_labels_visible = n_gt_labels_visible + tmp_n_gt_labels_visible
|
|
1089
|
+
n_gt_labels_unshown = n_gt_labels_unshown + tmp_n_gt_labels_unshown
|
|
1090
|
+
|
|
1091
|
+
precision.append(list())
|
|
1092
|
+
recall.append(list())
|
|
1093
|
+
precision_visible.append(list())
|
|
1094
|
+
recall_visible.append(list())
|
|
1095
|
+
precision_unshown.append(list())
|
|
1096
|
+
recall_unshown.append(list())
|
|
1097
|
+
|
|
1098
|
+
# Get only the visible or unshown actions
|
|
1099
|
+
total_detections_visible = np.copy(total_detections)
|
|
1100
|
+
total_detections_unshown = np.copy(total_detections)
|
|
1101
|
+
total_detections_visible[
|
|
1102
|
+
np.where(total_detections_visible[:, 2] <= 0.5)[0], 0
|
|
1103
|
+
] = -1
|
|
1104
|
+
total_detections_unshown[
|
|
1105
|
+
np.where(total_detections_unshown[:, 2] >= -0.5)[0], 0
|
|
1106
|
+
] = -1
|
|
1107
|
+
|
|
1108
|
+
# Get the precision and recall for each confidence threshold
|
|
1109
|
+
for threshold in thresholds:
|
|
1110
|
+
pred_indexes = np.where(total_detections[:, 0] >= threshold)[0]
|
|
1111
|
+
pred_indexes_visible = np.where(
|
|
1112
|
+
total_detections_visible[:, 0] >= threshold
|
|
1113
|
+
)[0]
|
|
1114
|
+
pred_indexes_unshown = np.where(
|
|
1115
|
+
total_detections_unshown[:, 0] >= threshold
|
|
1116
|
+
)[0]
|
|
1117
|
+
TP = np.sum(total_detections[pred_indexes, 1])
|
|
1118
|
+
TP_visible = np.sum(total_detections[pred_indexes_visible, 1])
|
|
1119
|
+
TP_unshown = np.sum(total_detections[pred_indexes_unshown, 1])
|
|
1120
|
+
p = np.nan_to_num(TP / len(pred_indexes))
|
|
1121
|
+
r = np.nan_to_num(TP / (n_gt_labels_visible + n_gt_labels_unshown))
|
|
1122
|
+
p_visible = np.nan_to_num(TP_visible / len(pred_indexes_visible))
|
|
1123
|
+
r_visible = np.nan_to_num(TP_visible / n_gt_labels_visible)
|
|
1124
|
+
p_unshown = np.nan_to_num(TP_unshown / len(pred_indexes_unshown))
|
|
1125
|
+
r_unshown = np.nan_to_num(TP_unshown / n_gt_labels_unshown)
|
|
1126
|
+
precision[-1].append(p)
|
|
1127
|
+
recall[-1].append(r)
|
|
1128
|
+
precision_visible[-1].append(p_visible)
|
|
1129
|
+
recall_visible[-1].append(r_visible)
|
|
1130
|
+
precision_unshown[-1].append(p_unshown)
|
|
1131
|
+
recall_unshown[-1].append(r_unshown)
|
|
1132
|
+
|
|
1133
|
+
precision = np.array(precision).transpose()
|
|
1134
|
+
recall = np.array(recall).transpose()
|
|
1135
|
+
precision_visible = np.array(precision_visible).transpose()
|
|
1136
|
+
recall_visible = np.array(recall_visible).transpose()
|
|
1137
|
+
precision_unshown = np.array(precision_unshown).transpose()
|
|
1138
|
+
recall_unshown = np.array(recall_unshown).transpose()
|
|
1139
|
+
|
|
1140
|
+
# Sort the points based on the recall, class per class
|
|
1141
|
+
for i in np.arange(num_classes):
|
|
1142
|
+
index_sort = np.argsort(recall[:, i])
|
|
1143
|
+
precision[:, i] = precision[index_sort, i]
|
|
1144
|
+
recall[:, i] = recall[index_sort, i]
|
|
1145
|
+
|
|
1146
|
+
# Sort the points based on the recall, class per class
|
|
1147
|
+
for i in np.arange(num_classes):
|
|
1148
|
+
index_sort = np.argsort(recall_visible[:, i])
|
|
1149
|
+
precision_visible[:, i] = precision_visible[index_sort, i]
|
|
1150
|
+
recall_visible[:, i] = recall_visible[index_sort, i]
|
|
1151
|
+
|
|
1152
|
+
# Sort the points based on the recall, class per class
|
|
1153
|
+
for i in np.arange(num_classes):
|
|
1154
|
+
index_sort = np.argsort(recall_unshown[:, i])
|
|
1155
|
+
precision_unshown[:, i] = precision_unshown[index_sort, i]
|
|
1156
|
+
recall_unshown[:, i] = recall_unshown[index_sort, i]
|
|
1157
|
+
|
|
1158
|
+
return (
|
|
1159
|
+
precision,
|
|
1160
|
+
recall,
|
|
1161
|
+
precision_visible,
|
|
1162
|
+
recall_visible,
|
|
1163
|
+
precision_unshown,
|
|
1164
|
+
recall_unshown,
|
|
1165
|
+
)
|
|
1166
|
+
|
|
1167
|
+
|
|
1168
|
+
def compute_mAP(precision, recall):
|
|
1169
|
+
"""Compute mean AP per class from precision and recall points
|
|
1170
|
+
|
|
1171
|
+
Args:
|
|
1172
|
+
precision: List of precision points.
|
|
1173
|
+
recall: List of recall points.
|
|
1174
|
+
|
|
1175
|
+
Returns:
|
|
1176
|
+
np.mean(mAP_per_class): mean of the mAP over all classes.
|
|
1177
|
+
mAP_per_class: List of mAP for each class.
|
|
1178
|
+
|
|
1179
|
+
"""
|
|
1180
|
+
# Array for storing the AP per class
|
|
1181
|
+
AP = np.array([0.0] * precision.shape[-1])
|
|
1182
|
+
|
|
1183
|
+
# Loop for all classes
|
|
1184
|
+
for i in np.arange(precision.shape[-1]):
|
|
1185
|
+
|
|
1186
|
+
# 11 point interpolation
|
|
1187
|
+
for j in np.arange(11) / 10:
|
|
1188
|
+
|
|
1189
|
+
index_recall = np.where(recall[:, i] >= j)[0]
|
|
1190
|
+
|
|
1191
|
+
possible_value_precision = precision[index_recall, i]
|
|
1192
|
+
max_value_precision = 0
|
|
1193
|
+
|
|
1194
|
+
if possible_value_precision.shape[0] != 0:
|
|
1195
|
+
max_value_precision = np.max(possible_value_precision)
|
|
1196
|
+
|
|
1197
|
+
AP[i] += max_value_precision
|
|
1198
|
+
|
|
1199
|
+
mAP_per_class = AP / 11
|
|
1200
|
+
|
|
1201
|
+
return np.mean(mAP_per_class), mAP_per_class
|
|
1202
|
+
|
|
1203
|
+
|
|
1204
|
+
# Tight: (SNv3): np.arange(5)*1 + 1
|
|
1205
|
+
# Loose: (SNv1/v2): np.arange(12)*5 + 5
|
|
1206
|
+
|
|
1207
|
+
|
|
1208
|
+
def delta_curve(targets, closests, detections, framerate, deltas=np.arange(5) * 1 + 1):
|
|
1209
|
+
"""Compute lists of mAP for each tolerance.
|
|
1210
|
+
|
|
1211
|
+
Args:
|
|
1212
|
+
targets (List(np.array(vector_size,num_classes)): List of ground truth targets of shape (number of videos, number of frames,number of classes).
|
|
1213
|
+
closests (List(np.array(vector_size - 1,num_classes)): List of closest action index of shape (number of videos, number of frames - 1,number of classes).
|
|
1214
|
+
detections (List(np.array(vector_size,num_classes)): List of predictions of shape (number of videos, number of frames,number of classes).
|
|
1215
|
+
delta (np.array): Tolerances.
|
|
1216
|
+
|
|
1217
|
+
Returns:
|
|
1218
|
+
mAP: List of mean mAP over all the classes for each tolerance.
|
|
1219
|
+
mAP_per_class: List of list of mAP for all classes for each tolerance.
|
|
1220
|
+
mAP_visible: List of mean mAP over all the classes for each tolerance only for the visible events.
|
|
1221
|
+
mAP_per_class_visible: List of list of mAP for all classes for each tolerance only for the visible events.
|
|
1222
|
+
mAP_unshown: List of mean mAP over all the classes for each tolerance only for the unshown events.
|
|
1223
|
+
mAP_per_class_unshown: List of list of mAP for all classes for each tolerance only for the unshown events.
|
|
1224
|
+
|
|
1225
|
+
"""
|
|
1226
|
+
mAP = list()
|
|
1227
|
+
mAP_per_class = list()
|
|
1228
|
+
mAP_visible = list()
|
|
1229
|
+
mAP_per_class_visible = list()
|
|
1230
|
+
mAP_unshown = list()
|
|
1231
|
+
mAP_per_class_unshown = list()
|
|
1232
|
+
|
|
1233
|
+
for delta in tqdm(deltas * framerate):
|
|
1234
|
+
|
|
1235
|
+
(
|
|
1236
|
+
precision,
|
|
1237
|
+
recall,
|
|
1238
|
+
precision_visible,
|
|
1239
|
+
recall_visible,
|
|
1240
|
+
precision_unshown,
|
|
1241
|
+
recall_unshown,
|
|
1242
|
+
) = compute_precision_recall_curve(targets, closests, detections, delta)
|
|
1243
|
+
|
|
1244
|
+
tmp_mAP, tmp_mAP_per_class = compute_mAP(precision, recall)
|
|
1245
|
+
mAP.append(tmp_mAP)
|
|
1246
|
+
mAP_per_class.append(tmp_mAP_per_class)
|
|
1247
|
+
# TODO: compute visible/undshown from another JSON file containing only the visible/unshown annotations
|
|
1248
|
+
tmp_mAP_visible, tmp_mAP_per_class_visible = compute_mAP(
|
|
1249
|
+
precision_visible, recall_visible
|
|
1250
|
+
)
|
|
1251
|
+
mAP_visible.append(tmp_mAP_visible)
|
|
1252
|
+
mAP_per_class_visible.append(tmp_mAP_per_class_visible)
|
|
1253
|
+
tmp_mAP_unshown, tmp_mAP_per_class_unshown = compute_mAP(
|
|
1254
|
+
precision_unshown, recall_unshown
|
|
1255
|
+
)
|
|
1256
|
+
mAP_unshown.append(tmp_mAP_unshown)
|
|
1257
|
+
mAP_per_class_unshown.append(tmp_mAP_per_class_unshown)
|
|
1258
|
+
|
|
1259
|
+
return (
|
|
1260
|
+
mAP,
|
|
1261
|
+
mAP_per_class,
|
|
1262
|
+
mAP_visible,
|
|
1263
|
+
mAP_per_class_visible,
|
|
1264
|
+
mAP_unshown,
|
|
1265
|
+
mAP_per_class_unshown,
|
|
1266
|
+
)
|
|
1267
|
+
|
|
1268
|
+
|
|
1269
|
+
def average_mAP(
|
|
1270
|
+
targets, detections, closests, framerate=2, deltas=np.arange(5) * 1 + 1
|
|
1271
|
+
):
|
|
1272
|
+
"""Compute average mAP.
|
|
1273
|
+
|
|
1274
|
+
Args:
|
|
1275
|
+
targets (List(np.array(vector_size,num_classes)): List of ground truth targets of shape (number of videos, number of frames,number of classes).
|
|
1276
|
+
closests (List(np.array(vector_size - 1,num_classes)): List of closest action index of shape (number of videos, number of frames - 1,number of classes).
|
|
1277
|
+
detections (List(np.array(vector_size,num_classes)): List of predictions of shape (number of videos, number of frames,number of classes).
|
|
1278
|
+
framerate (int).
|
|
1279
|
+
Default: 2.
|
|
1280
|
+
delta (np.array): Tolerances.
|
|
1281
|
+
|
|
1282
|
+
Returns:
|
|
1283
|
+
a_mAP: average mAP.
|
|
1284
|
+
a_mAP_per_class: List of average mAP for all classes.
|
|
1285
|
+
a_mAP_visible: average mAP only for the visible events.
|
|
1286
|
+
a_mAP_per_class_visible: List of average mAP only for the visible events.
|
|
1287
|
+
a_mAP_unshown: average mAP only for the unshown events.
|
|
1288
|
+
a_mAP_per_class_unshown: List of average mAP only for the unshown events.
|
|
1289
|
+
|
|
1290
|
+
"""
|
|
1291
|
+
(
|
|
1292
|
+
mAP,
|
|
1293
|
+
mAP_per_class,
|
|
1294
|
+
mAP_visible,
|
|
1295
|
+
mAP_per_class_visible,
|
|
1296
|
+
mAP_unshown,
|
|
1297
|
+
mAP_per_class_unshown,
|
|
1298
|
+
) = delta_curve(targets, closests, detections, framerate, deltas)
|
|
1299
|
+
|
|
1300
|
+
if len(mAP) == 1:
|
|
1301
|
+
return (
|
|
1302
|
+
mAP[0],
|
|
1303
|
+
mAP_per_class[0],
|
|
1304
|
+
mAP_visible[0],
|
|
1305
|
+
mAP_per_class_visible[0],
|
|
1306
|
+
mAP_unshown[0],
|
|
1307
|
+
mAP_per_class_unshown[0],
|
|
1308
|
+
)
|
|
1309
|
+
|
|
1310
|
+
# Compute the average mAP
|
|
1311
|
+
integral = 0.0
|
|
1312
|
+
for i in np.arange(len(mAP) - 1):
|
|
1313
|
+
integral += (mAP[i] + mAP[i + 1]) / 2
|
|
1314
|
+
a_mAP = integral / ((len(mAP) - 1))
|
|
1315
|
+
|
|
1316
|
+
integral_visible = 0.0
|
|
1317
|
+
for i in np.arange(len(mAP_visible) - 1):
|
|
1318
|
+
integral_visible += (mAP_visible[i] + mAP_visible[i + 1]) / 2
|
|
1319
|
+
a_mAP_visible = integral_visible / ((len(mAP_visible) - 1))
|
|
1320
|
+
|
|
1321
|
+
integral_unshown = 0.0
|
|
1322
|
+
for i in np.arange(len(mAP_unshown) - 1):
|
|
1323
|
+
integral_unshown += (mAP_unshown[i] + mAP_unshown[i + 1]) / 2
|
|
1324
|
+
a_mAP_unshown = integral_unshown / ((len(mAP_unshown) - 1))
|
|
1325
|
+
a_mAP_unshown = a_mAP_unshown * 17 / 13
|
|
1326
|
+
|
|
1327
|
+
a_mAP_per_class = list()
|
|
1328
|
+
for c in np.arange(len(mAP_per_class[0])):
|
|
1329
|
+
integral_per_class = 0.0
|
|
1330
|
+
for i in np.arange(len(mAP_per_class) - 1):
|
|
1331
|
+
integral_per_class += (mAP_per_class[i][c] + mAP_per_class[i + 1][c]) / 2
|
|
1332
|
+
a_mAP_per_class.append(integral_per_class / ((len(mAP_per_class) - 1)))
|
|
1333
|
+
|
|
1334
|
+
a_mAP_per_class_visible = list()
|
|
1335
|
+
for c in np.arange(len(mAP_per_class_visible[0])):
|
|
1336
|
+
integral_per_class_visible = 0.0
|
|
1337
|
+
for i in np.arange(len(mAP_per_class_visible) - 1):
|
|
1338
|
+
integral_per_class_visible += (
|
|
1339
|
+
mAP_per_class_visible[i][c] + mAP_per_class_visible[i + 1][c]
|
|
1340
|
+
) / 2
|
|
1341
|
+
a_mAP_per_class_visible.append(
|
|
1342
|
+
integral_per_class_visible / ((len(mAP_per_class_visible) - 1))
|
|
1343
|
+
)
|
|
1344
|
+
|
|
1345
|
+
a_mAP_per_class_unshown = list()
|
|
1346
|
+
for c in np.arange(len(mAP_per_class_unshown[0])):
|
|
1347
|
+
integral_per_class_unshown = 0.0
|
|
1348
|
+
for i in np.arange(len(mAP_per_class_unshown) - 1):
|
|
1349
|
+
integral_per_class_unshown += (
|
|
1350
|
+
mAP_per_class_unshown[i][c] + mAP_per_class_unshown[i + 1][c]
|
|
1351
|
+
) / 2
|
|
1352
|
+
a_mAP_per_class_unshown.append(
|
|
1353
|
+
integral_per_class_unshown / ((len(mAP_per_class_unshown) - 1))
|
|
1354
|
+
)
|
|
1355
|
+
|
|
1356
|
+
return (
|
|
1357
|
+
a_mAP,
|
|
1358
|
+
a_mAP_per_class,
|
|
1359
|
+
a_mAP_visible,
|
|
1360
|
+
a_mAP_per_class_visible,
|
|
1361
|
+
a_mAP_unshown,
|
|
1362
|
+
a_mAP_per_class_unshown,
|
|
1363
|
+
)
|
|
1364
|
+
|
|
1365
|
+
|
|
1366
|
+
def LoadJsonFromZip(zippedFile, JsonPath):
|
|
1367
|
+
with zipfile.ZipFile(zippedFile, "r") as z:
|
|
1368
|
+
# print(filename)
|
|
1369
|
+
with z.open(JsonPath) as f:
|
|
1370
|
+
data = f.read()
|
|
1371
|
+
d = json.loads(data.decode("utf-8"))
|
|
1372
|
+
|
|
1373
|
+
return d
|
|
1374
|
+
|
|
1375
|
+
|
|
1376
|
+
def get_closest_action_index(dense_labels, closest_numpy):
|
|
1377
|
+
"""Assign the action to a list of indexes based on the previous and the next indexes.
|
|
1378
|
+
Example:
|
|
1379
|
+
If action occurs at index 10, the previous index at which the actions occured is 6 and the next one is 12,
|
|
1380
|
+
the action will then be assigned from index 8 to 11.
|
|
1381
|
+
|
|
1382
|
+
Args:
|
|
1383
|
+
dense_labels: Vector of groundtruth labels.
|
|
1384
|
+
closest_numpy: Empty vector that stores the labels at the closest indexes.
|
|
1385
|
+
|
|
1386
|
+
Returns:
|
|
1387
|
+
closest_numpy: Vector that stores the labels at the closest indexes.
|
|
1388
|
+
"""
|
|
1389
|
+
for c in np.arange(dense_labels.shape[-1]):
|
|
1390
|
+
indexes = np.where(dense_labels[:, c] != 0)[0].tolist()
|
|
1391
|
+
if len(indexes) == 0:
|
|
1392
|
+
continue
|
|
1393
|
+
indexes.insert(0, -indexes[0])
|
|
1394
|
+
indexes.append(2 * closest_numpy.shape[0])
|
|
1395
|
+
for i in np.arange(len(indexes) - 2) + 1:
|
|
1396
|
+
start = max(0, (indexes[i - 1] + indexes[i]) // 2)
|
|
1397
|
+
stop = min(closest_numpy.shape[0], (indexes[i] + indexes[i + 1]) // 2)
|
|
1398
|
+
closest_numpy[start:stop, c] = dense_labels[indexes[i], c]
|
|
1399
|
+
return closest_numpy
|
|
1400
|
+
|
|
1401
|
+
|
|
1402
|
+
def compute_performances_mAP(
|
|
1403
|
+
metric, targets_numpy, detections_numpy, closests_numpy, INVERSE_EVENT_DICTIONARY
|
|
1404
|
+
):
|
|
1405
|
+
"""Compute the different mAP and display them.
|
|
1406
|
+
|
|
1407
|
+
Args:
|
|
1408
|
+
metric (string): The metric that will be used to compute the mAP.
|
|
1409
|
+
targets_numpy (List(np.array(vector_size,num_classes)): List of ground truth targets of shape (number of videos, number of frames,number of classes).
|
|
1410
|
+
detections_numpy (List(np.array(vector_size - 1,num_classes)): List of closest action index of shape (number of videos, number of frames - 1,number of classes).
|
|
1411
|
+
closests_numpy (List(np.array(vector_size,num_classes)): List of predictions of shape (number of videos, number of frames,number of classes).
|
|
1412
|
+
INVERSE_EVENT_DICTIONARY (dict): mapping between indexes and class name.
|
|
1413
|
+
|
|
1414
|
+
Returns:
|
|
1415
|
+
results (dict): Dictionnary containing the different mAP computed.
|
|
1416
|
+
"""
|
|
1417
|
+
if metric == "loose":
|
|
1418
|
+
deltas = np.arange(12) * 5 + 5
|
|
1419
|
+
elif metric == "tight":
|
|
1420
|
+
deltas = np.arange(5) * 1 + 1
|
|
1421
|
+
elif metric == "at1":
|
|
1422
|
+
deltas = np.array([1])
|
|
1423
|
+
elif metric == "at2":
|
|
1424
|
+
deltas = np.array([2])
|
|
1425
|
+
elif metric == "at3":
|
|
1426
|
+
deltas = np.array([3])
|
|
1427
|
+
elif metric == "at4":
|
|
1428
|
+
deltas = np.array([4])
|
|
1429
|
+
elif metric == "at5":
|
|
1430
|
+
deltas = np.array([5])
|
|
1431
|
+
|
|
1432
|
+
# Compute the performances
|
|
1433
|
+
(
|
|
1434
|
+
a_mAP,
|
|
1435
|
+
a_mAP_per_class,
|
|
1436
|
+
a_mAP_visible,
|
|
1437
|
+
a_mAP_per_class_visible,
|
|
1438
|
+
a_mAP_unshown,
|
|
1439
|
+
a_mAP_per_class_unshown,
|
|
1440
|
+
) = average_mAP(
|
|
1441
|
+
targets_numpy, detections_numpy, closests_numpy, framerate=2, deltas=deltas
|
|
1442
|
+
)
|
|
1443
|
+
|
|
1444
|
+
results = {
|
|
1445
|
+
"a_mAP": a_mAP,
|
|
1446
|
+
"a_mAP_per_class": a_mAP_per_class,
|
|
1447
|
+
"a_mAP_visible": a_mAP_visible,
|
|
1448
|
+
"a_mAP_per_class_visible": a_mAP_per_class_visible,
|
|
1449
|
+
"a_mAP_unshown": a_mAP_unshown,
|
|
1450
|
+
"a_mAP_per_class_unshown": a_mAP_per_class_unshown,
|
|
1451
|
+
}
|
|
1452
|
+
|
|
1453
|
+
rows = []
|
|
1454
|
+
for i in range(len(results["a_mAP_per_class"])):
|
|
1455
|
+
label = INVERSE_EVENT_DICTIONARY[i]
|
|
1456
|
+
rows.append(
|
|
1457
|
+
(
|
|
1458
|
+
label,
|
|
1459
|
+
"{:0.2f}".format(results["a_mAP_per_class"][i] * 100),
|
|
1460
|
+
"{:0.2f}".format(results["a_mAP_per_class_visible"][i] * 100),
|
|
1461
|
+
"{:0.2f}".format(results["a_mAP_per_class_unshown"][i] * 100),
|
|
1462
|
+
)
|
|
1463
|
+
)
|
|
1464
|
+
rows.append(
|
|
1465
|
+
(
|
|
1466
|
+
"Average mAP",
|
|
1467
|
+
"{:0.2f}".format(results["a_mAP"] * 100),
|
|
1468
|
+
"{:0.2f}".format(results["a_mAP_visible"] * 100),
|
|
1469
|
+
"{:0.2f}".format(results["a_mAP_unshown"] * 100),
|
|
1470
|
+
)
|
|
1471
|
+
)
|
|
1472
|
+
|
|
1473
|
+
header = ["", "Any", "Visible", "Unseen"]
|
|
1474
|
+
result = tabulate(rows, headers=header)
|
|
1475
|
+
log_table_wandb(name=f"Final Scores (Metric : {metric})", rows=rows, headers=header)
|
|
1476
|
+
logging.info("Best Performance at end of training ")
|
|
1477
|
+
logging.info("Metric: " + metric)
|
|
1478
|
+
logging.info("\n" + result)
|
|
1479
|
+
# print(tabulate(rows, headers=['', 'Any', 'Visible', 'Unseen']))
|
|
1480
|
+
return result
|
|
1481
|
+
|
|
1482
|
+
np.seterr(divide="ignore", invalid="ignore")
|