opensportslib 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. opensportslib/__init__.py +18 -0
  2. opensportslib/apis/__init__.py +21 -0
  3. opensportslib/apis/classification.py +361 -0
  4. opensportslib/apis/localization.py +228 -0
  5. opensportslib/config/classification.yaml +104 -0
  6. opensportslib/config/classification_tracking.yaml +103 -0
  7. opensportslib/config/graph_tracking_classification/avgpool.yaml +79 -0
  8. opensportslib/config/graph_tracking_classification/gin.yaml +79 -0
  9. opensportslib/config/graph_tracking_classification/graphconv.yaml +79 -0
  10. opensportslib/config/graph_tracking_classification/graphsage.yaml +79 -0
  11. opensportslib/config/graph_tracking_classification/maxpool.yaml +79 -0
  12. opensportslib/config/graph_tracking_classification/noedges.yaml +79 -0
  13. opensportslib/config/localization.yaml +132 -0
  14. opensportslib/config/sngar_frames.yaml +98 -0
  15. opensportslib/core/__init__.py +0 -0
  16. opensportslib/core/loss/__init__.py +0 -0
  17. opensportslib/core/loss/builder.py +40 -0
  18. opensportslib/core/loss/calf.py +258 -0
  19. opensportslib/core/loss/ce.py +23 -0
  20. opensportslib/core/loss/combine.py +42 -0
  21. opensportslib/core/loss/nll.py +25 -0
  22. opensportslib/core/optimizer/__init__.py +0 -0
  23. opensportslib/core/optimizer/builder.py +38 -0
  24. opensportslib/core/sampler/weighted_sampler.py +104 -0
  25. opensportslib/core/scheduler/__init__.py +0 -0
  26. opensportslib/core/scheduler/builder.py +77 -0
  27. opensportslib/core/trainer/__init__.py +0 -0
  28. opensportslib/core/trainer/classification_trainer.py +1131 -0
  29. opensportslib/core/trainer/localization_trainer.py +1009 -0
  30. opensportslib/core/utils/checkpoint.py +238 -0
  31. opensportslib/core/utils/config.py +199 -0
  32. opensportslib/core/utils/data.py +85 -0
  33. opensportslib/core/utils/ddp.py +77 -0
  34. opensportslib/core/utils/default_args.py +110 -0
  35. opensportslib/core/utils/load_annotations.py +485 -0
  36. opensportslib/core/utils/seed.py +26 -0
  37. opensportslib/core/utils/video_processing.py +389 -0
  38. opensportslib/core/utils/wandb.py +110 -0
  39. opensportslib/datasets/__init__.py +0 -0
  40. opensportslib/datasets/builder.py +42 -0
  41. opensportslib/datasets/classification_dataset.py +582 -0
  42. opensportslib/datasets/localization_dataset.py +813 -0
  43. opensportslib/datasets/utils/__init__.py +15 -0
  44. opensportslib/datasets/utils/tracking.py +615 -0
  45. opensportslib/metrics/classification_metric.py +176 -0
  46. opensportslib/metrics/localization_metric.py +1482 -0
  47. opensportslib/models/__init__.py +0 -0
  48. opensportslib/models/backbones/builder.py +590 -0
  49. opensportslib/models/base/e2e.py +252 -0
  50. opensportslib/models/base/tracking.py +73 -0
  51. opensportslib/models/base/vars.py +29 -0
  52. opensportslib/models/base/video.py +130 -0
  53. opensportslib/models/base/video_mae.py +60 -0
  54. opensportslib/models/builder.py +43 -0
  55. opensportslib/models/heads/builder.py +266 -0
  56. opensportslib/models/neck/builder.py +210 -0
  57. opensportslib/models/utils/common.py +176 -0
  58. opensportslib/models/utils/impl/__init__.py +0 -0
  59. opensportslib/models/utils/impl/asformer.py +390 -0
  60. opensportslib/models/utils/impl/calf.py +74 -0
  61. opensportslib/models/utils/impl/gsm.py +112 -0
  62. opensportslib/models/utils/impl/gtad.py +347 -0
  63. opensportslib/models/utils/impl/tsm.py +123 -0
  64. opensportslib/models/utils/litebase.py +59 -0
  65. opensportslib/models/utils/modules.py +120 -0
  66. opensportslib/models/utils/shift.py +135 -0
  67. opensportslib/models/utils/utils.py +276 -0
  68. opensportslib-0.0.1.dev2.dist-info/METADATA +566 -0
  69. opensportslib-0.0.1.dev2.dist-info/RECORD +73 -0
  70. opensportslib-0.0.1.dev2.dist-info/WHEEL +5 -0
  71. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE +661 -0
  72. opensportslib-0.0.1.dev2.dist-info/licenses/LICENSE-COMMERCIAL +5 -0
  73. opensportslib-0.0.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,15 @@
