rustat-python-api 0.5.8__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rustat_python_api/pitch_control.py +424 -48
- rustat_python_api/processing.py +17 -10
- {rustat_python_api-0.5.8.dist-info → rustat_python_api-0.6.0.dist-info}/METADATA +2 -1
- {rustat_python_api-0.5.8.dist-info → rustat_python_api-0.6.0.dist-info}/RECORD +7 -7
- {rustat_python_api-0.5.8.dist-info → rustat_python_api-0.6.0.dist-info}/LICENSE +0 -0
- {rustat_python_api-0.5.8.dist-info → rustat_python_api-0.6.0.dist-info}/WHEEL +0 -0
- {rustat_python_api-0.5.8.dist-info → rustat_python_api-0.6.0.dist-info}/top_level.txt +0 -0
|
@@ -4,6 +4,8 @@ from scipy.stats import multivariate_normal as mvn
|
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
5
|
import matplotlib.animation as animation
|
|
6
6
|
import matplotsoccer as mpl
|
|
7
|
+
import torch
|
|
8
|
+
from tqdm import tqdm
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
class PitchControl:
|
|
@@ -20,6 +22,8 @@ class PitchControl:
|
|
|
20
22
|
}
|
|
21
23
|
}
|
|
22
24
|
|
|
25
|
+
self._grid_cache: dict[tuple[int, str, torch.dtype], tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {}
|
|
26
|
+
|
|
23
27
|
self.locs_home, self.locs_away, self.locs_ball, self.t = self.get_locs(
|
|
24
28
|
tracking,
|
|
25
29
|
events,
|
|
@@ -32,12 +36,7 @@ class PitchControl:
|
|
|
32
36
|
'half', 'second', 'pos_x', 'pos_y'
|
|
33
37
|
]]
|
|
34
38
|
|
|
35
|
-
events
|
|
36
|
-
lambda x: self.swap_coords(x, 'x'), axis=1
|
|
37
|
-
)
|
|
38
|
-
events.loc[:, 'pos_y'] = events.apply(
|
|
39
|
-
lambda x: self.swap_coords(x, 'y'), axis=1
|
|
40
|
-
)
|
|
39
|
+
events = self.swap_coords_batch(events)
|
|
41
40
|
|
|
42
41
|
if ball_data is None:
|
|
43
42
|
ball_data = self.interpolate_ball_data(
|
|
@@ -45,21 +44,7 @@ class PitchControl:
|
|
|
45
44
|
tracking
|
|
46
45
|
)
|
|
47
46
|
|
|
48
|
-
locs_home =
|
|
49
|
-
half: {
|
|
50
|
-
player_id: self.get_player_data(player_id, half, tracking)
|
|
51
|
-
for player_id in tracking[tracking['side_1h'] == 'left']['player_id'].unique()
|
|
52
|
-
}
|
|
53
|
-
for half in tracking['half'].unique()
|
|
54
|
-
}
|
|
55
|
-
|
|
56
|
-
locs_away = {
|
|
57
|
-
half: {
|
|
58
|
-
player_id: self.get_player_data(player_id, half, tracking)
|
|
59
|
-
for player_id in tracking[tracking['side_1h'] == 'right']['player_id'].unique()
|
|
60
|
-
}
|
|
61
|
-
for half in tracking['half'].unique()
|
|
62
|
-
}
|
|
47
|
+
locs_home, locs_away = self.build_player_locs(tracking)
|
|
63
48
|
|
|
64
49
|
locs_ball = {
|
|
65
50
|
half: ball_data[ball_data['half'] == half][['pos_x', 'pos_y']].values
|
|
@@ -73,28 +58,73 @@ class PitchControl:
|
|
|
73
58
|
|
|
74
59
|
return locs_home, locs_away, locs_ball, t
|
|
75
60
|
|
|
61
|
+
# def swap_coords(self, row, how: str = 'x'):
|
|
62
|
+
# half = row['half']
|
|
63
|
+
# team_id = row['team_id']
|
|
64
|
+
# possession_team_id = row['possession_team_id']
|
|
65
|
+
# x = row['pos_x']
|
|
66
|
+
# y = row['pos_y']
|
|
67
|
+
|
|
68
|
+
# if isinstance(possession_team_id, list):
|
|
69
|
+
# current_side = 'left' if team_id in possession_team_id else 'right'
|
|
70
|
+
# real_side = self.side_by_half[half][str(int(team_id))]
|
|
71
|
+
# else:
|
|
72
|
+
# current_side = 'left' if team_id == possession_team_id else 'right'
|
|
73
|
+
# real_side = self.side_by_half[half][str(int(team_id))]
|
|
74
|
+
|
|
75
|
+
# if current_side != real_side:
|
|
76
|
+
# if how == 'x':
|
|
77
|
+
# x = 105 - x
|
|
78
|
+
# else:
|
|
79
|
+
# y = 68 - y
|
|
80
|
+
|
|
81
|
+
# return x if how == 'x' else y
|
|
82
|
+
|
|
83
|
+
def swap_coords_batch(self, events: pd.DataFrame) -> pd.DataFrame:
|
|
84
|
+
"""Vectorised replacement for per-row `swap_coords`.
|
|
85
|
+
|
|
86
|
+
Modifies *events* in-place: flips coordinates for rows where the
|
|
87
|
+
current attacking direction (left/right) does not match the
|
|
88
|
+
canonical side stored in ``self.side_by_half``.
|
|
89
|
+
Returns the same DataFrame for chaining.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
side_by_half = self.side_by_half # local alias for speed
|
|
76
93
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
x = row['pos_x']
|
|
82
|
-
y = row['pos_y']
|
|
94
|
+
def needs_swap(row):
|
|
95
|
+
team_id = row['team_id']
|
|
96
|
+
poss = row['possession_team_id']
|
|
97
|
+
half = row['half']
|
|
83
98
|
|
|
84
|
-
|
|
85
|
-
current_side = 'left' if
|
|
86
|
-
real_side =
|
|
87
|
-
|
|
88
|
-
current_side = 'left' if team_id == possession_team_id else 'right'
|
|
89
|
-
real_side = self.side_by_half[half][str(int(team_id))]
|
|
99
|
+
current_left = team_id in poss if isinstance(poss, list) else team_id == poss
|
|
100
|
+
current_side = 'left' if current_left else 'right'
|
|
101
|
+
real_side = side_by_half[half][str(int(team_id))]
|
|
102
|
+
return current_side != real_side
|
|
90
103
|
|
|
91
|
-
|
|
92
|
-
if how == 'x':
|
|
93
|
-
x = 105 - x
|
|
94
|
-
else:
|
|
95
|
-
y = 68 - y
|
|
104
|
+
mask = events.apply(needs_swap, axis=1)
|
|
96
105
|
|
|
97
|
-
|
|
106
|
+
# flip coords in bulk
|
|
107
|
+
events.loc[mask, 'pos_x'] = 105 - events.loc[mask, 'pos_x']
|
|
108
|
+
events.loc[mask, 'pos_y'] = 68 - events.loc[mask, 'pos_y']
|
|
109
|
+
return events
|
|
110
|
+
|
|
111
|
+
def build_player_locs(self, tracking: pd.DataFrame):
|
|
112
|
+
"""Vectorised construction of player location dictionaries.
|
|
113
|
+
|
|
114
|
+
Returns (locs_home, locs_away) where each is
|
|
115
|
+
{half: {player_id: np.ndarray(T,2)}}.
|
|
116
|
+
"""
|
|
117
|
+
locs_home = {1: {}, 2: {}}
|
|
118
|
+
locs_away = {1: {}, 2: {}}
|
|
119
|
+
|
|
120
|
+
# Work per half to keep order and avoid extra boolean checks.
|
|
121
|
+
for half in (1, 2):
|
|
122
|
+
half_df = tracking[tracking['half'] == half]
|
|
123
|
+
for side, locs_out in [('left', locs_home), ('right', locs_away)]:
|
|
124
|
+
side_df = half_df[half_df['side_1h'] == side]
|
|
125
|
+
for pid, grp in side_df.groupby('player_id'):
|
|
126
|
+
locs_out[half][pid] = grp[['pos_x', 'pos_y']].values
|
|
127
|
+
return locs_home, locs_away
|
|
98
128
|
|
|
99
129
|
@staticmethod
|
|
100
130
|
def interpolate_ball_data(
|
|
@@ -137,12 +167,7 @@ class PitchControl:
|
|
|
137
167
|
|
|
138
168
|
return player_data_full[['pos_x', 'pos_y']].values
|
|
139
169
|
|
|
140
|
-
|
|
141
|
-
# (tracking['player_id'] == player_id)
|
|
142
|
-
# & (tracking['half'] == half)
|
|
143
|
-
# ][['pos_x', 'pos_y']].values
|
|
144
|
-
|
|
145
|
-
def influence_function(
|
|
170
|
+
def influence_np(
|
|
146
171
|
self,
|
|
147
172
|
player_index: str,
|
|
148
173
|
location: np.ndarray,
|
|
@@ -199,7 +224,297 @@ class PitchControl:
|
|
|
199
224
|
out = np.zeros(location.shape[0])
|
|
200
225
|
return out
|
|
201
226
|
|
|
202
|
-
def
|
|
227
|
+
def _batch_influence_pt(
|
|
228
|
+
self,
|
|
229
|
+
player_dict: dict,
|
|
230
|
+
locs: torch.Tensor,
|
|
231
|
+
time_index: int,
|
|
232
|
+
half: int,
|
|
233
|
+
device: str,
|
|
234
|
+
dtype: torch.dtype,
|
|
235
|
+
) -> torch.Tensor:
|
|
236
|
+
"""Compute cumulative influence of *many* players at once.
|
|
237
|
+
|
|
238
|
+
Parameters
|
|
239
|
+
----------
|
|
240
|
+
player_dict : dict[player_id -> np.ndarray(shape=(T,2))]
|
|
241
|
+
Pre-loaded trajectory arrays for one team.
|
|
242
|
+
locs : torch.Tensor shape (N,2)
|
|
243
|
+
Grid locations (already on correct device / dtype).
|
|
244
|
+
time_index : int
|
|
245
|
+
Frame index t.
|
|
246
|
+
half : int
|
|
247
|
+
Half number.
|
|
248
|
+
device, dtype : torch configuration.
|
|
249
|
+
|
|
250
|
+
Returns
|
|
251
|
+
-------
|
|
252
|
+
torch.Tensor shape (N,)
|
|
253
|
+
Sum of influences from all valid players in *player_dict*.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
pos_t_list, pos_tp1_list = [], []
|
|
257
|
+
for arr in player_dict.values():
|
|
258
|
+
# Ensure we have t and t+1 and no NaNs at those rows
|
|
259
|
+
if (
|
|
260
|
+
time_index + 1 < arr.shape[0]
|
|
261
|
+
and np.isfinite(arr[[time_index, time_index + 1], :]).all()
|
|
262
|
+
):
|
|
263
|
+
pos_t_list.append(arr[time_index])
|
|
264
|
+
pos_tp1_list.append(arr[time_index + 1])
|
|
265
|
+
|
|
266
|
+
if not pos_t_list:
|
|
267
|
+
return torch.zeros(locs.shape[0], device=device, dtype=dtype)
|
|
268
|
+
|
|
269
|
+
pos_t = torch.tensor(np.asarray(pos_t_list), device=device, dtype=dtype) # (P,2)
|
|
270
|
+
pos_tp1 = torch.tensor(np.asarray(pos_tp1_list), device=device, dtype=dtype) # (P,2)
|
|
271
|
+
|
|
272
|
+
# Velocity, speed, rotation -------------------------
|
|
273
|
+
dt_sec = float(self.t[half][time_index + 1] - self.t[half][time_index])
|
|
274
|
+
sxy = (pos_tp1 - pos_t) / dt_sec # (P,2)
|
|
275
|
+
|
|
276
|
+
speed = torch.linalg.norm(sxy, dim=1) # (P,)
|
|
277
|
+
norm_sxy = speed.clamp(min=1e-6)
|
|
278
|
+
|
|
279
|
+
theta = torch.acos(torch.clamp(sxy[:, 0] / norm_sxy, -1 + 1e-6, 1 - 1e-6)) # (P,)
|
|
280
|
+
cos_t, sin_t = torch.cos(theta), torch.sin(theta)
|
|
281
|
+
|
|
282
|
+
R = torch.stack(
|
|
283
|
+
[
|
|
284
|
+
torch.stack([cos_t, -sin_t], dim=1),
|
|
285
|
+
torch.stack([sin_t, cos_t], dim=1),
|
|
286
|
+
],
|
|
287
|
+
dim=1,
|
|
288
|
+
) # (P,2,2)
|
|
289
|
+
|
|
290
|
+
# Shape parameters ----------------------------------
|
|
291
|
+
Srat = (speed / 13) ** 2 # (P,)
|
|
292
|
+
|
|
293
|
+
ball_pos = torch.tensor(
|
|
294
|
+
self.locs_ball[half][time_index], device=device, dtype=dtype
|
|
295
|
+
) # (2,)
|
|
296
|
+
Ri = torch.linalg.norm(ball_pos - pos_t, dim=1) # (P,)
|
|
297
|
+
Ri = torch.minimum(4 + Ri ** 3 / (18 ** 3 / 6), torch.tensor(10.0, device=device, dtype=dtype))
|
|
298
|
+
|
|
299
|
+
S11 = (1 + Srat) * Ri / 2
|
|
300
|
+
S22 = (1 - Srat) * Ri / 2
|
|
301
|
+
|
|
302
|
+
S = torch.zeros((pos_t.shape[0], 2, 2), device=device, dtype=dtype)
|
|
303
|
+
S[:, 0, 0] = S11
|
|
304
|
+
S[:, 1, 1] = S22
|
|
305
|
+
|
|
306
|
+
Sigma = R @ S @ S @ R.transpose(1, 2) # (P,2,2)
|
|
307
|
+
|
|
308
|
+
eye = torch.eye(2, device=device, dtype=dtype) * 1e-6
|
|
309
|
+
eye = eye.expand(pos_t.shape[0], 2, 2) # broadcast to (P,2,2)
|
|
310
|
+
|
|
311
|
+
if dtype == torch.float16:
|
|
312
|
+
Sigma_inv = torch.linalg.inv((Sigma + eye).float()).to(dtype)
|
|
313
|
+
else:
|
|
314
|
+
Sigma_inv = torch.linalg.inv(Sigma + eye)
|
|
315
|
+
|
|
316
|
+
# Mean ----------------------------------------------
|
|
317
|
+
mu = pos_t + 0.5 * sxy # (P,2)
|
|
318
|
+
|
|
319
|
+
# Grid diff & Mahalanobis ----------------------------
|
|
320
|
+
diff = locs.view(1, -1, 2) # (1,N,2)
|
|
321
|
+
diff = diff - mu.unsqueeze(1) # (P,N,2)
|
|
322
|
+
|
|
323
|
+
maha = torch.einsum('pni,pij,pnj->pn', diff, Sigma_inv, diff) # (P,N)
|
|
324
|
+
|
|
325
|
+
# Replace NaNs that arise from invalid player positions with large value
|
|
326
|
+
# so their influence tends to zero after exponent, then eliminate residual NaNs.
|
|
327
|
+
maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
|
|
328
|
+
|
|
329
|
+
out = torch.exp(-0.5 * maha) # (P,N)
|
|
330
|
+
|
|
331
|
+
return out.sum(dim=0) # sum over players
|
|
332
|
+
|
|
333
|
+
def _batch_team_influence_frames_pt(
|
|
334
|
+
self,
|
|
335
|
+
pos_t: torch.Tensor, # (F,P,2)
|
|
336
|
+
pos_tp1: torch.Tensor, # (F,P,2)
|
|
337
|
+
ball_pos: torch.Tensor, # (F,2)
|
|
338
|
+
dt_secs: torch.Tensor, # (F,)
|
|
339
|
+
locs: torch.Tensor, # (N,2)
|
|
340
|
+
dtype: torch.dtype,
|
|
341
|
+
) -> torch.Tensor: # returns (F,N)
|
|
342
|
+
"""Vectorised influence for many frames & players of ОДНОЙ команды."""
|
|
343
|
+
|
|
344
|
+
device = locs.device
|
|
345
|
+
|
|
346
|
+
sxy = (pos_tp1 - pos_t) / dt_secs[:, None, None] # (F,P,2)
|
|
347
|
+
speed = torch.linalg.norm(sxy, dim=-1) # (F,P)
|
|
348
|
+
norm_sxy = speed.clamp(min=1e-6)
|
|
349
|
+
|
|
350
|
+
theta = torch.acos(torch.clamp(sxy[..., 0] / norm_sxy, -1 + 1e-6, 1 - 1e-6))
|
|
351
|
+
cos_t, sin_t = torch.cos(theta), torch.sin(theta) # (F,P)
|
|
352
|
+
|
|
353
|
+
R = torch.stack(
|
|
354
|
+
[
|
|
355
|
+
torch.stack([cos_t, -sin_t], dim=-1),
|
|
356
|
+
torch.stack([sin_t, cos_t], dim=-1),
|
|
357
|
+
],
|
|
358
|
+
dim=-2,
|
|
359
|
+
) # (F,P,2,2)
|
|
360
|
+
|
|
361
|
+
Srat = (speed / 13) ** 2 # (F,P)
|
|
362
|
+
|
|
363
|
+
Ri = torch.linalg.norm(ball_pos[:, None, :] - pos_t, dim=-1) # (F,P)
|
|
364
|
+
Ri = torch.minimum(4 + Ri ** 3 / (18 ** 3 / 6), torch.tensor(10.0, device=device, dtype=dtype))
|
|
365
|
+
|
|
366
|
+
S11 = (1 + Srat) * Ri / 2
|
|
367
|
+
S22 = (1 - Srat) * Ri / 2
|
|
368
|
+
|
|
369
|
+
S = torch.zeros((*pos_t.shape[:-1], 2, 2), device=device, dtype=dtype) # (F,P,2,2)
|
|
370
|
+
S[..., 0, 0] = S11
|
|
371
|
+
S[..., 1, 1] = S22
|
|
372
|
+
|
|
373
|
+
Sigma = R @ S @ S @ R.transpose(-1, -2) # (F,P,2,2)
|
|
374
|
+
|
|
375
|
+
eye = torch.eye(2, device=device, dtype=dtype) * 1e-6
|
|
376
|
+
eye = eye.view(1, 1, 2, 2)
|
|
377
|
+
|
|
378
|
+
if dtype == torch.float16:
|
|
379
|
+
Sigma_inv = torch.linalg.inv((Sigma + eye).float()).to(dtype)
|
|
380
|
+
else:
|
|
381
|
+
Sigma_inv = torch.linalg.inv(Sigma + eye)
|
|
382
|
+
|
|
383
|
+
mu = pos_t + 0.5 * sxy # (F,P,2)
|
|
384
|
+
|
|
385
|
+
diff = locs.view(1, 1, -1, 2) # (1,1,N,2)
|
|
386
|
+
diff = diff - mu.unsqueeze(2) # (F,P,N,2)
|
|
387
|
+
|
|
388
|
+
maha = torch.einsum('fpni,fpij,fpnj->fpn', diff, Sigma_inv, diff) # (F,P,N)
|
|
389
|
+
|
|
390
|
+
# Replace NaNs that arise from invalid player positions with large value
|
|
391
|
+
# so their influence tends to zero after exponent, then eliminate residual NaNs.
|
|
392
|
+
maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
|
|
393
|
+
|
|
394
|
+
out = torch.exp(-0.5 * maha) # (F,P,N)
|
|
395
|
+
|
|
396
|
+
return out.sum(dim=1) # sum over players
|
|
397
|
+
|
|
398
|
+
@staticmethod
|
|
399
|
+
def _stack_team_frames(players: list[np.ndarray], frames: np.ndarray, device: str, dtype: torch.dtype):
|
|
400
|
+
"""Stack positions for given frames into torch tensors (pos_t, pos_tp1)."""
|
|
401
|
+
pos_t = torch.tensor(
|
|
402
|
+
np.stack([p[frames] for p in players], axis=1), device=device, dtype=dtype
|
|
403
|
+
) # (F,P,2)
|
|
404
|
+
pos_tp1 = torch.tensor(
|
|
405
|
+
np.stack([p[frames + 1] for p in players], axis=1), device=device, dtype=dtype
|
|
406
|
+
)
|
|
407
|
+
return pos_t, pos_tp1
|
|
408
|
+
|
|
409
|
+
def _fit_full_pt(
|
|
410
|
+
self,
|
|
411
|
+
half: int,
|
|
412
|
+
dt: int,
|
|
413
|
+
device: str,
|
|
414
|
+
batch_size: int,
|
|
415
|
+
use_fp16: bool,
|
|
416
|
+
verbose: bool,
|
|
417
|
+
):
|
|
418
|
+
"""Internal helper with fully batched PyTorch implementation."""
|
|
419
|
+
|
|
420
|
+
dtype = torch.float16 if use_fp16 else torch.float32
|
|
421
|
+
xx_t, yy_t, locs_t = self._get_grid(dt, device, dtype)
|
|
422
|
+
xx, yy = xx_t.cpu().numpy(), yy_t.cpu().numpy()
|
|
423
|
+
|
|
424
|
+
T = len(self.t[half]) - 1
|
|
425
|
+
pc_all = np.empty((T, dt, dt), dtype=np.float32)
|
|
426
|
+
|
|
427
|
+
home_players = list(self.locs_home[half].values())
|
|
428
|
+
away_players = list(self.locs_away[half].values())
|
|
429
|
+
|
|
430
|
+
for start in tqdm(range(0, T, batch_size)):
|
|
431
|
+
end = min(start + batch_size, T)
|
|
432
|
+
frames = np.arange(start, end)
|
|
433
|
+
|
|
434
|
+
# deltas t
|
|
435
|
+
dt_secs = torch.tensor(
|
|
436
|
+
self.t[half][frames + 1] - self.t[half][frames],
|
|
437
|
+
device=device,
|
|
438
|
+
dtype=dtype,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
ball_pos = torch.tensor(
|
|
442
|
+
self.locs_ball[half][frames], device=device, dtype=dtype
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# stack teams
|
|
446
|
+
pos_t_h, pos_tp1_h = self._stack_team_frames(home_players, frames, device, dtype)
|
|
447
|
+
pos_t_a, pos_tp1_a = self._stack_team_frames(away_players, frames, device, dtype)
|
|
448
|
+
|
|
449
|
+
Zh = self._batch_team_influence_frames_pt(
|
|
450
|
+
pos_t_h, pos_tp1_h, ball_pos, dt_secs, locs_t, dtype
|
|
451
|
+
)
|
|
452
|
+
Za = self._batch_team_influence_frames_pt(
|
|
453
|
+
pos_t_a, pos_tp1_a, ball_pos, dt_secs, locs_t, dtype
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
res_batch = torch.sigmoid(Za - Zh).reshape(-1, dt, dt)
|
|
457
|
+
pc_all[start:end] = res_batch.cpu().numpy().astype(np.float32)
|
|
458
|
+
|
|
459
|
+
if verbose:
|
|
460
|
+
print(f"pt full: frames {start}-{end-1} done")
|
|
461
|
+
|
|
462
|
+
return pc_all, xx, yy
|
|
463
|
+
|
|
464
|
+
def _get_grid(self, dt: int, device: str, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
465
|
+
"""Helper to create a grid of locations for torch-based pitch control."""
|
|
466
|
+
x = torch.linspace(0, 105, dt, device=device, dtype=dtype)
|
|
467
|
+
y = torch.linspace(0, 68, dt, device=device, dtype=dtype)
|
|
468
|
+
xx_t, yy_t = torch.meshgrid(x, y, indexing="xy") # (dt,dt)
|
|
469
|
+
locs_t = torch.stack([xx_t, yy_t], dim=-1).reshape(-1, 2) # (N,2)
|
|
470
|
+
return xx_t, yy_t, locs_t
|
|
471
|
+
|
|
472
|
+
def _fit_pt(
|
|
473
|
+
self,
|
|
474
|
+
half: int,
|
|
475
|
+
tp: int,
|
|
476
|
+
dt: int = 200,
|
|
477
|
+
device: str = "cpu",
|
|
478
|
+
verbose: bool = False,
|
|
479
|
+
use_fp16: bool = True,
|
|
480
|
+
):
|
|
481
|
+
"""Torch-accelerated pitch-control calculation.
|
|
482
|
+
|
|
483
|
+
Returns
|
|
484
|
+
-------
|
|
485
|
+
result : np.ndarray shape (dt,dt)
|
|
486
|
+
xx, yy : np.ndarray meshgrids (dt,dt)
|
|
487
|
+
"""
|
|
488
|
+
if torch is None:
|
|
489
|
+
raise ImportError("PyTorch is required for backend='torch'.")
|
|
490
|
+
|
|
491
|
+
dtype = torch.float16 if (use_fp16 and device != "cpu") else torch.float32
|
|
492
|
+
|
|
493
|
+
# ---- grid caching ----
|
|
494
|
+
key = (dt, device, dtype)
|
|
495
|
+
if key not in self._grid_cache:
|
|
496
|
+
xx_t, yy_t, locs_t = self._get_grid(dt, device, dtype)
|
|
497
|
+
# Store; keep cache small (max 3 grids)
|
|
498
|
+
if len(self._grid_cache) >= 3:
|
|
499
|
+
self._grid_cache.pop(next(iter(self._grid_cache)))
|
|
500
|
+
self._grid_cache[key] = (xx_t, yy_t, locs_t)
|
|
501
|
+
else:
|
|
502
|
+
xx_t, yy_t, locs_t = self._grid_cache[key]
|
|
503
|
+
|
|
504
|
+
# --- vectorised influence computation ---
|
|
505
|
+
Zh = self._batch_influence_pt(
|
|
506
|
+
self.locs_home[half], locs_t, tp, half, device, dtype
|
|
507
|
+
)
|
|
508
|
+
Za = self._batch_influence_pt(
|
|
509
|
+
self.locs_away[half], locs_t, tp, half, device, dtype
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
res_t = torch.sigmoid(Za - Zh).reshape(dt, dt)
|
|
513
|
+
|
|
514
|
+
# Convert to numpy for downstream plotting
|
|
515
|
+
return res_t.cpu().numpy(), xx_t.cpu().numpy(), yy_t.cpu().numpy()
|
|
516
|
+
|
|
517
|
+
def _fit_np(self, half: int, tp: int, dt: int, verbose: bool = False) -> tuple:
|
|
203
518
|
x = np.linspace(0, 105, dt)
|
|
204
519
|
y = np.linspace(0, 68, dt)
|
|
205
520
|
xx, yy = np.meshgrid(x, y)
|
|
@@ -211,10 +526,10 @@ class PitchControl:
|
|
|
211
526
|
|
|
212
527
|
for k in self.locs_home[half].keys():
|
|
213
528
|
# if len(self.locs_home[half][k]) >= tp:
|
|
214
|
-
Zh += self.
|
|
529
|
+
Zh += self.influence_np(k, locations, tp, 'h', half, verbose)
|
|
215
530
|
for k in self.locs_away[half].keys():
|
|
216
531
|
# if len(self.locs_away[half][k]) >= tp:
|
|
217
|
-
Za += self.
|
|
532
|
+
Za += self.influence_np(k, locations, tp, 'a', half, verbose)
|
|
218
533
|
|
|
219
534
|
Zh = Zh.reshape((dt, dt))
|
|
220
535
|
Za = Za.reshape((dt, dt))
|
|
@@ -222,6 +537,64 @@ class PitchControl:
|
|
|
222
537
|
|
|
223
538
|
return result, xx, yy
|
|
224
539
|
|
|
540
|
+
def fit(
|
|
541
|
+
self,
|
|
542
|
+
half: int,
|
|
543
|
+
tp: int,
|
|
544
|
+
dt: int = 100,
|
|
545
|
+
backend: str = "np",
|
|
546
|
+
device: str = "cpu",
|
|
547
|
+
verbose: bool = False,
|
|
548
|
+
use_fp16: bool = True,
|
|
549
|
+
):
|
|
550
|
+
"""Selects NumPy or PyTorch backend depending on `backend`."""
|
|
551
|
+
match backend:
|
|
552
|
+
case "np" | "numpy":
|
|
553
|
+
return self._fit_np(half, tp, dt, verbose)
|
|
554
|
+
case "torch" | "pt":
|
|
555
|
+
return self._fit_pt(half, tp, dt, device=device, verbose=verbose, use_fp16=use_fp16)
|
|
556
|
+
case _:
|
|
557
|
+
raise ValueError(f"Unknown backend '{backend}'. Use 'np' or 'torch'.")
|
|
558
|
+
|
|
559
|
+
def fit_full(
|
|
560
|
+
self,
|
|
561
|
+
half: int,
|
|
562
|
+
dt: int = 100,
|
|
563
|
+
backend: str = "np",
|
|
564
|
+
device: str = "cpu",
|
|
565
|
+
batch_size: int = 30*60,
|
|
566
|
+
use_fp16: bool = True,
|
|
567
|
+
verbose: bool = False,
|
|
568
|
+
):
|
|
569
|
+
"""Compute pitch-control map for *каждый* кадр тайма.
|
|
570
|
+
|
|
571
|
+
Returns
|
|
572
|
+
-------
|
|
573
|
+
maps : np.ndarray, shape (T, dt, dt)
|
|
574
|
+
Pitch-control probability for home team at every frame.
|
|
575
|
+
xx, yy : np.ndarray, shape (dt, dt)
|
|
576
|
+
Coordinate grids (общие для всех кадров).
|
|
577
|
+
"""
|
|
578
|
+
|
|
579
|
+
T = len(self.t[half]) - 1 # мы используем t и t+1, поэтому последний кадр T-1
|
|
580
|
+
|
|
581
|
+
match backend:
|
|
582
|
+
case "np" | "numpy":
|
|
583
|
+
pc_all = np.empty((T, dt, dt), dtype=np.float32)
|
|
584
|
+
for tp in tqdm(range(T)):
|
|
585
|
+
pc_map, xx, yy = self._fit_np(half, tp, dt, verbose=False)
|
|
586
|
+
pc_all[tp] = pc_map.astype(np.float32)
|
|
587
|
+
if verbose and tp % 500 == 0:
|
|
588
|
+
print(f"np full-match: done {tp}/{T}")
|
|
589
|
+
return pc_all, xx, yy
|
|
590
|
+
|
|
591
|
+
case "torch" | "pt":
|
|
592
|
+
return self._fit_full_pt(
|
|
593
|
+
half, dt, device, batch_size, use_fp16, verbose
|
|
594
|
+
)
|
|
595
|
+
case _:
|
|
596
|
+
raise ValueError("backend must be 'np' or 'pt'")
|
|
597
|
+
|
|
225
598
|
def draw_pitch_control(
|
|
226
599
|
self,
|
|
227
600
|
half: int,
|
|
@@ -233,8 +606,11 @@ class PitchControl:
|
|
|
233
606
|
):
|
|
234
607
|
if pitch_control is None:
|
|
235
608
|
pitch_control, xx, yy = self.fit(half, tp, dt)
|
|
609
|
+
else:
|
|
610
|
+
pitch_control, xx, yy = pitch_control
|
|
236
611
|
|
|
237
612
|
fig, ax = plt.subplots(figsize=(10.5, 6.8))
|
|
613
|
+
# mpl.field(fieldcolor="white", linecolor="black", alpha=1, show=False, ax=ax)
|
|
238
614
|
mpl.field("white", show=False, ax=ax)
|
|
239
615
|
|
|
240
616
|
plt.contourf(xx, yy, pitch_control)
|
rustat_python_api/processing.py
CHANGED
|
@@ -13,20 +13,27 @@ def process_list(x: pd.Series):
|
|
|
13
13
|
else:
|
|
14
14
|
return lst
|
|
15
15
|
|
|
16
|
+
def take_last(x: pd.Series):
|
|
17
|
+
lst = x.dropna().tolist()
|
|
18
|
+
if lst:
|
|
19
|
+
return lst[-1]
|
|
20
|
+
return np.nan
|
|
21
|
+
|
|
16
22
|
|
|
17
23
|
def gluing(df: pd.DataFrame) -> pd.DataFrame:
|
|
18
24
|
cols = ['player_id', 'half', 'second', 'pos_x', 'pos_y']
|
|
19
25
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
)
|
|
26
|
+
agg_rules = {}
|
|
27
|
+
|
|
28
|
+
for col_name in df.columns:
|
|
29
|
+
if col_name not in cols:
|
|
30
|
+
if col_name in ['action_name', 'action_id']:
|
|
31
|
+
agg_rules[col_name] = process_list
|
|
32
|
+
else:
|
|
33
|
+
agg_rules[col_name] = take_last
|
|
34
|
+
|
|
35
|
+
df_gb = df.groupby(cols).agg(agg_rules).reset_index()
|
|
36
|
+
|
|
30
37
|
df_gb['pos_dest_nan'] = (df_gb['pos_dest_x'].isna() & df_gb['pos_dest_y'].isna()).astype(int)
|
|
31
38
|
df_gb = df_gb.sort_values(by=['half', 'second', 'possession_number', 'pos_dest_nan']).reset_index(drop=True)
|
|
32
39
|
return df_gb
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: rustat-python-api
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: A Python wrapper for RuStat API
|
|
5
5
|
Home-page: https://github.com/dailydaniel/rustat-python-api
|
|
6
6
|
Author: Daniel Zholkovsky
|
|
@@ -18,6 +18,7 @@ Requires-Dist: tqdm==4.66.5
|
|
|
18
18
|
Requires-Dist: scipy==1.14.1
|
|
19
19
|
Requires-Dist: matplotlib
|
|
20
20
|
Requires-Dist: matplotsoccer
|
|
21
|
+
Requires-Dist: torch
|
|
21
22
|
|
|
22
23
|
# rustat-python-api
|
|
23
24
|
|
|
@@ -2,11 +2,11 @@ rustat_python_api/__init__.py,sha256=Ij-PAm2y5ss_XAZhKTZus35cRPLzvXFyIswDa_Iq3rs
|
|
|
2
2
|
rustat_python_api/config.py,sha256=eMvi1p8Cfvnbp6Cd4bBOwgehVN7thKnaQV5uzWyGZXM,1844
|
|
3
3
|
rustat_python_api/models_api.py,sha256=oHXEqeCupvZwjVEdoxf7W9LP7ELFKA8-9DuRXpQHLno,1701
|
|
4
4
|
rustat_python_api/parser.py,sha256=PrGvN3vY0oczMsJxyUxBO0Yb1P03Moc74AgGYHYr_X8,10216
|
|
5
|
-
rustat_python_api/pitch_control.py,sha256=
|
|
6
|
-
rustat_python_api/processing.py,sha256=
|
|
5
|
+
rustat_python_api/pitch_control.py,sha256=3Sokn8yqVxO_XSMfIq9DkGBP6OkoxsYES56LxihaU-I,25598
|
|
6
|
+
rustat_python_api/processing.py,sha256=WES2D77uu3ScpjN0nW8ozatHteCZJQmmF9WpHPgGfJo,2835
|
|
7
7
|
rustat_python_api/urls.py,sha256=iJTD31T6OyXPAhmhViwFXVehrzwsOjBDONA1SIVc_40,1068
|
|
8
|
-
rustat_python_api-0.
|
|
9
|
-
rustat_python_api-0.
|
|
10
|
-
rustat_python_api-0.
|
|
11
|
-
rustat_python_api-0.
|
|
12
|
-
rustat_python_api-0.
|
|
8
|
+
rustat_python_api-0.6.0.dist-info/LICENSE,sha256=4Cohqg5p6Mq1xyrzdEX8AvFSA62GSVvapEOr2xK_tgY,57
|
|
9
|
+
rustat_python_api-0.6.0.dist-info/METADATA,sha256=AgCVwTnDBQZ9IU5D14rStxIP3S2S0MyBVwULZMndlgc,1751
|
|
10
|
+
rustat_python_api-0.6.0.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
|
11
|
+
rustat_python_api-0.6.0.dist-info/top_level.txt,sha256=VK0hmkKZE9YThxolUcoE6JtGI67NFeKJMBLuet8kI4w,18
|
|
12
|
+
rustat_python_api-0.6.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|