opensportslib 0.0.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opensportslib/__init__.py +18 -0
- opensportslib/apis/__init__.py +21 -0
- opensportslib/apis/classification.py +361 -0
- opensportslib/apis/localization.py +228 -0
- opensportslib/config/classification.yaml +104 -0
- opensportslib/config/classification_tracking.yaml +103 -0
- opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
- opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
- opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
- opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
- opensportslib/config/localization.yaml +132 -0
- opensportslib/config/sngar_frames.yaml +98 -0
- opensportslib/core/__init__.py +0 -0
- opensportslib/core/loss/__init__.py +0 -0
- opensportslib/core/loss/builder.py +40 -0
- opensportslib/core/loss/calf.py +258 -0
- opensportslib/core/loss/ce.py +23 -0
- opensportslib/core/loss/combine.py +42 -0
- opensportslib/core/loss/nll.py +25 -0
- opensportslib/core/optimizer/__init__.py +0 -0
- opensportslib/core/optimizer/builder.py +38 -0
- opensportslib/core/sampler/weighted_sampler.py +104 -0
- opensportslib/core/scheduler/__init__.py +0 -0
- opensportslib/core/scheduler/builder.py +77 -0
- opensportslib/core/trainer/__init__.py +0 -0
- opensportslib/core/trainer/classification_trainer.py +1131 -0
- opensportslib/core/trainer/localization_trainer.py +1009 -0
- opensportslib/core/utils/checkpoint.py +238 -0
- opensportslib/core/utils/config.py +199 -0
- opensportslib/core/utils/data.py +85 -0
- opensportslib/core/utils/ddp.py +77 -0
- opensportslib/core/utils/default_args.py +110 -0
- opensportslib/core/utils/load_annotations.py +485 -0
- opensportslib/core/utils/seed.py +26 -0
- opensportslib/core/utils/video_processing.py +389 -0
- opensportslib/core/utils/wandb.py +110 -0
- opensportslib/datasets/__init__.py +0 -0
- opensportslib/datasets/builder.py +42 -0
- opensportslib/datasets/classification_dataset.py +582 -0
- opensportslib/datasets/localization_dataset.py +813 -0
- opensportslib/datasets/utils/__init__.py +15 -0
- opensportslib/datasets/utils/tracking.py +615 -0
- opensportslib/metrics/classification_metric.py +176 -0
- opensportslib/metrics/localization_metric.py +1482 -0
- opensportslib/models/__init__.py +0 -0
- opensportslib/models/backbones/builder.py +590 -0
- opensportslib/models/base/e2e.py +252 -0
- opensportslib/models/base/tracking.py +73 -0
- opensportslib/models/base/vars.py +29 -0
- opensportslib/models/base/video.py +130 -0
- opensportslib/models/base/video_mae.py +60 -0
- opensportslib/models/builder.py +43 -0
- opensportslib/models/heads/builder.py +266 -0
- opensportslib/models/neck/builder.py +210 -0
- opensportslib/models/utils/common.py +176 -0
- opensportslib/models/utils/impl/__init__.py +0 -0
- opensportslib/models/utils/impl/asformer.py +390 -0
- opensportslib/models/utils/impl/calf.py +74 -0
- opensportslib/models/utils/impl/gsm.py +112 -0
- opensportslib/models/utils/impl/gtad.py +347 -0
- opensportslib/models/utils/impl/tsm.py +123 -0
- opensportslib/models/utils/litebase.py +59 -0
- opensportslib/models/utils/modules.py +120 -0
- opensportslib/models/utils/shift.py +135 -0
- opensportslib/models/utils/utils.py +276 -0
- opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
- opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
- opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
- opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
- opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,615 @@
|
|
|
1
|
+
# opensportslib/datasets/utils/tracking.py
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
tracking utilities for player-position graph data.
|
|
5
|
+
|
|
6
|
+
provides constants, edge-building strategies, spatial augmentations,
|
|
7
|
+
and frame-level feature extraction for the SN-GAR tracking
|
|
8
|
+
modality.
|
|
9
|
+
|
|
10
|
+
feature vector layout (per object, per frame)::
|
|
11
|
+
|
|
12
|
+
Index Field
|
|
13
|
+
----- ----------------------------------
|
|
14
|
+
0 x (pitch coordinates)
|
|
15
|
+
1 y
|
|
16
|
+
2 is_ball (one-hot entity type)
|
|
17
|
+
3 is_home
|
|
18
|
+
4 is_away
|
|
19
|
+
5 dx (velocity delta, computed across consecutive frames)
|
|
20
|
+
6 dy
|
|
21
|
+
7 z (ball height; MISSING_VALUE for players)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import json
|
|
25
|
+
import random
|
|
26
|
+
|
|
27
|
+
import numpy as np
|
|
28
|
+
import pandas as pd
|
|
29
|
+
|
|
30
|
+
# -------------------------------------------------------------------
|
|
31
|
+
# constants
|
|
32
|
+
# -------------------------------------------------------------------
|
|
33
|
+
|
|
34
|
+
NUM_OBJECTS = 23 # 1 ball + 11 home players + 11 away players
|
|
35
|
+
FEATURE_DIM = 8 # [x, y, is_ball, is_home, is_away, dx, dy, z]
|
|
36
|
+
|
|
37
|
+
# normalization bounds (pitch dimension in metres).
|
|
38
|
+
PITCH_HALF_LENGTH = 85.0
|
|
39
|
+
PITCH_HALF_WIDTH = 50.0
|
|
40
|
+
MAX_DISPLACEMENT = 110.0
|
|
41
|
+
MAX_BALL_HEIGHT = 30.0
|
|
42
|
+
|
|
43
|
+
# sentinel values used to mark missing / unobserved objects.
|
|
44
|
+
# raw features use MISSING_VALUE; after normalization they are mapped
|
|
45
|
+
# to MISSING_VALUE_NORMALIZED so the model sees a distinct, in-range
|
|
46
|
+
# indicator that cannot be confused with a valid coordinate.
|
|
47
|
+
MISSING_VALUE = -200.0
|
|
48
|
+
MISSING_VALUE_NORMALIZED = -2.0
|
|
49
|
+
|
|
50
|
+
# slot layout: index 0 is always the ball, 1-11 home, 12-22 away.
|
|
51
|
+
_BALL_SLOT = 0
|
|
52
|
+
_HOME_SLOT_START = 1
|
|
53
|
+
_HOME_SLOT_END = 12 # exclusive
|
|
54
|
+
_AWAY_SLOT_START = 12
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# -------------------------------------------------------------------
|
|
58
|
+
# frame-level feature extraction
|
|
59
|
+
# -------------------------------------------------------------------
|
|
60
|
+
|
|
61
|
+
def parse_frame(row):
|
|
62
|
+
"""parse a single parquet row into per-object features and positions.
|
|
63
|
+
|
|
64
|
+
Each row is expected to contain JSON-encoded columns balls,
|
|
65
|
+
homePlayers, and awayPlayers following the SN-GAR
|
|
66
|
+
tracking format.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
row: A single row (pandas.Series) from a tracking parquet
|
|
70
|
+
file.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
A tuple (features, positions) where features is a
|
|
74
|
+
numpy.ndarray of shape (NUM_OBJECTS, FEATURE_DIM) and
|
|
75
|
+
positions is a list of position-group strings (e.g.
|
|
76
|
+
"GK", "DEF"). Unobserved slots are filled with
|
|
77
|
+
MISSING_VALUE and an empty string respectively.
|
|
78
|
+
"""
|
|
79
|
+
features = np.full(
|
|
80
|
+
(NUM_OBJECTS, FEATURE_DIM), MISSING_VALUE, dtype=np.float32,
|
|
81
|
+
)
|
|
82
|
+
positions = [""] * NUM_OBJECTS
|
|
83
|
+
|
|
84
|
+
obj_idx = _BALL_SLOT
|
|
85
|
+
|
|
86
|
+
# -- ball (always slot 0) --
|
|
87
|
+
ball_str = row.get("balls", "null")
|
|
88
|
+
if pd.notna(ball_str) and ball_str not in ("null", ""):
|
|
89
|
+
try:
|
|
90
|
+
ball_list = json.loads(ball_str)
|
|
91
|
+
if ball_list:
|
|
92
|
+
ball = ball_list[0]
|
|
93
|
+
x, y = ball.get("x"), ball.get("y")
|
|
94
|
+
z = ball.get("z", 0)
|
|
95
|
+
if x is not None and y is not None:
|
|
96
|
+
features[obj_idx] = [
|
|
97
|
+
float(x), float(y),
|
|
98
|
+
1, 0, 0, # one-hot: ball
|
|
99
|
+
0, 0, # dx, dy (filled later)
|
|
100
|
+
float(z),
|
|
101
|
+
]
|
|
102
|
+
positions[obj_idx] = "BALL"
|
|
103
|
+
except (json.JSONDecodeError, TypeError):
|
|
104
|
+
pass
|
|
105
|
+
obj_idx = _HOME_SLOT_START
|
|
106
|
+
|
|
107
|
+
# -- home players (slots 1-11) --
|
|
108
|
+
home_str = row.get("homePlayers", "[]")
|
|
109
|
+
if pd.notna(home_str) and home_str not in ("null", ""):
|
|
110
|
+
try:
|
|
111
|
+
home_players = json.loads(home_str)
|
|
112
|
+
home_players = sorted(
|
|
113
|
+
home_players,
|
|
114
|
+
key=lambda p: int(p.get("jerseyNum", 0)),
|
|
115
|
+
)[:11]
|
|
116
|
+
|
|
117
|
+
for player in home_players:
|
|
118
|
+
x, y = player.get("x"), player.get("y")
|
|
119
|
+
if x is not None and y is not None:
|
|
120
|
+
features[obj_idx] = [
|
|
121
|
+
float(x), float(y),
|
|
122
|
+
0, 1, 0, # one-hot: home
|
|
123
|
+
0, 0, # dx, dy
|
|
124
|
+
MISSING_VALUE, # z unused for players
|
|
125
|
+
]
|
|
126
|
+
positions[obj_idx] = player.get("positionGroup", "")
|
|
127
|
+
obj_idx += 1
|
|
128
|
+
|
|
129
|
+
# advance past any unfilled home slots.
|
|
130
|
+
while obj_idx < _HOME_SLOT_END:
|
|
131
|
+
obj_idx += 1
|
|
132
|
+
except (json.JSONDecodeError, TypeError):
|
|
133
|
+
obj_idx = _HOME_SLOT_END
|
|
134
|
+
else:
|
|
135
|
+
obj_idx = _HOME_SLOT_END
|
|
136
|
+
|
|
137
|
+
# -- away players (slots 12-22) --
|
|
138
|
+
away_str = row.get("awayPlayers", "[]")
|
|
139
|
+
if pd.notna(away_str) and away_str not in ("null", ""):
|
|
140
|
+
try:
|
|
141
|
+
away_players = json.loads(away_str)
|
|
142
|
+
away_players = sorted(
|
|
143
|
+
away_players,
|
|
144
|
+
key=lambda p: int(p.get("jerseyNum", 0)),
|
|
145
|
+
)[:11]
|
|
146
|
+
|
|
147
|
+
for player in away_players:
|
|
148
|
+
x, y = player.get("x"), player.get("y")
|
|
149
|
+
if x is not None and y is not None:
|
|
150
|
+
features[obj_idx] = [
|
|
151
|
+
float(x), float(y),
|
|
152
|
+
0, 0, 1, # one-hot: away
|
|
153
|
+
0, 0, # dx, dy
|
|
154
|
+
MISSING_VALUE, # z unused for players
|
|
155
|
+
]
|
|
156
|
+
positions[obj_idx] = player.get("positionGroup", "")
|
|
157
|
+
obj_idx += 1
|
|
158
|
+
except (json.JSONDecodeError, TypeError):
|
|
159
|
+
pass
|
|
160
|
+
|
|
161
|
+
return features, positions
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
# -------------------------------------------------------------------
|
|
165
|
+
# temporal feature computation
|
|
166
|
+
# -------------------------------------------------------------------
|
|
167
|
+
|
|
168
|
+
def compute_deltas(all_features):
|
|
169
|
+
"""compute per-object velocity deltas across consecutive frames.
|
|
170
|
+
|
|
171
|
+
For each object that is observed in both frame t and frame t - 1,
|
|
172
|
+
the displacement (dx, dy) is written into feature indices 5 and 6.
|
|
173
|
+
frame 0 retains zero deltas.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
all_features: numpy.ndarray of shape
|
|
177
|
+
(num_frames, NUM_OBJECTS, FEATURE_DIM).
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
the same array, modified in-place, with velocity deltas
|
|
181
|
+
populated.
|
|
182
|
+
"""
|
|
183
|
+
for t in range(1, all_features.shape[0]):
|
|
184
|
+
for obj in range(NUM_OBJECTS):
|
|
185
|
+
curr_valid = all_features[t, obj, 0] != MISSING_VALUE
|
|
186
|
+
prev_valid = all_features[t - 1, obj, 0] != MISSING_VALUE
|
|
187
|
+
if curr_valid and prev_valid:
|
|
188
|
+
all_features[t, obj, 5] = (
|
|
189
|
+
all_features[t, obj, 0] - all_features[t - 1, obj, 0]
|
|
190
|
+
)
|
|
191
|
+
all_features[t, obj, 6] = (
|
|
192
|
+
all_features[t, obj, 1] - all_features[t - 1, obj, 1]
|
|
193
|
+
)
|
|
194
|
+
return all_features
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def normalize_features(features):
|
|
198
|
+
"""normalize spatial features to roughly [-1, 1].
|
|
199
|
+
|
|
200
|
+
observed coordinates are divided by the known pitch / displacement
|
|
201
|
+
bounds. unobserved slots are set to MISSING_VALUE_NORMALIZED
|
|
202
|
+
so the model receives a distinct, out-of-range sentinel that
|
|
203
|
+
cannot be confused with a valid normalized value.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
features: numpy.ndarray of shape
|
|
207
|
+
(num_frames, NUM_OBJECTS, FEATURE_DIM).
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
a new array with the same shape containing normalized values.
|
|
211
|
+
the input is not modified.
|
|
212
|
+
"""
|
|
213
|
+
features_norm = features.copy()
|
|
214
|
+
valid_mask = features_norm[:, :, 0] != MISSING_VALUE
|
|
215
|
+
|
|
216
|
+
features_norm[valid_mask, 0] /= PITCH_HALF_LENGTH
|
|
217
|
+
features_norm[valid_mask, 1] /= PITCH_HALF_WIDTH
|
|
218
|
+
features_norm[valid_mask, 5] /= MAX_DISPLACEMENT
|
|
219
|
+
features_norm[valid_mask, 6] /= MAX_DISPLACEMENT
|
|
220
|
+
features_norm[valid_mask, 7] /= MAX_BALL_HEIGHT
|
|
221
|
+
|
|
222
|
+
# write the normalized sentinel into every spatial channel of
|
|
223
|
+
# missing objects so downstream layers can distinguish "absent"
|
|
224
|
+
# from "at the origin".
|
|
225
|
+
for ch in (0, 1, 5, 6, 7):
|
|
226
|
+
features_norm[~valid_mask, ch] = MISSING_VALUE_NORMALIZED
|
|
227
|
+
|
|
228
|
+
return features_norm
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
# -------------------------------------------------------------------
|
|
232
|
+
# edge building strategies
|
|
233
|
+
# -------------------------------------------------------------------
|
|
234
|
+
|
|
235
|
+
def build_edge_index(node_features, node_positions, edge_type, k=8, r=15.0):
|
|
236
|
+
"""build a graph edge index for a single frame.
|
|
237
|
+
|
|
238
|
+
supports several connectivity strategies that trade off density
|
|
239
|
+
against spatial or tactical priors.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
node_features: numpy.ndarray of shape
|
|
243
|
+
(NUM_OBJECTS, FEATURE_DIM).
|
|
244
|
+
node_positions: List of position-group strings (length
|
|
245
|
+
NUM_OBJECTS).
|
|
246
|
+
edge_type: One of "none", "full", "knn",
|
|
247
|
+
"distance", "ball_knn", "ball_distance",
|
|
248
|
+
or "positional".
|
|
249
|
+
k: Number of neighbours for knn / ball_knn strategies.
|
|
250
|
+
r: Distance threshold (metres) for distance / ball_distance
|
|
251
|
+
strategies.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
numpy.ndarray of shape (2, num_edges) in COO format,
|
|
255
|
+
compatible with PyTorch Geometric.
|
|
256
|
+
"""
|
|
257
|
+
if edge_type == "none":
|
|
258
|
+
return np.zeros((2, 0), dtype=np.int64)
|
|
259
|
+
|
|
260
|
+
if edge_type == "full":
|
|
261
|
+
return _build_full_edges(node_features)
|
|
262
|
+
|
|
263
|
+
if edge_type == "knn":
|
|
264
|
+
return _build_knn_edges(node_features, k)
|
|
265
|
+
|
|
266
|
+
if edge_type == "distance":
|
|
267
|
+
return _build_distance_edges(node_features, threshold=r)
|
|
268
|
+
|
|
269
|
+
if edge_type == "ball_knn":
|
|
270
|
+
return _build_ball_knn_edges(node_features, k)
|
|
271
|
+
|
|
272
|
+
if edge_type == "ball_distance":
|
|
273
|
+
return _build_ball_distance_edges(node_features, threshold=r)
|
|
274
|
+
|
|
275
|
+
if edge_type == "positional":
|
|
276
|
+
return _build_positional_edges(node_features, node_positions)
|
|
277
|
+
|
|
278
|
+
raise ValueError(f"Unknown edge_type: {edge_type}")
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
# -- strategy implementations (private) ----------------------------
|
|
282
|
+
|
|
283
|
+
def _build_full_edges(node_features):
|
|
284
|
+
"""fully connected graph, excluding self-loops and missing nodes."""
|
|
285
|
+
num_nodes = node_features.shape[0]
|
|
286
|
+
edge_list = [
|
|
287
|
+
[i, j]
|
|
288
|
+
for i in range(num_nodes)
|
|
289
|
+
for j in range(num_nodes)
|
|
290
|
+
if i != j
|
|
291
|
+
and node_features[i, 0] != MISSING_VALUE
|
|
292
|
+
and node_features[j, 0] != MISSING_VALUE
|
|
293
|
+
]
|
|
294
|
+
|
|
295
|
+
if not edge_list:
|
|
296
|
+
return np.zeros((2, 0), dtype=np.int64)
|
|
297
|
+
return np.array(edge_list, dtype=np.int64).T
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _build_knn_edges(node_features, k):
|
|
301
|
+
"""k-nearest-neighbour edges based on Euclidean pitch distance."""
|
|
302
|
+
num_nodes = node_features.shape[0]
|
|
303
|
+
edge_list = []
|
|
304
|
+
|
|
305
|
+
for i in range(num_nodes):
|
|
306
|
+
if node_features[i, 0] == MISSING_VALUE:
|
|
307
|
+
continue
|
|
308
|
+
|
|
309
|
+
distances = []
|
|
310
|
+
for j in range(num_nodes):
|
|
311
|
+
if i != j and node_features[j, 0] != MISSING_VALUE:
|
|
312
|
+
dist = np.linalg.norm(
|
|
313
|
+
node_features[i, :2] - node_features[j, :2],
|
|
314
|
+
)
|
|
315
|
+
distances.append((j, dist))
|
|
316
|
+
|
|
317
|
+
distances.sort(key=lambda x: x[1])
|
|
318
|
+
k_nearest = distances[: min(k, len(distances))]
|
|
319
|
+
|
|
320
|
+
for neighbour_idx, _ in k_nearest:
|
|
321
|
+
edge_list.append([i, neighbour_idx])
|
|
322
|
+
edge_list.append([neighbour_idx, i])
|
|
323
|
+
|
|
324
|
+
if not edge_list:
|
|
325
|
+
return np.zeros((2, 0), dtype=np.int64)
|
|
326
|
+
|
|
327
|
+
# de-duplicate symmetric pairs.
|
|
328
|
+
edge_array = np.array(edge_list, dtype=np.int64).T
|
|
329
|
+
edge_array = np.unique(edge_array, axis=1)
|
|
330
|
+
return edge_array
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def _build_distance_edges(node_features, threshold=15.0):
|
|
334
|
+
"""edges between all node pairs within a distance threshold."""
|
|
335
|
+
num_nodes = node_features.shape[0]
|
|
336
|
+
edge_list = []
|
|
337
|
+
|
|
338
|
+
for i in range(num_nodes):
|
|
339
|
+
if node_features[i, 0] == MISSING_VALUE:
|
|
340
|
+
continue
|
|
341
|
+
for j in range(i + 1, num_nodes):
|
|
342
|
+
if node_features[j, 0] == MISSING_VALUE:
|
|
343
|
+
continue
|
|
344
|
+
dist = np.linalg.norm(
|
|
345
|
+
node_features[i, :2] - node_features[j, :2],
|
|
346
|
+
)
|
|
347
|
+
if dist <= threshold:
|
|
348
|
+
edge_list.append([i, j])
|
|
349
|
+
edge_list.append([j, i])
|
|
350
|
+
|
|
351
|
+
if not edge_list:
|
|
352
|
+
return np.zeros((2, 0), dtype=np.int64)
|
|
353
|
+
return np.array(edge_list, dtype=np.int64).T
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _build_positional_edges(node_features, node_positions):
|
|
357
|
+
"""tactical-structure edges following formation lines.
|
|
358
|
+
|
|
359
|
+
connects adjacent lines within each team
|
|
360
|
+
(GK <-> DEF <-> MID <-> FWD) and players within the same line.
|
|
361
|
+
the ball is connected to every valid player on the pitch.
|
|
362
|
+
"""
|
|
363
|
+
edge_list = []
|
|
364
|
+
|
|
365
|
+
home_players = {}
|
|
366
|
+
away_players = {}
|
|
367
|
+
ball_idx = None
|
|
368
|
+
|
|
369
|
+
for i in range(node_features.shape[0]):
|
|
370
|
+
if node_features[i, 0] == MISSING_VALUE or not node_positions[i]:
|
|
371
|
+
continue
|
|
372
|
+
|
|
373
|
+
pos = node_positions[i]
|
|
374
|
+
|
|
375
|
+
if pos == "BALL":
|
|
376
|
+
ball_idx = i
|
|
377
|
+
elif node_features[i, 3] == 1.0: # home flag
|
|
378
|
+
home_players.setdefault(pos, []).append(i)
|
|
379
|
+
elif node_features[i, 4] == 1.0: # away flag
|
|
380
|
+
away_players.setdefault(pos, []).append(i)
|
|
381
|
+
|
|
382
|
+
# intra-team tactical edges for both teams.
|
|
383
|
+
for team_players in (home_players, away_players):
|
|
384
|
+
gk = team_players.get("GK", [])
|
|
385
|
+
defenders = team_players.get("DEF", [])
|
|
386
|
+
midfielders = team_players.get("MID", [])
|
|
387
|
+
forwards = team_players.get("FWD", [])
|
|
388
|
+
|
|
389
|
+
# GK <-> DEF
|
|
390
|
+
for p1 in gk:
|
|
391
|
+
for p2 in defenders:
|
|
392
|
+
edge_list.extend([[p1, p2], [p2, p1]])
|
|
393
|
+
|
|
394
|
+
# DEF <-> DEF and DEF <-> MID
|
|
395
|
+
for i, p1 in enumerate(defenders):
|
|
396
|
+
for p2 in defenders[i + 1 :]:
|
|
397
|
+
edge_list.extend([[p1, p2], [p2, p1]])
|
|
398
|
+
for p2 in midfielders:
|
|
399
|
+
edge_list.extend([[p1, p2], [p2, p1]])
|
|
400
|
+
|
|
401
|
+
# MID <-> MID and MID <-> FWD
|
|
402
|
+
for i, p1 in enumerate(midfielders):
|
|
403
|
+
for p2 in midfielders[i + 1 :]:
|
|
404
|
+
edge_list.extend([[p1, p2], [p2, p1]])
|
|
405
|
+
for p2 in forwards:
|
|
406
|
+
edge_list.extend([[p1, p2], [p2, p1]])
|
|
407
|
+
|
|
408
|
+
# FWD <-> FWD
|
|
409
|
+
for i, p1 in enumerate(forwards):
|
|
410
|
+
for p2 in forwards[i + 1 :]:
|
|
411
|
+
edge_list.extend([[p1, p2], [p2, p1]])
|
|
412
|
+
|
|
413
|
+
# ball connects to every valid player on the pitch.
|
|
414
|
+
if ball_idx is not None:
|
|
415
|
+
for i in range(node_features.shape[0]):
|
|
416
|
+
if i != ball_idx and node_features[i, 0] != MISSING_VALUE:
|
|
417
|
+
edge_list.extend([[ball_idx, i], [i, ball_idx]])
|
|
418
|
+
|
|
419
|
+
if not edge_list:
|
|
420
|
+
return np.zeros((2, 0), dtype=np.int64)
|
|
421
|
+
|
|
422
|
+
edge_array = np.array(edge_list, dtype=np.int64).T
|
|
423
|
+
edge_array = np.unique(edge_array, axis=1)
|
|
424
|
+
return edge_array
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def _build_ball_knn_edges(node_features, k):
|
|
428
|
+
"""k nearest players to the ball, plus same-team interconnections.
|
|
429
|
+
|
|
430
|
+
the ball node is connected to its k closest players. Players
|
|
431
|
+
among those k that share a team flag are also connected to each
|
|
432
|
+
other (dot product of the one-hot team columns > 0).
|
|
433
|
+
"""
|
|
434
|
+
ball_indices = np.where(node_features[:, 2] == 1.0)[0]
|
|
435
|
+
if len(ball_indices) == 0:
|
|
436
|
+
return np.zeros((2, 0), dtype=np.int64)
|
|
437
|
+
|
|
438
|
+
ball_idx = ball_indices[0]
|
|
439
|
+
ball_pos = node_features[ball_idx, :2]
|
|
440
|
+
|
|
441
|
+
if ball_pos[0] == MISSING_VALUE or ball_pos[1] == MISSING_VALUE:
|
|
442
|
+
return np.zeros((2, 0), dtype=np.int64)
|
|
443
|
+
|
|
444
|
+
num_nodes = node_features.shape[0]
|
|
445
|
+
distances = []
|
|
446
|
+
for i in range(num_nodes):
|
|
447
|
+
is_player = (
|
|
448
|
+
node_features[i, 3] == 1.0 or node_features[i, 4] == 1.0
|
|
449
|
+
)
|
|
450
|
+
if i != ball_idx and node_features[i, 0] != MISSING_VALUE and is_player:
|
|
451
|
+
dist = np.linalg.norm(node_features[i, :2] - ball_pos)
|
|
452
|
+
distances.append((i, dist))
|
|
453
|
+
|
|
454
|
+
distances.sort(key=lambda x: x[1])
|
|
455
|
+
k_nearest = distances[: min(k, len(distances))]
|
|
456
|
+
|
|
457
|
+
edge_list = []
|
|
458
|
+
|
|
459
|
+
# ball <-> each of the k nearest players.
|
|
460
|
+
for player_idx, _ in k_nearest:
|
|
461
|
+
edge_list.extend([[ball_idx, player_idx], [player_idx, ball_idx]])
|
|
462
|
+
|
|
463
|
+
# same-team interconnections among the k nearest.
|
|
464
|
+
k_nearest_indices = [idx for idx, _ in k_nearest]
|
|
465
|
+
for i, idx_i in enumerate(k_nearest_indices):
|
|
466
|
+
team_i = node_features[idx_i, 3:5]
|
|
467
|
+
for j in range(i + 1, len(k_nearest_indices)):
|
|
468
|
+
idx_j = k_nearest_indices[j]
|
|
469
|
+
team_j = node_features[idx_j, 3:5]
|
|
470
|
+
if np.dot(team_i, team_j) > 0:
|
|
471
|
+
edge_list.extend([[idx_i, idx_j], [idx_j, idx_i]])
|
|
472
|
+
|
|
473
|
+
if not edge_list:
|
|
474
|
+
return np.zeros((2, 0), dtype=np.int64)
|
|
475
|
+
return np.array(edge_list, dtype=np.int64).T
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
def _build_ball_distance_edges(node_features, threshold=20.0):
|
|
479
|
+
"""players within a distance threshold of the ball, plus same-team edges.
|
|
480
|
+
|
|
481
|
+
every player within threshold metres of the ball receives a
|
|
482
|
+
bidirectional edge to the ball node. players among those that
|
|
483
|
+
share a team flag are also interconnected.
|
|
484
|
+
"""
|
|
485
|
+
ball_indices = np.where(node_features[:, 2] == 1.0)[0]
|
|
486
|
+
if len(ball_indices) == 0:
|
|
487
|
+
return np.zeros((2, 0), dtype=np.int64)
|
|
488
|
+
|
|
489
|
+
ball_idx = ball_indices[0]
|
|
490
|
+
ball_pos = node_features[ball_idx, :2]
|
|
491
|
+
|
|
492
|
+
if ball_pos[0] == MISSING_VALUE or ball_pos[1] == MISSING_VALUE:
|
|
493
|
+
return np.zeros((2, 0), dtype=np.int64)
|
|
494
|
+
|
|
495
|
+
num_nodes = node_features.shape[0]
|
|
496
|
+
edge_list = []
|
|
497
|
+
nearby_players = []
|
|
498
|
+
|
|
499
|
+
for i in range(num_nodes):
|
|
500
|
+
is_player = (
|
|
501
|
+
node_features[i, 3] == 1.0 or node_features[i, 4] == 1.0
|
|
502
|
+
)
|
|
503
|
+
if i != ball_idx and node_features[i, 0] != MISSING_VALUE and is_player:
|
|
504
|
+
dist = np.linalg.norm(node_features[i, :2] - ball_pos)
|
|
505
|
+
if dist <= threshold:
|
|
506
|
+
edge_list.extend([[ball_idx, i], [i, ball_idx]])
|
|
507
|
+
nearby_players.append(i)
|
|
508
|
+
|
|
509
|
+
# same-team interconnections among nearby players.
|
|
510
|
+
for i, idx_i in enumerate(nearby_players):
|
|
511
|
+
team_i = node_features[idx_i, 3:5]
|
|
512
|
+
for j in range(i + 1, len(nearby_players)):
|
|
513
|
+
idx_j = nearby_players[j]
|
|
514
|
+
team_j = node_features[idx_j, 3:5]
|
|
515
|
+
if np.dot(team_i, team_j) > 0:
|
|
516
|
+
edge_list.extend([[idx_i, idx_j], [idx_j, idx_i]])
|
|
517
|
+
|
|
518
|
+
if not edge_list:
|
|
519
|
+
return np.zeros((2, 0), dtype=np.int64)
|
|
520
|
+
return np.array(edge_list, dtype=np.int64).T
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
# -------------------------------------------------------------------
|
|
524
|
+
# augmentations
|
|
525
|
+
# -------------------------------------------------------------------
|
|
526
|
+
|
|
527
|
+
class HorizontalFlip:
|
|
528
|
+
"""randomly negate x-coordinates and dx-velocities (pitch length axis).
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
probability: chance of applying the flip per call.
|
|
532
|
+
"""
|
|
533
|
+
|
|
534
|
+
def __init__(self, probability=0.5):
|
|
535
|
+
self.probability = probability
|
|
536
|
+
|
|
537
|
+
def __call__(self, features):
|
|
538
|
+
"""apply the transform.
|
|
539
|
+
|
|
540
|
+
Args:
|
|
541
|
+
features: numpy.ndarray of shape
|
|
542
|
+
(num_frames, NUM_OBJECTS, FEATURE_DIM).
|
|
543
|
+
|
|
544
|
+
Returns:
|
|
545
|
+
the (possibly flipped) feature array. a copy is made when
|
|
546
|
+
the flip is applied; the input is never modified.
|
|
547
|
+
"""
|
|
548
|
+
if random.random() < self.probability:
|
|
549
|
+
features_flipped = features.copy()
|
|
550
|
+
valid_mask = features_flipped[:, :, 0] != MISSING_VALUE
|
|
551
|
+
features_flipped[valid_mask, 0] *= -1 # flip x
|
|
552
|
+
features_flipped[valid_mask, 5] *= -1 # flip dx
|
|
553
|
+
return features_flipped
|
|
554
|
+
return features
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
class VerticalFlip:
|
|
558
|
+
"""randomly negate y-coordinates and dy-velocities (pitch width axis).
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
probability: chance of applying the flip per call.
|
|
562
|
+
"""
|
|
563
|
+
|
|
564
|
+
def __init__(self, probability=0.5):
|
|
565
|
+
self.probability = probability
|
|
566
|
+
|
|
567
|
+
def __call__(self, features):
|
|
568
|
+
"""apply the transform.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
features: numpy.ndarray of shape
|
|
572
|
+
(num_frames, NUM_OBJECTS, FEATURE_DIM).
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
the (possibly flipped) feature array.
|
|
576
|
+
"""
|
|
577
|
+
if random.random() < self.probability:
|
|
578
|
+
features_flipped = features.copy()
|
|
579
|
+
valid_mask = features_flipped[:, :, 0] != MISSING_VALUE
|
|
580
|
+
features_flipped[valid_mask, 1] *= -1 # flip y
|
|
581
|
+
features_flipped[valid_mask, 6] *= -1 # flip dy
|
|
582
|
+
return features_flipped
|
|
583
|
+
return features
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
class TeamFlip:
|
|
587
|
+
"""randomly swap the home and away one-hot team flags.
|
|
588
|
+
|
|
589
|
+
this is a label-preserving augmentation: swapping team identity
|
|
590
|
+
does not change the group activity class.
|
|
591
|
+
|
|
592
|
+
Args:
|
|
593
|
+
probability: chance of applying the swap per call.
|
|
594
|
+
"""
|
|
595
|
+
|
|
596
|
+
def __init__(self, probability=0.5):
|
|
597
|
+
self.probability = probability
|
|
598
|
+
|
|
599
|
+
def __call__(self, features):
|
|
600
|
+
"""apply the transform.
|
|
601
|
+
|
|
602
|
+
Args:
|
|
603
|
+
features: numpy.ndarray of shape
|
|
604
|
+
(num_frames, NUM_OBJECTS, FEATURE_DIM).
|
|
605
|
+
|
|
606
|
+
Returns:
|
|
607
|
+
the (possibly swapped) feature array.
|
|
608
|
+
"""
|
|
609
|
+
if random.random() < self.probability:
|
|
610
|
+
features_flipped = features.copy()
|
|
611
|
+
home_col = features_flipped[:, :, 3].copy()
|
|
612
|
+
features_flipped[:, :, 3] = features_flipped[:, :, 4]
|
|
613
|
+
features_flipped[:, :, 4] = home_col
|
|
614
|
+
return features_flipped
|
|
615
|
+
return features
|