1
+ # opensportslib/datasets/utils/__init__.py
2
+
3
+ from .tracking import (
4
+ build_edge_index,
5
+ HorizontalFlip,
6
+ VerticalFlip,
7
+ TeamFlip,
8
+ )
9
+
10
+ __all__ = [
11
+ 'build_edge_index',
12
+ 'HorizontalFlip',
13
+ 'VerticalFlip',
14
+ 'TeamFlip',
15
+ ]
@@ -0,0 +1,615 @@
1
+ # opensportslib/datasets/utils/tracking.py
2
+
3
+ """
4
+ tracking utilities for player-position graph data.
5
+
6
+ provides constants, edge-building strategies, spatial augmentations,
7
+ and frame-level feature extraction for the SN-GAR tracking
8
+ modality.
9
+
10
+ feature vector layout (per object, per frame)::
11
+
12
+ Index Field
13
+ ----- ----------------------------------
14
+ 0 x (pitch coordinates)
15
+ 1 y
16
+ 2 is_ball (one-hot entity type)
17
+ 3 is_home
18
+ 4 is_away
19
+ 5 dx (velocity delta, computed across consecutive frames)
20
+ 6 dy
21
+ 7 z (ball height; MISSING_VALUE for players)
22
+ """
23
+
24
+ import json
25
+ import random
26
+
27
+ import numpy as np
28
+ import pandas as pd
29
+
30
+ # -------------------------------------------------------------------
31
+ # constants
32
+ # -------------------------------------------------------------------
33
+
34
+ NUM_OBJECTS = 23 # 1 ball + 11 home players + 11 away players
35
+ FEATURE_DIM = 8 # [x, y, is_ball, is_home, is_away, dx, dy, z]
36
+
37
+ # normalization bounds (pitch dimension in metres).
38
+ PITCH_HALF_LENGTH = 85.0
39
+ PITCH_HALF_WIDTH = 50.0
40
+ MAX_DISPLACEMENT = 110.0
41
+ MAX_BALL_HEIGHT = 30.0
42
+
43
+ # sentinel values used to mark missing / unobserved objects.
44
+ # raw features use MISSING_VALUE; after normalization they are mapped
45
+ # to MISSING_VALUE_NORMALIZED so the model sees a distinct, in-range
46
+ # indicator that cannot be confused with a valid coordinate.
47
+ MISSING_VALUE = -200.0
48
+ MISSING_VALUE_NORMALIZED = -2.0
49
+
50
+ # slot layout: index 0 is always the ball, 1-11 home, 12-22 away.
51
+ _BALL_SLOT = 0
52
+ _HOME_SLOT_START = 1
53
+ _HOME_SLOT_END = 12 # exclusive
54
+ _AWAY_SLOT_START = 12
55
+
56
+
57
+ # -------------------------------------------------------------------
58
+ # frame-level feature extraction
59
+ # -------------------------------------------------------------------
60
+
61
+ def parse_frame(row):
62
+ """parse a single parquet row into per-object features and positions.
63
+
64
+ Each row is expected to contain JSON-encoded columns balls,
65
+ homePlayers, and awayPlayers following the SN-GAR
66
+ tracking format.
67
+
68
+ Args:
69
+ row: A single row (pandas.Series) from a tracking parquet
70
+ file.
71
+
72
+ Returns:
73
+ A tuple (features, positions) where features is a
74
+ numpy.ndarray of shape (NUM_OBJECTS, FEATURE_DIM) and
75
+ positions is a list of position-group strings (e.g.
76
+ "GK", "DEF"). Unobserved slots are filled with
77
+ MISSING_VALUE and an empty string respectively.
78
+ """
79
+ features = np.full(
80
+ (NUM_OBJECTS, FEATURE_DIM), MISSING_VALUE, dtype=np.float32,
81
+ )
82
+ positions = [""] * NUM_OBJECTS
83
+
84
+ obj_idx = _BALL_SLOT
85
+
86
+ # -- ball (always slot 0) --
87
+ ball_str = row.get("balls", "null")
88
+ if pd.notna(ball_str) and ball_str not in ("null", ""):
89
+ try:
90
+ ball_list = json.loads(ball_str)
91
+ if ball_list:
92
+ ball = ball_list[0]
93
+ x, y = ball.get("x"), ball.get("y")
94
+ z = ball.get("z", 0)
95
+ if x is not None and y is not None:
96
+ features[obj_idx] = [
97
+ float(x), float(y),
98
+ 1, 0, 0, # one-hot: ball
99
+ 0, 0, # dx, dy (filled later)
100
+ float(z),
101
+ ]
102
+ positions[obj_idx] = "BALL"
103
+ except (json.JSONDecodeError, TypeError):
104
+ pass
105
+ obj_idx = _HOME_SLOT_START
106
+
107
+ # -- home players (slots 1-11) --
108
+ home_str = row.get("homePlayers", "[]")
109
+ if pd.notna(home_str) and home_str not in ("null", ""):
110
+ try:
111
+ home_players = json.loads(home_str)
112
+ home_players = sorted(
113
+ home_players,
114
+ key=lambda p: int(p.get("jerseyNum", 0)),
115
+ )[:11]
116
+
117
+ for player in home_players:
118
+ x, y = player.get("x"), player.get("y")
119
+ if x is not None and y is not None:
120
+ features[obj_idx] = [
121
+ float(x), float(y),
122
+ 0, 1, 0, # one-hot: home
123
+ 0, 0, # dx, dy
124
+ MISSING_VALUE, # z unused for players
125
+ ]
126
+ positions[obj_idx] = player.get("positionGroup", "")
127
+ obj_idx += 1
128
+
129
+ # advance past any unfilled home slots.
130
+ while obj_idx < _HOME_SLOT_END:
131
+ obj_idx += 1
132
+ except (json.JSONDecodeError, TypeError):
133
+ obj_idx = _HOME_SLOT_END
134
+ else:
135
+ obj_idx = _HOME_SLOT_END
136
+
137
+ # -- away players (slots 12-22) --
138
+ away_str = row.get("awayPlayers", "[]")
139
+ if pd.notna(away_str) and away_str not in ("null", ""):
140
+ try:
141
+ away_players = json.loads(away_str)
142
+ away_players = sorted(
143
+ away_players,
144
+ key=lambda p: int(p.get("jerseyNum", 0)),
145
+ )[:11]
146
+
147
+ for player in away_players:
148
+ x, y = player.get("x"), player.get("y")
149
+ if x is not None and y is not None:
150
+ features[obj_idx] = [
151
+ float(x), float(y),
152
+ 0, 0, 1, # one-hot: away
153
+ 0, 0, # dx, dy
154
+ MISSING_VALUE, # z unused for players
155
+ ]
156
+ positions[obj_idx] = player.get("positionGroup", "")
157
+ obj_idx += 1
158
+ except (json.JSONDecodeError, TypeError):
159
+ pass
160
+
161
+ return features, positions
162
+
163
+
164
+ # -------------------------------------------------------------------
165
+ # temporal feature computation
166
+ # -------------------------------------------------------------------
167
+
168
+ def compute_deltas(all_features):
169
+ """compute per-object velocity deltas across consecutive frames.
170
+
171
+ For each object that is observed in both frame t and frame t - 1,
172
+ the displacement (dx, dy) is written into feature indices 5 and 6.
173
+ frame 0 retains zero deltas.
174
+
175
+ Args:
176
+ all_features: numpy.ndarray of shape
177
+ (num_frames, NUM_OBJECTS, FEATURE_DIM).
178
+
179
+ Returns:
180
+ the same array, modified in-place, with velocity deltas
181
+ populated.
182
+ """
183
+ for t in range(1, all_features.shape[0]):
184
+ for obj in range(NUM_OBJECTS):
185
+ curr_valid = all_features[t, obj, 0] != MISSING_VALUE
186
+ prev_valid = all_features[t - 1, obj, 0] != MISSING_VALUE
187
+ if curr_valid and prev_valid:
188
+ all_features[t, obj, 5] = (
189
+ all_features[t, obj, 0] - all_features[t - 1, obj, 0]
190
+ )
191
+ all_features[t, obj, 6] = (
192
+ all_features[t, obj, 1] - all_features[t - 1, obj, 1]
193
+ )
194
+ return all_features
195
+
196
+
197
+ def normalize_features(features):
198
+ """normalize spatial features to roughly [-1, 1].
199
+
200
+ observed coordinates are divided by the known pitch / displacement
201
+ bounds. unobserved slots are set to MISSING_VALUE_NORMALIZED
202
+ so the model receives a distinct, out-of-range sentinel that
203
+ cannot be confused with a valid normalized value.
204
+
205
+ Args:
206
+ features: numpy.ndarray of shape
207
+ (num_frames, NUM_OBJECTS, FEATURE_DIM).
208
+
209
+ Returns:
210
+ a new array with the same shape containing normalized values.
211
+ the input is not modified.
212
+ """
213
+ features_norm = features.copy()
214
+ valid_mask = features_norm[:, :, 0] != MISSING_VALUE
215
+
216
+ features_norm[valid_mask, 0] /= PITCH_HALF_LENGTH
217
+ features_norm[valid_mask, 1] /= PITCH_HALF_WIDTH
218
+ features_norm[valid_mask, 5] /= MAX_DISPLACEMENT
219
+ features_norm[valid_mask, 6] /= MAX_DISPLACEMENT
220
+ features_norm[valid_mask, 7] /= MAX_BALL_HEIGHT
221
+
222
+ # write the normalized sentinel into every spatial channel of
223
+ # missing objects so downstream layers can distinguish "absent"
224
+ # from "at the origin".
225
+ for ch in (0, 1, 5, 6, 7):
226
+ features_norm[~valid_mask, ch] = MISSING_VALUE_NORMALIZED
227
+
228
+ return features_norm
229
+
230
+
231
+ # -------------------------------------------------------------------
232
+ # edge building strategies
233
+ # -------------------------------------------------------------------
234
+
235
+ def build_edge_index(node_features, node_positions, edge_type, k=8, r=15.0):
236
+ """build a graph edge index for a single frame.
237
+
238
+ supports several connectivity strategies that trade off density
239
+ against spatial or tactical priors.
240
+
241
+ Args:
242
+ node_features: numpy.ndarray of shape
243
+ (NUM_OBJECTS, FEATURE_DIM).
244
+ node_positions: List of position-group strings (length
245
+ NUM_OBJECTS).
246
+ edge_type: One of "none", "full", "knn",
247
+ "distance", "ball_knn", "ball_distance",
248
+ or "positional".
249
+ k: Number of neighbours for knn / ball_knn strategies.
250
+ r: Distance threshold (metres) for distance / ball_distance
251
+ strategies.
252
+
253
+ Returns:
254
+ numpy.ndarray of shape (2, num_edges) in COO format,
255
+ compatible with PyTorch Geometric.
256
+ """
257
+ if edge_type == "none":
258
+ return np.zeros((2, 0), dtype=np.int64)
259
+
260
+ if edge_type == "full":
261
+ return _build_full_edges(node_features)
262
+
263
+ if edge_type == "knn":
264
+ return _build_knn_edges(node_features, k)
265
+
266
+ if edge_type == "distance":
267
+ return _build_distance_edges(node_features, threshold=r)
268
+
269
+ if edge_type == "ball_knn":
270
+ return _build_ball_knn_edges(node_features, k)
271
+
272
+ if edge_type == "ball_distance":
273
+ return _build_ball_distance_edges(node_features, threshold=r)
274
+
275
+ if edge_type == "positional":
276
+ return _build_positional_edges(node_features, node_positions)
277
+
278
+ raise ValueError(f"Unknown edge_type: {edge_type}")
279
+
280
+
281
+ # -- strategy implementations (private) ----------------------------
282
+
283
+ def _build_full_edges(node_features):
284
+ """fully connected graph, excluding self-loops and missing nodes."""
285
+ num_nodes = node_features.shape[0]
286
+ edge_list = [
287
+ [i, j]
288
+ for i in range(num_nodes)
289
+ for j in range(num_nodes)
290
+ if i != j
291
+ and node_features[i, 0] != MISSING_VALUE
292
+ and node_features[j, 0] != MISSING_VALUE
293
+ ]
294
+
295
+ if not edge_list:
296
+ return np.zeros((2, 0), dtype=np.int64)
297
+ return np.array(edge_list, dtype=np.int64).T
298
+
299
+
300
+ def _build_knn_edges(node_features, k):
301
+ """k-nearest-neighbour edges based on Euclidean pitch distance."""
302
+ num_nodes = node_features.shape[0]
303
+ edge_list = []
304
+
305
+ for i in range(num_nodes):
306
+ if node_features[i, 0] == MISSING_VALUE:
307
+ continue
308
+
309
+ distances = []
310
+ for j in range(num_nodes):
311
+ if i != j and node_features[j, 0] != MISSING_VALUE:
312
+ dist = np.linalg.norm(
313
+ node_features[i, :2] - node_features[j, :2],
314
+ )
315
+ distances.append((j, dist))
316
+
317
+ distances.sort(key=lambda x: x[1])
318
+ k_nearest = distances[: min(k, len(distances))]
319
+
320
+ for neighbour_idx, _ in k_nearest:
321
+ edge_list.append([i, neighbour_idx])
322
+ edge_list.append([neighbour_idx, i])
323
+
324
+ if not edge_list:
325
+ return np.zeros((2, 0), dtype=np.int64)
326
+
327
+ # de-duplicate symmetric pairs.
328
+ edge_array = np.array(edge_list, dtype=np.int64).T
329
+ edge_array = np.unique(edge_array, axis=1)
330
+ return edge_array
331
+
332
+
333
+ def _build_distance_edges(node_features, threshold=15.0):
334
+ """edges between all node pairs within a distance threshold."""
335
+ num_nodes = node_features.shape[0]
336
+ edge_list = []
337
+
338
+ for i in range(num_nodes):
339
+ if node_features[i, 0] == MISSING_VALUE:
340
+ continue
341
+ for j in range(i + 1, num_nodes):
342
+ if node_features[j, 0] == MISSING_VALUE:
343
+ continue
344
+ dist = np.linalg.norm(
345
+ node_features[i, :2] - node_features[j, :2],
346
+ )
347
+ if dist <= threshold:
348
+ edge_list.append([i, j])
349
+ edge_list.append([j, i])
350
+
351
+ if not edge_list:
352
+ return np.zeros((2, 0), dtype=np.int64)
353
+ return np.array(edge_list, dtype=np.int64).T
354
+
355
+
356
+ def _build_positional_edges(node_features, node_positions):
357
+ """tactical-structure edges following formation lines.
358
+
359
+ connects adjacent lines within each team
360
+ (GK <-> DEF <-> MID <-> FWD) and players within the same line.
361
+ the ball is connected to every valid player on the pitch.
362
+ """
363
+ edge_list = []
364
+
365
+ home_players = {}
366
+ away_players = {}
367
+ ball_idx = None
368
+
369
+ for i in range(node_features.shape[0]):
370
+ if node_features[i, 0] == MISSING_VALUE or not node_positions[i]:
371
+ continue
372
+
373
+ pos = node_positions[i]
374
+
375
+ if pos == "BALL":
376
+ ball_idx = i
377
+ elif node_features[i, 3] == 1.0: # home flag
378
+ home_players.setdefault(pos, []).append(i)
379
+ elif node_features[i, 4] == 1.0: # away flag
380
+ away_players.setdefault(pos, []).append(i)
381
+
382
+ # intra-team tactical edges for both teams.
383
+ for team_players in (home_players, away_players):
384
+ gk = team_players.get("GK", [])
385
+ defenders = team_players.get("DEF", [])
386
+ midfielders = team_players.get("MID", [])
387
+ forwards = team_players.get("FWD", [])
388
+
389
+ # GK <-> DEF
390
+ for p1 in gk:
391
+ for p2 in defenders:
392
+ edge_list.extend([[p1, p2], [p2, p1]])
393
+
394
+ # DEF <-> DEF and DEF <-> MID
395
+ for i, p1 in enumerate(defenders):
396
+ for p2 in defenders[i + 1 :]:
397
+ edge_list.extend([[p1, p2], [p2, p1]])
398
+ for p2 in midfielders:
399
+ edge_list.extend([[p1, p2], [p2, p1]])
400
+
401
+ # MID <-> MID and MID <-> FWD
402
+ for i, p1 in enumerate(midfielders):
403
+ for p2 in midfielders[i + 1 :]:
404
+ edge_list.extend([[p1, p2], [p2, p1]])
405
+ for p2 in forwards:
406
+ edge_list.extend([[p1, p2], [p2, p1]])
407
+
408
+ # FWD <-> FWD
409
+ for i, p1 in enumerate(forwards):
410
+ for p2 in forwards[i + 1 :]:
411
+ edge_list.extend([[p1, p2], [p2, p1]])
412
+
413
+ # ball connects to every valid player on the pitch.
414
+ if ball_idx is not None:
415
+ for i in range(node_features.shape[0]):
416
+ if i != ball_idx and node_features[i, 0] != MISSING_VALUE:
417
+ edge_list.extend([[ball_idx, i], [i, ball_idx]])
418
+
419
+ if not edge_list:
420
+ return np.zeros((2, 0), dtype=np.int64)
421
+
422
+ edge_array = np.array(edge_list, dtype=np.int64).T
423
+ edge_array = np.unique(edge_array, axis=1)
424
+ return edge_array
425
+
426
+
427
+ def _build_ball_knn_edges(node_features, k):
428
+ """k nearest players to the ball, plus same-team interconnections.
429
+
430
+ the ball node is connected to its k closest players. Players
431
+ among those k that share a team flag are also connected to each
432
+ other (dot product of the one-hot team columns > 0).
433
+ """
434
+ ball_indices = np.where(node_features[:, 2] == 1.0)[0]
435
+ if len(ball_indices) == 0:
436
+ return np.zeros((2, 0), dtype=np.int64)
437
+
438
+ ball_idx = ball_indices[0]
439
+ ball_pos = node_features[ball_idx, :2]
440
+
441
+ if ball_pos[0] == MISSING_VALUE or ball_pos[1] == MISSING_VALUE:
442
+ return np.zeros((2, 0), dtype=np.int64)
443
+
444
+ num_nodes = node_features.shape[0]
445
+ distances = []
446
+ for i in range(num_nodes):
447
+ is_player = (
448
+ node_features[i, 3] == 1.0 or node_features[i, 4] == 1.0
449
+ )
450
+ if i != ball_idx and node_features[i, 0] != MISSING_VALUE and is_player:
451
+ dist = np.linalg.norm(node_features[i, :2] - ball_pos)
452
+ distances.append((i, dist))
453
+
454
+ distances.sort(key=lambda x: x[1])
455
+ k_nearest = distances[: min(k, len(distances))]
456
+
457
+ edge_list = []
458
+
459
+ # ball <-> each of the k nearest players.
460
+ for player_idx, _ in k_nearest:
461
+ edge_list.extend([[ball_idx, player_idx], [player_idx, ball_idx]])
462
+
463
+ # same-team interconnections among the k nearest.
464
+ k_nearest_indices = [idx for idx, _ in k_nearest]
465
+ for i, idx_i in enumerate(k_nearest_indices):
466
+ team_i = node_features[idx_i, 3:5]
467
+ for j in range(i + 1, len(k_nearest_indices)):
468
+ idx_j = k_nearest_indices[j]
469
+ team_j = node_features[idx_j, 3:5]
470
+ if np.dot(team_i, team_j) > 0:
471
+ edge_list.extend([[idx_i, idx_j], [idx_j, idx_i]])
472
+
473
+ if not edge_list:
474
+ return np.zeros((2, 0), dtype=np.int64)
475
+ return np.array(edge_list, dtype=np.int64).T
476
+
477
+
478
+ def _build_ball_distance_edges(node_features, threshold=20.0):
479
+ """players within a distance threshold of the ball, plus same-team edges.
480
+
481
+ every player within threshold metres of the ball receives a
482
+ bidirectional edge to the ball node. players among those that
483
+ share a team flag are also interconnected.
484
+ """
485
+ ball_indices = np.where(node_features[:, 2] == 1.0)[0]
486
+ if len(ball_indices) == 0:
487
+ return np.zeros((2, 0), dtype=np.int64)
488
+
489
+ ball_idx = ball_indices[0]
490
+ ball_pos = node_features[ball_idx, :2]
491
+
492
+ if ball_pos[0] == MISSING_VALUE or ball_pos[1] == MISSING_VALUE:
493
+ return np.zeros((2, 0), dtype=np.int64)
494
+
495
+ num_nodes = node_features.shape[0]
496
+ edge_list = []
497
+ nearby_players = []
498
+
499
+ for i in range(num_nodes):
500
+ is_player = (
501
+ node_features[i, 3] == 1.0 or node_features[i, 4] == 1.0
502
+ )
503
+ if i != ball_idx and node_features[i, 0] != MISSING_VALUE and is_player:
504
+ dist = np.linalg.norm(node_features[i, :2] - ball_pos)
505
+ if dist <= threshold:
506
+ edge_list.extend([[ball_idx, i], [i, ball_idx]])
507
+ nearby_players.append(i)
508
+
509
+ # same-team interconnections among nearby players.
510
+ for i, idx_i in enumerate(nearby_players):
511
+ team_i = node_features[idx_i, 3:5]
512
+ for j in range(i + 1, len(nearby_players)):
513
+ idx_j = nearby_players[j]
514
+ team_j = node_features[idx_j, 3:5]
515
+ if np.dot(team_i, team_j) > 0:
516
+ edge_list.extend([[idx_i, idx_j], [idx_j, idx_i]])
517
+
518
+ if not edge_list:
519
+ return np.zeros((2, 0), dtype=np.int64)
520
+ return np.array(edge_list, dtype=np.int64).T
521
+
522
+
523
+ # -------------------------------------------------------------------
524
+ # augmentations
525
+ # -------------------------------------------------------------------
526
+
527
+ class HorizontalFlip:
528
+ """randomly negate x-coordinates and dx-velocities (pitch length axis).
529
+
530
+ Args:
531
+ probability: chance of applying the flip per call.
532
+ """
533
+
534
+ def __init__(self, probability=0.5):
535
+ self.probability = probability
536
+
537
+ def __call__(self, features):
538
+ """apply the transform.
539
+
540
+ Args:
541
+ features: numpy.ndarray of shape
542
+ (num_frames, NUM_OBJECTS, FEATURE_DIM).
543
+
544
+ Returns:
545
+ the (possibly flipped) feature array. a copy is made when
546
+ the flip is applied; the input is never modified.
547
+ """
548
+ if random.random() < self.probability:
549
+ features_flipped = features.copy()
550
+ valid_mask = features_flipped[:, :, 0] != MISSING_VALUE
551
+ features_flipped[valid_mask, 0] *= -1 # flip x
552
+ features_flipped[valid_mask, 5] *= -1 # flip dx
553
+ return features_flipped
554
+ return features
555
+
556
+
557
+ class VerticalFlip:
558
+ """randomly negate y-coordinates and dy-velocities (pitch width axis).
559
+
560
+ Args:
561
+ probability: chance of applying the flip per call.
562
+ """
563
+
564
+ def __init__(self, probability=0.5):
565
+ self.probability = probability
566
+
567
+ def __call__(self, features):
568
+ """apply the transform.
569
+
570
+ Args:
571
+ features: numpy.ndarray of shape
572
+ (num_frames, NUM_OBJECTS, FEATURE_DIM).
573
+
574
+ Returns:
575
+ the (possibly flipped) feature array.
576
+ """
577
+ if random.random() < self.probability:
578
+ features_flipped = features.copy()
579
+ valid_mask = features_flipped[:, :, 0] != MISSING_VALUE
580
+ features_flipped[valid_mask, 1] *= -1 # flip y
581
+ features_flipped[valid_mask, 6] *= -1 # flip dy
582
+ return features_flipped
583
+ return features
584
+
585
+
586
+ class TeamFlip:
587
+ """randomly swap the home and away one-hot team flags.
588
+
589
+ this is a label-preserving augmentation: swapping team identity
590
+ does not change the group activity class.
591
+
592
+ Args:
593
+ probability: chance of applying the swap per call.
594
+ """
595
+
596
+ def __init__(self, probability=0.5):
597
+ self.probability = probability
598
+
599
+ def __call__(self, features):
600
+ """apply the transform.
601
+
602
+ Args:
603
+ features: numpy.ndarray of shape
604
+ (num_frames, NUM_OBJECTS, FEATURE_DIM).
605
+
606
+ Returns:
607
+ the (possibly swapped) feature array.
608
+ """
609
+ if random.random() < self.probability:
610
+ features_flipped = features.copy()
611
+ home_col = features_flipped[:, :, 3].copy()
612
+ features_flipped[:, :, 3] = features_flipped[:, :, 4]
613
+ features_flipped[:, :, 4] = home_col
614
+ return features_flipped
615
+ return features