rustat-python-api 0.7.7__tar.gz → 0.7.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rustat-python-api-0.7.7/rustat_python_api.egg-info → rustat-python-api-0.7.9}/PKG-INFO +1 -1
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api/pitch_control.py +229 -139
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9/rustat_python_api.egg-info}/PKG-INFO +1 -1
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api.egg-info/requires.txt +0 -1
- rustat-python-api-0.7.9/setup.py +32 -0
- rustat-python-api-0.7.7/setup.py +0 -38
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/LICENSE +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/README.md +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/pyproject.toml +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api/__init__.py +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api/config.py +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api/kernels/__init__.py +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api/kernels/maha.py +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api/matching/__init__.py +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api/matching/dataloader.py +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api/matching/pc_adder.py +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api/matching/tr_adder.py +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api/models_api.py +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api/parser.py +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api/processing.py +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api/urls.py +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api.egg-info/SOURCES.txt +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api.egg-info/dependency_links.txt +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api.egg-info/top_level.txt +0 -0
- {rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/setup.cfg +0 -0
|
@@ -12,52 +12,71 @@ try:
|
|
|
12
12
|
except ImportError:
|
|
13
13
|
triton_influence = None
|
|
14
14
|
|
|
15
|
+
|
|
15
16
|
class PitchControl:
|
|
16
|
-
def __init__(
|
|
17
|
-
self
|
|
18
|
-
|
|
19
|
-
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
tracking: pd.DataFrame,
|
|
20
|
+
events: pd.DataFrame,
|
|
21
|
+
ball_data: pd.DataFrame = None,
|
|
22
|
+
):
|
|
23
|
+
self.team_ids = tracking["team_id"].unique()
|
|
24
|
+
sides = tracking.groupby("team_id")["side_1h"].unique()
|
|
25
|
+
side_by_team = dict(
|
|
26
|
+
zip(self.team_ids, sides[self.team_ids].apply(lambda x: x[0]))
|
|
27
|
+
)
|
|
20
28
|
self.side_by_half = {
|
|
21
29
|
1: side_by_team,
|
|
22
|
-
2:
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
}
|
|
30
|
+
2: {
|
|
31
|
+
team: "left" if side == "right" else "right"
|
|
32
|
+
for team, side in side_by_team.items()
|
|
33
|
+
},
|
|
27
34
|
}
|
|
28
35
|
|
|
29
|
-
self._grid_cache: dict[
|
|
36
|
+
self._grid_cache: dict[
|
|
37
|
+
tuple[int, str, torch.dtype],
|
|
38
|
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
39
|
+
] = {}
|
|
30
40
|
|
|
31
41
|
self.locs_home, self.locs_away, self.locs_ball, self.t = self.get_locs(
|
|
32
|
-
tracking,
|
|
33
|
-
events,
|
|
34
|
-
ball_data
|
|
42
|
+
tracking, events, ball_data
|
|
35
43
|
)
|
|
36
44
|
|
|
37
|
-
def get_locs(
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
45
|
+
def get_locs(
|
|
46
|
+
self,
|
|
47
|
+
tracking: pd.DataFrame,
|
|
48
|
+
events: pd.DataFrame,
|
|
49
|
+
ball_data: pd.DataFrame | None,
|
|
50
|
+
) -> tuple:
|
|
51
|
+
events = events[
|
|
52
|
+
[
|
|
53
|
+
"possession_number",
|
|
54
|
+
"team_id",
|
|
55
|
+
"possession_team_id",
|
|
56
|
+
"half",
|
|
57
|
+
"second",
|
|
58
|
+
"pos_x",
|
|
59
|
+
"pos_y",
|
|
60
|
+
]
|
|
61
|
+
]
|
|
42
62
|
|
|
43
63
|
events = self.swap_coords_batch(events)
|
|
44
64
|
|
|
45
65
|
if ball_data is None:
|
|
46
66
|
ball_data = self.interpolate_ball_data(
|
|
47
|
-
events[[
|
|
48
|
-
tracking
|
|
67
|
+
events[["half", "second", "pos_x", "pos_y"]], tracking
|
|
49
68
|
)
|
|
50
69
|
|
|
51
70
|
locs_home, locs_away = self.build_player_locs(tracking)
|
|
52
71
|
|
|
53
72
|
locs_ball = {
|
|
54
|
-
half: ball_data[ball_data[
|
|
55
|
-
for half in tracking[
|
|
73
|
+
half: ball_data[ball_data["half"] == half][["pos_x", "pos_y"]].values
|
|
74
|
+
for half in tracking["half"].unique()
|
|
56
75
|
}
|
|
57
76
|
|
|
58
77
|
t = {
|
|
59
|
-
half: ball_data[ball_data[
|
|
60
|
-
for half in tracking[
|
|
78
|
+
half: ball_data[ball_data["half"] == half]["second"].values
|
|
79
|
+
for half in tracking["half"].unique()
|
|
61
80
|
}
|
|
62
81
|
|
|
63
82
|
return locs_home, locs_away, locs_ball, t
|
|
@@ -96,20 +115,22 @@ class PitchControl:
|
|
|
96
115
|
side_by_half = self.side_by_half # local alias for speed
|
|
97
116
|
|
|
98
117
|
def needs_swap(row):
|
|
99
|
-
team_id = row[
|
|
100
|
-
poss = row[
|
|
101
|
-
half = row[
|
|
118
|
+
team_id = row["team_id"]
|
|
119
|
+
poss = row["possession_team_id"]
|
|
120
|
+
half = row["half"]
|
|
102
121
|
|
|
103
|
-
current_left =
|
|
104
|
-
|
|
122
|
+
current_left = (
|
|
123
|
+
team_id in poss if isinstance(poss, list) else team_id == poss
|
|
124
|
+
)
|
|
125
|
+
current_side = "left" if current_left else "right"
|
|
105
126
|
real_side = side_by_half[half][str(int(team_id))]
|
|
106
127
|
return current_side != real_side
|
|
107
128
|
|
|
108
129
|
mask = events.apply(needs_swap, axis=1)
|
|
109
130
|
|
|
110
131
|
# flip coords in bulk
|
|
111
|
-
events.loc[mask,
|
|
112
|
-
events.loc[mask,
|
|
132
|
+
events.loc[mask, "pos_x"] = 105 - events.loc[mask, "pos_x"]
|
|
133
|
+
events.loc[mask, "pos_y"] = 68 - events.loc[mask, "pos_y"]
|
|
113
134
|
return events
|
|
114
135
|
|
|
115
136
|
def build_player_locs(self, tracking: pd.DataFrame):
|
|
@@ -123,36 +144,35 @@ class PitchControl:
|
|
|
123
144
|
|
|
124
145
|
# Work per half to keep order and avoid extra boolean checks.
|
|
125
146
|
for half in (1, 2):
|
|
126
|
-
half_df = tracking[tracking[
|
|
127
|
-
for side, locs_out in [(
|
|
128
|
-
side_df = half_df[half_df[
|
|
129
|
-
for pid, grp in side_df.groupby(
|
|
130
|
-
locs_out[half][pid] = grp[[
|
|
147
|
+
half_df = tracking[tracking["half"] == half]
|
|
148
|
+
for side, locs_out in [("left", locs_home), ("right", locs_away)]:
|
|
149
|
+
side_df = half_df[half_df["side_1h"] == side]
|
|
150
|
+
for pid, grp in side_df.groupby("player_id"):
|
|
151
|
+
locs_out[half][pid] = grp[["pos_x", "pos_y"]].values
|
|
131
152
|
return locs_home, locs_away
|
|
132
153
|
|
|
133
154
|
@staticmethod
|
|
134
155
|
def interpolate_ball_data(
|
|
135
|
-
ball_data: pd.DataFrame,
|
|
136
|
-
player_data: pd.DataFrame
|
|
156
|
+
ball_data: pd.DataFrame, player_data: pd.DataFrame
|
|
137
157
|
) -> pd.DataFrame:
|
|
138
|
-
ball_data = ball_data.drop_duplicates(subset=[
|
|
158
|
+
ball_data = ball_data.drop_duplicates(subset=["second", "half"])
|
|
139
159
|
|
|
140
160
|
interpolated_data = []
|
|
141
|
-
for half in ball_data[
|
|
142
|
-
ball_half = ball_data[ball_data[
|
|
143
|
-
player_half = player_data[player_data[
|
|
161
|
+
for half in ball_data["half"].unique():
|
|
162
|
+
ball_half = ball_data[ball_data["half"] == half]
|
|
163
|
+
player_half = player_data[player_data["half"] == half]
|
|
144
164
|
|
|
145
|
-
player_times = player_half[
|
|
165
|
+
player_times = player_half["second"].unique()
|
|
146
166
|
|
|
147
|
-
ball_half = ball_half.sort_values(by=
|
|
148
|
-
interpolated_half = pd.DataFrame({
|
|
149
|
-
interpolated_half[
|
|
150
|
-
interpolated_half[
|
|
167
|
+
ball_half = ball_half.sort_values(by="second")
|
|
168
|
+
interpolated_half = pd.DataFrame({"second": player_times})
|
|
169
|
+
interpolated_half["pos_x"] = np.interp(
|
|
170
|
+
interpolated_half["second"], ball_half["second"], ball_half["pos_x"]
|
|
151
171
|
)
|
|
152
|
-
interpolated_half[
|
|
153
|
-
interpolated_half[
|
|
172
|
+
interpolated_half["pos_y"] = np.interp(
|
|
173
|
+
interpolated_half["second"], ball_half["second"], ball_half["pos_y"]
|
|
154
174
|
)
|
|
155
|
-
interpolated_half[
|
|
175
|
+
interpolated_half["half"] = half
|
|
156
176
|
interpolated_data.append(interpolated_half)
|
|
157
177
|
|
|
158
178
|
interpolated_ball_data = pd.concat(interpolated_data, ignore_index=True)
|
|
@@ -160,16 +180,15 @@ class PitchControl:
|
|
|
160
180
|
|
|
161
181
|
@staticmethod
|
|
162
182
|
def get_player_data(player_id, half, tracking):
|
|
163
|
-
timestamps = tracking[tracking[
|
|
183
|
+
timestamps = tracking[tracking["half"] == half]["second"].unique()
|
|
164
184
|
player_data = tracking[
|
|
165
|
-
(tracking[
|
|
166
|
-
|
|
167
|
-
][['second', 'pos_x', 'pos_y']]
|
|
185
|
+
(tracking["player_id"] == player_id) & (tracking["half"] == half)
|
|
186
|
+
][["second", "pos_x", "pos_y"]]
|
|
168
187
|
|
|
169
|
-
player_data_full = pd.DataFrame({
|
|
170
|
-
player_data_full = player_data_full.merge(player_data, on=
|
|
188
|
+
player_data_full = pd.DataFrame({"second": timestamps})
|
|
189
|
+
player_data_full = player_data_full.merge(player_data, on="second", how="left")
|
|
171
190
|
|
|
172
|
-
return player_data_full[[
|
|
191
|
+
return player_data_full[["pos_x", "pos_y"]].values
|
|
173
192
|
|
|
174
193
|
def influence_np(
|
|
175
194
|
self,
|
|
@@ -178,11 +197,11 @@ class PitchControl:
|
|
|
178
197
|
time_index: int,
|
|
179
198
|
home_or_away: str,
|
|
180
199
|
half: int,
|
|
181
|
-
verbose: bool = False
|
|
200
|
+
verbose: bool = False,
|
|
182
201
|
):
|
|
183
|
-
if home_or_away ==
|
|
202
|
+
if home_or_away == "h":
|
|
184
203
|
data = self.locs_home[half].copy()
|
|
185
|
-
elif home_or_away ==
|
|
204
|
+
elif home_or_away == "a":
|
|
186
205
|
data = self.locs_away[half].copy()
|
|
187
206
|
else:
|
|
188
207
|
raise ValueError("Enter either 'h' or 'a'.")
|
|
@@ -190,38 +209,44 @@ class PitchControl:
|
|
|
190
209
|
locs_ball = self.locs_ball[half].copy()
|
|
191
210
|
t = self.t[half].copy()
|
|
192
211
|
|
|
193
|
-
if (
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
jitter = 1e-10 ## to prevent identically zero covariance matrices when velocity is zero
|
|
212
|
+
if np.all(
|
|
213
|
+
np.isfinite(data[player_index][[time_index, time_index + 1], :])
|
|
214
|
+
) & np.all(np.isfinite(locs_ball[time_index, :])):
|
|
215
|
+
jitter = 1e-10 ## to prevent identically zero covariance matrices when velocity is zero
|
|
198
216
|
## compute velocity by fwd difference
|
|
199
|
-
s = (
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
/ (t[time_index + 1] - t[time_index])
|
|
205
|
-
)
|
|
217
|
+
s = np.linalg.norm(
|
|
218
|
+
data[player_index][time_index + 1, :]
|
|
219
|
+
- data[player_index][time_index, :]
|
|
220
|
+
+ jitter
|
|
221
|
+
) / (t[time_index + 1] - t[time_index])
|
|
206
222
|
## velocities in x,y directions
|
|
207
223
|
sxy = (
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
224
|
+
data[player_index][time_index + 1, :]
|
|
225
|
+
- data[player_index][time_index, :]
|
|
226
|
+
+ jitter
|
|
227
|
+
) / (t[time_index + 1] - t[time_index])
|
|
211
228
|
## angle between velocity vector & x-axis
|
|
212
229
|
theta = np.arccos(sxy[0] / np.linalg.norm(sxy))
|
|
213
230
|
## rotation matrix
|
|
214
|
-
R = np.array(
|
|
231
|
+
R = np.array(
|
|
232
|
+
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
|
|
233
|
+
)
|
|
215
234
|
mu = data[player_index][time_index, :] + sxy * 0.5
|
|
216
235
|
Srat = (s / 13) ** 2
|
|
217
|
-
Ri = np.linalg.norm(
|
|
236
|
+
Ri = np.linalg.norm(
|
|
237
|
+
locs_ball[time_index, :] - data[player_index][time_index, :]
|
|
238
|
+
)
|
|
218
239
|
## don't think this function is specified in the paper but looks close enough to fig 9
|
|
219
|
-
Ri = np.minimum(4 + Ri
|
|
240
|
+
Ri = np.minimum(4 + Ri**3 / (18**3 / 6), 10)
|
|
220
241
|
S = np.array([[(1 + Srat) * Ri / 2, 0], [0, (1 - Srat) * Ri / 2]])
|
|
221
242
|
Sigma = np.matmul(R, S)
|
|
222
243
|
Sigma = np.matmul(Sigma, S)
|
|
223
|
-
Sigma = np.matmul(
|
|
224
|
-
|
|
244
|
+
Sigma = np.matmul(
|
|
245
|
+
Sigma, np.linalg.inv(R)
|
|
246
|
+
) ## this is not efficient, forgive me.
|
|
247
|
+
out = mvn.pdf(location, mu, Sigma) / mvn.pdf(
|
|
248
|
+
data[player_index][time_index, :], mu, Sigma
|
|
249
|
+
)
|
|
225
250
|
else:
|
|
226
251
|
if verbose:
|
|
227
252
|
print("Data is not finite.")
|
|
@@ -270,8 +295,12 @@ class PitchControl:
|
|
|
270
295
|
if not pos_t_list:
|
|
271
296
|
return torch.zeros(locs.shape[0], device=device, dtype=dtype)
|
|
272
297
|
|
|
273
|
-
pos_t = torch.tensor(
|
|
274
|
-
|
|
298
|
+
pos_t = torch.tensor(
|
|
299
|
+
np.asarray(pos_t_list), device=device, dtype=dtype
|
|
300
|
+
) # (P,2)
|
|
301
|
+
pos_tp1 = torch.tensor(
|
|
302
|
+
np.asarray(pos_tp1_list), device=device, dtype=dtype
|
|
303
|
+
) # (P,2)
|
|
275
304
|
|
|
276
305
|
# Velocity, speed, rotation -------------------------
|
|
277
306
|
dt_sec = float(self.t[half][time_index + 1] - self.t[half][time_index])
|
|
@@ -280,13 +309,15 @@ class PitchControl:
|
|
|
280
309
|
speed = torch.linalg.norm(sxy, dim=1) # (P,)
|
|
281
310
|
norm_sxy = speed.clamp(min=1e-6)
|
|
282
311
|
|
|
283
|
-
theta = torch.acos(
|
|
312
|
+
theta = torch.acos(
|
|
313
|
+
torch.clamp(sxy[:, 0] / norm_sxy, -1 + 1e-6, 1 - 1e-6)
|
|
314
|
+
) # (P,)
|
|
284
315
|
cos_t, sin_t = torch.cos(theta), torch.sin(theta)
|
|
285
316
|
|
|
286
317
|
R = torch.stack(
|
|
287
318
|
[
|
|
288
319
|
torch.stack([cos_t, -sin_t], dim=1),
|
|
289
|
-
torch.stack([sin_t,
|
|
320
|
+
torch.stack([sin_t, cos_t], dim=1),
|
|
290
321
|
],
|
|
291
322
|
dim=1,
|
|
292
323
|
) # (P,2,2)
|
|
@@ -298,7 +329,9 @@ class PitchControl:
|
|
|
298
329
|
self.locs_ball[half][time_index], device=device, dtype=dtype
|
|
299
330
|
) # (2,)
|
|
300
331
|
Ri = torch.linalg.norm(ball_pos - pos_t, dim=1) # (P,)
|
|
301
|
-
Ri = torch.minimum(
|
|
332
|
+
Ri = torch.minimum(
|
|
333
|
+
4 + Ri**3 / (18**3 / 6), torch.tensor(10.0, device=device, dtype=dtype)
|
|
334
|
+
)
|
|
302
335
|
|
|
303
336
|
S11 = (1 + Srat) * Ri / 2
|
|
304
337
|
S22 = (1 - Srat) * Ri / 2
|
|
@@ -322,19 +355,18 @@ class PitchControl:
|
|
|
322
355
|
|
|
323
356
|
# Grid diff & Mahalanobis ----------------------------
|
|
324
357
|
diff = locs.view(1, -1, 2) # (1,N,2)
|
|
325
|
-
diff = diff - mu.unsqueeze(1)
|
|
358
|
+
diff = diff - mu.unsqueeze(1) # (P,N,2)
|
|
326
359
|
|
|
327
360
|
device = torch.device(device)
|
|
328
361
|
|
|
329
|
-
if device.type ==
|
|
362
|
+
if device.type == "cuda":
|
|
330
363
|
out = triton_influence(
|
|
331
|
-
mu.unsqueeze(0), Sigma_inv.unsqueeze(0),
|
|
332
|
-
locs, BLOCK_N=64
|
|
364
|
+
mu.unsqueeze(0), Sigma_inv.unsqueeze(0), locs, BLOCK_N=64
|
|
333
365
|
)[0] # (N,)
|
|
334
366
|
|
|
335
367
|
return out
|
|
336
368
|
else:
|
|
337
|
-
maha = torch.einsum(
|
|
369
|
+
maha = torch.einsum("pni,pij,pnj->pn", diff, Sigma_inv, diff) # (P,N)
|
|
338
370
|
maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
|
|
339
371
|
out = torch.exp(-0.5 * maha) # (P,N)
|
|
340
372
|
|
|
@@ -342,19 +374,19 @@ class PitchControl:
|
|
|
342
374
|
|
|
343
375
|
def _batch_team_influence_frames_pt(
|
|
344
376
|
self,
|
|
345
|
-
pos_t: torch.Tensor,
|
|
346
|
-
pos_tp1: torch.Tensor,
|
|
347
|
-
ball_pos: torch.Tensor,
|
|
348
|
-
dt_secs: torch.Tensor,
|
|
349
|
-
locs: torch.Tensor,
|
|
377
|
+
pos_t: torch.Tensor, # (F,P,2)
|
|
378
|
+
pos_tp1: torch.Tensor, # (F,P,2)
|
|
379
|
+
ball_pos: torch.Tensor, # (F,2)
|
|
380
|
+
dt_secs: torch.Tensor, # (F,)
|
|
381
|
+
locs: torch.Tensor, # (N,2)
|
|
350
382
|
dtype: torch.dtype,
|
|
351
|
-
) -> torch.Tensor:
|
|
383
|
+
) -> torch.Tensor: # returns (F,N)
|
|
352
384
|
"""Vectorised influence for many frames & players of ОДНОЙ команды."""
|
|
353
385
|
|
|
354
386
|
device = locs.device
|
|
355
387
|
|
|
356
388
|
sxy = (pos_tp1 - pos_t) / dt_secs[:, None, None] # (F,P,2)
|
|
357
|
-
speed = torch.linalg.norm(sxy, dim=-1)
|
|
389
|
+
speed = torch.linalg.norm(sxy, dim=-1) # (F,P)
|
|
358
390
|
norm_sxy = speed.clamp(min=1e-6)
|
|
359
391
|
|
|
360
392
|
theta = torch.acos(torch.clamp(sxy[..., 0] / norm_sxy, -1 + 1e-6, 1 - 1e-6))
|
|
@@ -363,7 +395,7 @@ class PitchControl:
|
|
|
363
395
|
R = torch.stack(
|
|
364
396
|
[
|
|
365
397
|
torch.stack([cos_t, -sin_t], dim=-1),
|
|
366
|
-
torch.stack([sin_t,
|
|
398
|
+
torch.stack([sin_t, cos_t], dim=-1),
|
|
367
399
|
],
|
|
368
400
|
dim=-2,
|
|
369
401
|
) # (F,P,2,2)
|
|
@@ -371,12 +403,16 @@ class PitchControl:
|
|
|
371
403
|
Srat = (speed / 13) ** 2 # (F,P)
|
|
372
404
|
|
|
373
405
|
Ri = torch.linalg.norm(ball_pos[:, None, :] - pos_t, dim=-1) # (F,P)
|
|
374
|
-
Ri = torch.minimum(
|
|
406
|
+
Ri = torch.minimum(
|
|
407
|
+
4 + Ri**3 / (18**3 / 6), torch.tensor(10.0, device=device, dtype=dtype)
|
|
408
|
+
)
|
|
375
409
|
|
|
376
410
|
S11 = (1 + Srat) * Ri / 2
|
|
377
411
|
S22 = (1 - Srat) * Ri / 2
|
|
378
412
|
|
|
379
|
-
S = torch.zeros(
|
|
413
|
+
S = torch.zeros(
|
|
414
|
+
(*pos_t.shape[:-1], 2, 2), device=device, dtype=dtype
|
|
415
|
+
) # (F,P,2,2)
|
|
380
416
|
S[..., 0, 0] = S11
|
|
381
417
|
S[..., 1, 1] = S22
|
|
382
418
|
|
|
@@ -393,21 +429,23 @@ class PitchControl:
|
|
|
393
429
|
mu = pos_t + 0.5 * sxy # (F,P,2)
|
|
394
430
|
|
|
395
431
|
diff = locs.view(1, 1, -1, 2) # (1,1,N,2)
|
|
396
|
-
diff = diff - mu.unsqueeze(2)
|
|
432
|
+
diff = diff - mu.unsqueeze(2) # (F,P,N,2)
|
|
397
433
|
|
|
398
|
-
if device.type ==
|
|
434
|
+
if device.type == "cuda":
|
|
399
435
|
out = triton_influence(mu, Sigma_inv, locs, BLOCK_N=64) # (F,N)
|
|
400
436
|
|
|
401
437
|
return out
|
|
402
438
|
else:
|
|
403
|
-
maha = torch.einsum(
|
|
439
|
+
maha = torch.einsum("fpni,fpij,fpnj->fpn", diff, Sigma_inv, diff) # (F,P,N)
|
|
404
440
|
maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
|
|
405
441
|
out = torch.exp(-0.5 * maha) # (F,P,N)
|
|
406
442
|
|
|
407
443
|
return out.sum(dim=1) # sum over players
|
|
408
444
|
|
|
409
445
|
@staticmethod
|
|
410
|
-
def _stack_team_frames(
|
|
446
|
+
def _stack_team_frames(
|
|
447
|
+
players: list[np.ndarray], frames: np.ndarray, device: str, dtype: torch.dtype
|
|
448
|
+
):
|
|
411
449
|
"""Stack positions for given frames into torch tensors (pos_t, pos_tp1)."""
|
|
412
450
|
# Ensure every player's trajectory is long enough; if not, pad by repeating
|
|
413
451
|
# the last available coordinate so that indexing `frames` and `frames+1` is safe.
|
|
@@ -426,7 +464,9 @@ class PitchControl:
|
|
|
426
464
|
np.stack([p[frames] for p in padded], axis=1), device=device, dtype=dtype
|
|
427
465
|
) # (F,P,2)
|
|
428
466
|
pos_tp1 = torch.tensor(
|
|
429
|
-
np.stack([p[frames + 1] for p in padded], axis=1),
|
|
467
|
+
np.stack([p[frames + 1] for p in padded], axis=1),
|
|
468
|
+
device=device,
|
|
469
|
+
dtype=dtype,
|
|
430
470
|
)
|
|
431
471
|
return pos_t, pos_tp1
|
|
432
472
|
|
|
@@ -438,6 +478,7 @@ class PitchControl:
|
|
|
438
478
|
batch_size: int,
|
|
439
479
|
use_fp16: bool,
|
|
440
480
|
verbose: bool,
|
|
481
|
+
return_teams: str = "both",
|
|
441
482
|
):
|
|
442
483
|
"""Internal helper with fully batched PyTorch implementation."""
|
|
443
484
|
|
|
@@ -467,8 +508,12 @@ class PitchControl:
|
|
|
467
508
|
)
|
|
468
509
|
|
|
469
510
|
# stack teams
|
|
470
|
-
pos_t_h, pos_tp1_h = self._stack_team_frames(
|
|
471
|
-
|
|
511
|
+
pos_t_h, pos_tp1_h = self._stack_team_frames(
|
|
512
|
+
home_players, frames, device, dtype
|
|
513
|
+
)
|
|
514
|
+
pos_t_a, pos_tp1_a = self._stack_team_frames(
|
|
515
|
+
away_players, frames, device, dtype
|
|
516
|
+
)
|
|
472
517
|
|
|
473
518
|
Zh = self._batch_team_influence_frames_pt(
|
|
474
519
|
pos_t_h, pos_tp1_h, ball_pos, dt_secs, locs_t, dtype
|
|
@@ -477,15 +522,30 @@ class PitchControl:
|
|
|
477
522
|
pos_t_a, pos_tp1_a, ball_pos, dt_secs, locs_t, dtype
|
|
478
523
|
)
|
|
479
524
|
|
|
480
|
-
|
|
525
|
+
# Return raw team influence or combined logistic transformation
|
|
526
|
+
if return_teams == "home":
|
|
527
|
+
res_batch = torch.clamp(Zh, 0, 1).reshape(
|
|
528
|
+
-1, dt, dt
|
|
529
|
+
) # home team influence only
|
|
530
|
+
elif return_teams == "away":
|
|
531
|
+
res_batch = torch.clamp(Za, 0, 1).reshape(
|
|
532
|
+
-1, dt, dt
|
|
533
|
+
) # away team influence only
|
|
534
|
+
else: # "both"
|
|
535
|
+
res_batch = torch.sigmoid(Za - Zh).reshape(
|
|
536
|
+
-1, dt, dt
|
|
537
|
+
) # logistic transformation
|
|
538
|
+
|
|
481
539
|
pc_all[start:end] = res_batch.cpu().numpy().astype(np.float32)
|
|
482
540
|
|
|
483
541
|
if verbose:
|
|
484
|
-
print(f"pt full: frames {start}-{end-1} done")
|
|
542
|
+
print(f"pt full: frames {start}-{end - 1} done")
|
|
485
543
|
|
|
486
544
|
return pc_all, xx, yy
|
|
487
545
|
|
|
488
|
-
def _get_grid(
|
|
546
|
+
def _get_grid(
|
|
547
|
+
self, dt: int, device: str, dtype: torch.dtype
|
|
548
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
489
549
|
"""Helper to create a grid of locations for torch-based pitch control."""
|
|
490
550
|
x = torch.linspace(0, 105, dt, device=device, dtype=dtype)
|
|
491
551
|
y = torch.linspace(0, 68, dt, device=device, dtype=dtype)
|
|
@@ -501,6 +561,7 @@ class PitchControl:
|
|
|
501
561
|
device: str = "cpu",
|
|
502
562
|
verbose: bool = False,
|
|
503
563
|
use_fp16: bool = True,
|
|
564
|
+
return_teams: str = "both",
|
|
504
565
|
):
|
|
505
566
|
"""Torch-accelerated pitch-control calculation.
|
|
506
567
|
|
|
@@ -533,31 +594,51 @@ class PitchControl:
|
|
|
533
594
|
self.locs_away[half], locs_t, tp, half, device, dtype
|
|
534
595
|
)
|
|
535
596
|
|
|
536
|
-
|
|
597
|
+
# Return raw team influence or combined logistic transformation
|
|
598
|
+
if return_teams == "home":
|
|
599
|
+
res_t = torch.clamp(Zh, 0, 1).reshape(dt, dt) # home team influence only
|
|
600
|
+
elif return_teams == "away":
|
|
601
|
+
res_t = torch.clamp(Za, 0, 1).reshape(dt, dt) # away team influence only
|
|
602
|
+
else: # "both"
|
|
603
|
+
res_t = torch.sigmoid(Za - Zh).reshape(dt, dt) # logistic transformation
|
|
537
604
|
|
|
538
605
|
# Convert to numpy for downstream plotting
|
|
539
606
|
return res_t.cpu().numpy(), xx_t.cpu().numpy(), yy_t.cpu().numpy()
|
|
540
607
|
|
|
541
|
-
def _fit_np(
|
|
608
|
+
def _fit_np(
|
|
609
|
+
self,
|
|
610
|
+
half: int,
|
|
611
|
+
tp: int,
|
|
612
|
+
dt: int,
|
|
613
|
+
verbose: bool = False,
|
|
614
|
+
return_teams: str = "both",
|
|
615
|
+
) -> tuple:
|
|
542
616
|
x = np.linspace(0, 105, dt)
|
|
543
617
|
y = np.linspace(0, 68, dt)
|
|
544
618
|
xx, yy = np.meshgrid(x, y)
|
|
545
619
|
|
|
546
|
-
Zh = np.zeros(dt*dt)
|
|
547
|
-
Za = np.zeros(dt*dt)
|
|
620
|
+
Zh = np.zeros(dt * dt)
|
|
621
|
+
Za = np.zeros(dt * dt)
|
|
548
622
|
|
|
549
|
-
locations = np.c_[xx.flatten(),yy.flatten()]
|
|
623
|
+
locations = np.c_[xx.flatten(), yy.flatten()]
|
|
550
624
|
|
|
551
625
|
for k in self.locs_home[half].keys():
|
|
552
626
|
# if len(self.locs_home[half][k]) >= tp:
|
|
553
|
-
Zh += self.influence_np(k, locations, tp,
|
|
627
|
+
Zh += self.influence_np(k, locations, tp, "h", half, verbose)
|
|
554
628
|
for k in self.locs_away[half].keys():
|
|
555
629
|
# if len(self.locs_away[half][k]) >= tp:
|
|
556
|
-
Za += self.influence_np(k, locations, tp,
|
|
630
|
+
Za += self.influence_np(k, locations, tp, "a", half, verbose)
|
|
557
631
|
|
|
558
632
|
Zh = Zh.reshape((dt, dt))
|
|
559
633
|
Za = Za.reshape((dt, dt))
|
|
560
|
-
|
|
634
|
+
|
|
635
|
+
# Return raw team influence or combined logistic transformation
|
|
636
|
+
if return_teams == "home":
|
|
637
|
+
result = np.clip(Zh, 0, 1) # home team influence only
|
|
638
|
+
elif return_teams == "away":
|
|
639
|
+
result = np.clip(Za, 0, 1) # away team influence only
|
|
640
|
+
else: # "both"
|
|
641
|
+
result = 1 / (1 + np.exp(-Za + Zh)) # logistic transformation
|
|
561
642
|
|
|
562
643
|
return result, xx, yy
|
|
563
644
|
|
|
@@ -570,13 +651,22 @@ class PitchControl:
|
|
|
570
651
|
device: str = "cpu",
|
|
571
652
|
verbose: bool = False,
|
|
572
653
|
use_fp16: bool = True,
|
|
654
|
+
return_teams: str = "both",
|
|
573
655
|
):
|
|
574
656
|
"""Selects NumPy or PyTorch backend depending on `backend`."""
|
|
575
657
|
match backend:
|
|
576
658
|
case "np" | "numpy":
|
|
577
|
-
return self._fit_np(half, tp, dt, verbose)
|
|
659
|
+
return self._fit_np(half, tp, dt, verbose, return_teams)
|
|
578
660
|
case "torch" | "pt":
|
|
579
|
-
return self._fit_pt(
|
|
661
|
+
return self._fit_pt(
|
|
662
|
+
half,
|
|
663
|
+
tp,
|
|
664
|
+
dt,
|
|
665
|
+
device=device,
|
|
666
|
+
verbose=verbose,
|
|
667
|
+
use_fp16=use_fp16,
|
|
668
|
+
return_teams=return_teams,
|
|
669
|
+
)
|
|
580
670
|
case _:
|
|
581
671
|
raise ValueError(f"Unknown backend '{backend}'. Use 'np' or 'torch'.")
|
|
582
672
|
|
|
@@ -586,9 +676,10 @@ class PitchControl:
|
|
|
586
676
|
dt: int = 100,
|
|
587
677
|
backend: str = "np",
|
|
588
678
|
device: str = "cpu",
|
|
589
|
-
batch_size: int = 30*60,
|
|
679
|
+
batch_size: int = 30 * 60,
|
|
590
680
|
use_fp16: bool = True,
|
|
591
681
|
verbose: bool = False,
|
|
682
|
+
return_teams: str = "both",
|
|
592
683
|
):
|
|
593
684
|
"""Compute pitch-control map for *каждый* кадр тайма.
|
|
594
685
|
|
|
@@ -606,7 +697,9 @@ class PitchControl:
|
|
|
606
697
|
case "np" | "numpy":
|
|
607
698
|
pc_all = np.empty((T, dt, dt), dtype=np.float32)
|
|
608
699
|
for tp in tqdm(range(T)):
|
|
609
|
-
pc_map, xx, yy = self._fit_np(
|
|
700
|
+
pc_map, xx, yy = self._fit_np(
|
|
701
|
+
half, tp, dt, verbose=False, return_teams=return_teams
|
|
702
|
+
)
|
|
610
703
|
pc_all[tp] = pc_map.astype(np.float32)
|
|
611
704
|
if verbose and tp % 500 == 0:
|
|
612
705
|
print(f"np full-match: done {tp}/{T}")
|
|
@@ -614,7 +707,7 @@ class PitchControl:
|
|
|
614
707
|
|
|
615
708
|
case "torch" | "pt":
|
|
616
709
|
return self._fit_full_pt(
|
|
617
|
-
half, dt, device, batch_size, use_fp16, verbose
|
|
710
|
+
half, dt, device, batch_size, use_fp16, verbose, return_teams
|
|
618
711
|
)
|
|
619
712
|
case _:
|
|
620
713
|
raise ValueError("backend must be 'np' or 'pt'")
|
|
@@ -626,7 +719,7 @@ class PitchControl:
|
|
|
626
719
|
pitch_control: tuple = None,
|
|
627
720
|
save: bool = False,
|
|
628
721
|
dt: int = 200,
|
|
629
|
-
filename: str =
|
|
722
|
+
filename: str = "pitch_control",
|
|
630
723
|
):
|
|
631
724
|
if pitch_control is None:
|
|
632
725
|
pitch_control, xx, yy = self.fit(half, tp, dt)
|
|
@@ -645,7 +738,7 @@ class PitchControl:
|
|
|
645
738
|
plt.scatter(
|
|
646
739
|
self.locs_home[half][k][tp, 0],
|
|
647
740
|
self.locs_home[half][k][tp, 1],
|
|
648
|
-
color=
|
|
741
|
+
color="darkgrey",
|
|
649
742
|
)
|
|
650
743
|
|
|
651
744
|
for k in self.locs_away[half].keys():
|
|
@@ -653,17 +746,16 @@ class PitchControl:
|
|
|
653
746
|
if np.isfinite(self.locs_away[half][k][tp, :]).all():
|
|
654
747
|
plt.scatter(
|
|
655
748
|
self.locs_away[half][k][tp, 0],
|
|
656
|
-
self.locs_away[half][k][tp, 1],
|
|
749
|
+
self.locs_away[half][k][tp, 1],
|
|
750
|
+
color="black",
|
|
657
751
|
)
|
|
658
752
|
|
|
659
753
|
plt.scatter(
|
|
660
|
-
self.locs_ball[half][tp, 0],
|
|
661
|
-
self.locs_ball[half][tp, 1],
|
|
662
|
-
color='red'
|
|
754
|
+
self.locs_ball[half][tp, 0], self.locs_ball[half][tp, 1], color="red"
|
|
663
755
|
)
|
|
664
756
|
|
|
665
757
|
if save:
|
|
666
|
-
plt.savefig(f
|
|
758
|
+
plt.savefig(f"{filename}.png", dpi=300)
|
|
667
759
|
else:
|
|
668
760
|
plt.show()
|
|
669
761
|
|
|
@@ -674,7 +766,7 @@ class PitchControl:
|
|
|
674
766
|
filename: str = "pitch_control_animation",
|
|
675
767
|
dt: int = 200,
|
|
676
768
|
frames: int = 30,
|
|
677
|
-
interval: int = 1000
|
|
769
|
+
interval: int = 1000,
|
|
678
770
|
):
|
|
679
771
|
"""
|
|
680
772
|
ffmpeg should be installed on your machine.
|
|
@@ -686,7 +778,7 @@ class PitchControl:
|
|
|
686
778
|
pitch_control, xx, yy = self.fit(half, fr, dt)
|
|
687
779
|
|
|
688
780
|
mpl.field("white", show=False, ax=ax)
|
|
689
|
-
ax.axis(
|
|
781
|
+
ax.axis("off")
|
|
690
782
|
|
|
691
783
|
plt.contourf(xx, yy, pitch_control)
|
|
692
784
|
|
|
@@ -696,7 +788,7 @@ class PitchControl:
|
|
|
696
788
|
plt.scatter(
|
|
697
789
|
self.locs_home[half][k][fr, 0],
|
|
698
790
|
self.locs_home[half][k][fr, 1],
|
|
699
|
-
color=
|
|
791
|
+
color="darkgrey",
|
|
700
792
|
)
|
|
701
793
|
for k in self.locs_away[half].keys():
|
|
702
794
|
# if len(self.locs_away[half][k]) >= fr:
|
|
@@ -704,13 +796,11 @@ class PitchControl:
|
|
|
704
796
|
plt.scatter(
|
|
705
797
|
self.locs_away[half][k][fr, 0],
|
|
706
798
|
self.locs_away[half][k][fr, 1],
|
|
707
|
-
color=
|
|
799
|
+
color="black",
|
|
708
800
|
)
|
|
709
801
|
|
|
710
802
|
plt.scatter(
|
|
711
|
-
self.locs_ball[half][fr, 0],
|
|
712
|
-
self.locs_ball[half][fr, 1],
|
|
713
|
-
color='red'
|
|
803
|
+
self.locs_ball[half][fr, 0], self.locs_ball[half][fr, 1], color="red"
|
|
714
804
|
)
|
|
715
805
|
|
|
716
806
|
return ax
|
|
@@ -724,7 +814,7 @@ class PitchControl:
|
|
|
724
814
|
func=animate,
|
|
725
815
|
frames=min(frames, len(self.locs_ball[half]) - tp),
|
|
726
816
|
interval=interval,
|
|
727
|
-
blit=False
|
|
817
|
+
blit=False,
|
|
728
818
|
)
|
|
729
819
|
|
|
730
|
-
ani.save(f
|
|
820
|
+
ani.save(f"{filename}.mp4", writer="ffmpeg")
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from setuptools import setup, find_packages
|
|
2
|
+
|
|
3
|
+
setup(
|
|
4
|
+
name="rustat-python-api",
|
|
5
|
+
version="0.7.9",
|
|
6
|
+
description="A Python wrapper for RuStat API",
|
|
7
|
+
long_description=open("README.md").read(),
|
|
8
|
+
long_description_content_type="text/markdown",
|
|
9
|
+
author="Daniel Zholkovsky",
|
|
10
|
+
author_email="daniel@zholkovsky.com",
|
|
11
|
+
url="https://github.com/dailydaniel/rustat-python-api",
|
|
12
|
+
license="MIT",
|
|
13
|
+
packages=find_packages(),
|
|
14
|
+
install_requires=[
|
|
15
|
+
"requests==2.32.3",
|
|
16
|
+
"pandas==2.2.3",
|
|
17
|
+
"tqdm==4.66.5",
|
|
18
|
+
"scipy==1.14.1",
|
|
19
|
+
"matplotlib",
|
|
20
|
+
"matplotsoccer",
|
|
21
|
+
],
|
|
22
|
+
extras_require={
|
|
23
|
+
"gpu": ["torch", "triton==3.0.0; platform_system=='Linux'"],
|
|
24
|
+
"cpu": ["torch"],
|
|
25
|
+
},
|
|
26
|
+
classifiers=[
|
|
27
|
+
"Programming Language :: Python :: 3",
|
|
28
|
+
"License :: OSI Approved :: MIT License",
|
|
29
|
+
"Operating System :: OS Independent",
|
|
30
|
+
],
|
|
31
|
+
python_requires=">=3.10",
|
|
32
|
+
)
|
rustat-python-api-0.7.7/setup.py
DELETED
|
@@ -1,38 +0,0 @@
|
|
|
1
|
-
from setuptools import setup, find_packages
|
|
2
|
-
|
|
3
|
-
setup(
|
|
4
|
-
name='rustat-python-api',
|
|
5
|
-
version='0.7.7',
|
|
6
|
-
description='A Python wrapper for RuStat API',
|
|
7
|
-
long_description=open('README.md').read(),
|
|
8
|
-
long_description_content_type='text/markdown',
|
|
9
|
-
author='Daniel Zholkovsky',
|
|
10
|
-
author_email='daniel@zholkovsky.com',
|
|
11
|
-
url='https://github.com/dailydaniel/rustat-python-api',
|
|
12
|
-
license='MIT',
|
|
13
|
-
packages=find_packages(),
|
|
14
|
-
install_requires=[
|
|
15
|
-
'requests==2.32.3',
|
|
16
|
-
'pandas==2.2.3',
|
|
17
|
-
'tqdm==4.66.5',
|
|
18
|
-
'scipy==1.14.1',
|
|
19
|
-
'matplotlib',
|
|
20
|
-
'matplotsoccer',
|
|
21
|
-
'torch',
|
|
22
|
-
],
|
|
23
|
-
extras_require={
|
|
24
|
-
"gpu": [
|
|
25
|
-
"torch",
|
|
26
|
-
"triton==3.0.0; platform_system=='Linux'",
|
|
27
|
-
],
|
|
28
|
-
"cpu": [
|
|
29
|
-
"torch"
|
|
30
|
-
]
|
|
31
|
-
},
|
|
32
|
-
classifiers=[
|
|
33
|
-
'Programming Language :: Python :: 3',
|
|
34
|
-
'License :: OSI Approved :: MIT License',
|
|
35
|
-
'Operating System :: OS Independent',
|
|
36
|
-
],
|
|
37
|
-
python_requires='>=3.10',
|
|
38
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api/matching/dataloader.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{rustat-python-api-0.7.7 → rustat-python-api-0.7.9}/rustat_python_api.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|