rustat-python-api 0.7.8__tar.gz → 0.7.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {rustat-python-api-0.7.8/rustat_python_api.egg-info → rustat-python-api-0.7.10}/PKG-INFO +1 -1
  2. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api/pitch_control.py +244 -148
  3. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10/rustat_python_api.egg-info}/PKG-INFO +1 -1
  4. rustat-python-api-0.7.10/setup.py +32 -0
  5. rustat-python-api-0.7.8/setup.py +0 -37
  6. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/LICENSE +0 -0
  7. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/README.md +0 -0
  8. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/pyproject.toml +0 -0
  9. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api/__init__.py +0 -0
  10. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api/config.py +0 -0
  11. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api/kernels/__init__.py +0 -0
  12. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api/kernels/maha.py +0 -0
  13. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api/matching/__init__.py +0 -0
  14. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api/matching/dataloader.py +0 -0
  15. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api/matching/pc_adder.py +0 -0
  16. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api/matching/tr_adder.py +0 -0
  17. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api/models_api.py +0 -0
  18. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api/parser.py +0 -0
  19. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api/processing.py +0 -0
  20. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api/urls.py +0 -0
  21. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api.egg-info/SOURCES.txt +0 -0
  22. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api.egg-info/dependency_links.txt +0 -0
  23. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api.egg-info/requires.txt +0 -0
  24. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/rustat_python_api.egg-info/top_level.txt +0 -0
  25. {rustat-python-api-0.7.8 → rustat-python-api-0.7.10}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: rustat-python-api
3
- Version: 0.7.8
3
+ Version: 0.7.10
4
4
  Summary: A Python wrapper for RuStat API
5
5
  Home-page: https://github.com/dailydaniel/rustat-python-api
6
6
  Author: Daniel Zholkovsky
@@ -12,52 +12,71 @@ try:
12
12
  except ImportError:
13
13
  triton_influence = None
14
14
 
15
+
15
16
  class PitchControl:
16
- def __init__(self, tracking: pd.DataFrame, events: pd.DataFrame, ball_data: pd.DataFrame = None):
17
- self.team_ids = tracking['team_id'].unique()
18
- sides = tracking.groupby('team_id')['side_1h'].unique()
19
- side_by_team = dict(zip(self.team_ids, sides[self.team_ids].apply(lambda x: x[0])))
17
+ def __init__(
18
+ self,
19
+ tracking: pd.DataFrame,
20
+ events: pd.DataFrame,
21
+ ball_data: pd.DataFrame = None,
22
+ ):
23
+ self.team_ids = tracking["team_id"].unique()
24
+ sides = tracking.groupby("team_id")["side_1h"].unique()
25
+ side_by_team = dict(
26
+ zip(self.team_ids, sides[self.team_ids].apply(lambda x: x[0]))
27
+ )
20
28
  self.side_by_half = {
21
29
  1: side_by_team,
22
- 2:
23
- {
24
- team: 'left' if side == 'right' else 'right'
25
- for team, side in side_by_team.items()
26
- }
30
+ 2: {
31
+ team: "left" if side == "right" else "right"
32
+ for team, side in side_by_team.items()
33
+ },
27
34
  }
28
35
 
29
- self._grid_cache: dict[tuple[int, str, torch.dtype], tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {}
36
+ self._grid_cache: dict[
37
+ tuple[int, str, torch.dtype],
38
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor],
39
+ ] = {}
30
40
 
31
41
  self.locs_home, self.locs_away, self.locs_ball, self.t = self.get_locs(
32
- tracking,
33
- events,
34
- ball_data
42
+ tracking, events, ball_data
35
43
  )
36
44
 
37
- def get_locs(self, tracking: pd.DataFrame, events: pd.DataFrame, ball_data: pd.DataFrame | None) -> tuple:
38
- events = events[[
39
- 'possession_number', 'team_id', 'possession_team_id',
40
- 'half', 'second', 'pos_x', 'pos_y'
41
- ]]
45
+ def get_locs(
46
+ self,
47
+ tracking: pd.DataFrame,
48
+ events: pd.DataFrame,
49
+ ball_data: pd.DataFrame | None,
50
+ ) -> tuple:
51
+ events = events[
52
+ [
53
+ "possession_number",
54
+ "team_id",
55
+ "possession_team_id",
56
+ "half",
57
+ "second",
58
+ "pos_x",
59
+ "pos_y",
60
+ ]
61
+ ]
42
62
 
43
63
  events = self.swap_coords_batch(events)
44
64
 
45
65
  if ball_data is None:
46
66
  ball_data = self.interpolate_ball_data(
47
- events[['half', 'second', 'pos_x', 'pos_y']],
48
- tracking
67
+ events[["half", "second", "pos_x", "pos_y"]], tracking
49
68
  )
50
69
 
51
70
  locs_home, locs_away = self.build_player_locs(tracking)
52
71
 
53
72
  locs_ball = {
54
- half: ball_data[ball_data['half'] == half][['pos_x', 'pos_y']].values
55
- for half in tracking['half'].unique()
73
+ half: ball_data[ball_data["half"] == half][["pos_x", "pos_y"]].values
74
+ for half in tracking["half"].unique()
56
75
  }
57
76
 
58
77
  t = {
59
- half: ball_data[ball_data['half'] == half]['second'].values
60
- for half in tracking['half'].unique()
78
+ half: ball_data[ball_data["half"] == half]["second"].values
79
+ for half in tracking["half"].unique()
61
80
  }
62
81
 
63
82
  return locs_home, locs_away, locs_ball, t
@@ -96,20 +115,22 @@ class PitchControl:
96
115
  side_by_half = self.side_by_half # local alias for speed
97
116
 
98
117
  def needs_swap(row):
99
- team_id = row['team_id']
100
- poss = row['possession_team_id']
101
- half = row['half']
118
+ team_id = row["team_id"]
119
+ poss = row["possession_team_id"]
120
+ half = row["half"]
102
121
 
103
- current_left = team_id in poss if isinstance(poss, list) else team_id == poss
104
- current_side = 'left' if current_left else 'right'
122
+ current_left = (
123
+ team_id in poss if isinstance(poss, list) else team_id == poss
124
+ )
125
+ current_side = "left" if current_left else "right"
105
126
  real_side = side_by_half[half][str(int(team_id))]
106
127
  return current_side != real_side
107
128
 
108
129
  mask = events.apply(needs_swap, axis=1)
109
130
 
110
131
  # flip coords in bulk
111
- events.loc[mask, 'pos_x'] = 105 - events.loc[mask, 'pos_x']
112
- events.loc[mask, 'pos_y'] = 68 - events.loc[mask, 'pos_y']
132
+ events.loc[mask, "pos_x"] = 105 - events.loc[mask, "pos_x"]
133
+ events.loc[mask, "pos_y"] = 68 - events.loc[mask, "pos_y"]
113
134
  return events
114
135
 
115
136
  def build_player_locs(self, tracking: pd.DataFrame):
@@ -123,36 +144,35 @@ class PitchControl:
123
144
 
124
145
  # Work per half to keep order and avoid extra boolean checks.
125
146
  for half in (1, 2):
126
- half_df = tracking[tracking['half'] == half]
127
- for side, locs_out in [('left', locs_home), ('right', locs_away)]:
128
- side_df = half_df[half_df['side_1h'] == side]
129
- for pid, grp in side_df.groupby('player_id'):
130
- locs_out[half][pid] = grp[['pos_x', 'pos_y']].values
147
+ half_df = tracking[tracking["half"] == half]
148
+ for side, locs_out in [("left", locs_home), ("right", locs_away)]:
149
+ side_df = half_df[half_df["side_1h"] == side]
150
+ for pid, grp in side_df.groupby("player_id"):
151
+ locs_out[half][pid] = grp[["pos_x", "pos_y"]].values
131
152
  return locs_home, locs_away
132
153
 
133
154
  @staticmethod
134
155
  def interpolate_ball_data(
135
- ball_data: pd.DataFrame,
136
- player_data: pd.DataFrame
156
+ ball_data: pd.DataFrame, player_data: pd.DataFrame
137
157
  ) -> pd.DataFrame:
138
- ball_data = ball_data.drop_duplicates(subset=['second', 'half'])
158
+ ball_data = ball_data.drop_duplicates(subset=["second", "half"])
139
159
 
140
160
  interpolated_data = []
141
- for half in ball_data['half'].unique():
142
- ball_half = ball_data[ball_data['half'] == half]
143
- player_half = player_data[player_data['half'] == half]
161
+ for half in ball_data["half"].unique():
162
+ ball_half = ball_data[ball_data["half"] == half]
163
+ player_half = player_data[player_data["half"] == half]
144
164
 
145
- player_times = player_half['second'].unique()
165
+ player_times = player_half["second"].unique()
146
166
 
147
- ball_half = ball_half.sort_values(by='second')
148
- interpolated_half = pd.DataFrame({'second': player_times})
149
- interpolated_half['pos_x'] = np.interp(
150
- interpolated_half['second'], ball_half['second'], ball_half['pos_x']
167
+ ball_half = ball_half.sort_values(by="second")
168
+ interpolated_half = pd.DataFrame({"second": player_times})
169
+ interpolated_half["pos_x"] = np.interp(
170
+ interpolated_half["second"], ball_half["second"], ball_half["pos_x"]
151
171
  )
152
- interpolated_half['pos_y'] = np.interp(
153
- interpolated_half['second'], ball_half['second'], ball_half['pos_y']
172
+ interpolated_half["pos_y"] = np.interp(
173
+ interpolated_half["second"], ball_half["second"], ball_half["pos_y"]
154
174
  )
155
- interpolated_half['half'] = half
175
+ interpolated_half["half"] = half
156
176
  interpolated_data.append(interpolated_half)
157
177
 
158
178
  interpolated_ball_data = pd.concat(interpolated_data, ignore_index=True)
@@ -160,16 +180,15 @@ class PitchControl:
160
180
 
161
181
  @staticmethod
162
182
  def get_player_data(player_id, half, tracking):
163
- timestamps = tracking[tracking['half'] == half]['second'].unique()
183
+ timestamps = tracking[tracking["half"] == half]["second"].unique()
164
184
  player_data = tracking[
165
- (tracking['player_id'] == player_id)
166
- & (tracking['half'] == half)
167
- ][['second', 'pos_x', 'pos_y']]
185
+ (tracking["player_id"] == player_id) & (tracking["half"] == half)
186
+ ][["second", "pos_x", "pos_y"]]
168
187
 
169
- player_data_full = pd.DataFrame({'second': timestamps})
170
- player_data_full = player_data_full.merge(player_data, on='second', how='left')
188
+ player_data_full = pd.DataFrame({"second": timestamps})
189
+ player_data_full = player_data_full.merge(player_data, on="second", how="left")
171
190
 
172
- return player_data_full[['pos_x', 'pos_y']].values
191
+ return player_data_full[["pos_x", "pos_y"]].values
173
192
 
174
193
  def influence_np(
175
194
  self,
@@ -178,11 +197,11 @@ class PitchControl:
178
197
  time_index: int,
179
198
  home_or_away: str,
180
199
  half: int,
181
- verbose: bool = False
200
+ verbose: bool = False,
182
201
  ):
183
- if home_or_away == 'h':
202
+ if home_or_away == "h":
184
203
  data = self.locs_home[half].copy()
185
- elif home_or_away == 'a':
204
+ elif home_or_away == "a":
186
205
  data = self.locs_away[half].copy()
187
206
  else:
188
207
  raise ValueError("Enter either 'h' or 'a'.")
@@ -190,38 +209,44 @@ class PitchControl:
190
209
  locs_ball = self.locs_ball[half].copy()
191
210
  t = self.t[half].copy()
192
211
 
193
- if (
194
- np.all(np.isfinite(data[player_index][[time_index, time_index + 1], :]))
195
- & np.all(np.isfinite(locs_ball[time_index, :]))
196
- ):
197
- jitter = 1e-10 ## to prevent identically zero covariance matrices when velocity is zero
212
+ if np.all(
213
+ np.isfinite(data[player_index][[time_index, time_index + 1], :])
214
+ ) & np.all(np.isfinite(locs_ball[time_index, :])):
215
+ jitter = 1e-10 ## to prevent identically zero covariance matrices when velocity is zero
198
216
  ## compute velocity by fwd difference
199
- s = (
200
- np.linalg.norm(
201
- data[player_index][time_index + 1,:]
202
- - data[player_index][time_index,:] + jitter
203
- )
204
- / (t[time_index + 1] - t[time_index])
205
- )
217
+ s = np.linalg.norm(
218
+ data[player_index][time_index + 1, :]
219
+ - data[player_index][time_index, :]
220
+ + jitter
221
+ ) / (t[time_index + 1] - t[time_index])
206
222
  ## velocities in x,y directions
207
223
  sxy = (
208
- (data[player_index][time_index + 1, :] - data[player_index][time_index, :] + jitter)
209
- / (t[time_index + 1] - t[time_index])
210
- )
224
+ data[player_index][time_index + 1, :]
225
+ - data[player_index][time_index, :]
226
+ + jitter
227
+ ) / (t[time_index + 1] - t[time_index])
211
228
  ## angle between velocity vector & x-axis
212
229
  theta = np.arccos(sxy[0] / np.linalg.norm(sxy))
213
230
  ## rotation matrix
214
- R = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
231
+ R = np.array(
232
+ [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
233
+ )
215
234
  mu = data[player_index][time_index, :] + sxy * 0.5
216
235
  Srat = (s / 13) ** 2
217
- Ri = np.linalg.norm(locs_ball[time_index, :] - data[player_index][time_index, :])
236
+ Ri = np.linalg.norm(
237
+ locs_ball[time_index, :] - data[player_index][time_index, :]
238
+ )
218
239
  ## don't think this function is specified in the paper but looks close enough to fig 9
219
- Ri = np.minimum(4 + Ri ** 3/ (18 ** 3 / 6), 10)
240
+ Ri = np.minimum(4 + Ri**3 / (18**3 / 6), 10)
220
241
  S = np.array([[(1 + Srat) * Ri / 2, 0], [0, (1 - Srat) * Ri / 2]])
221
242
  Sigma = np.matmul(R, S)
222
243
  Sigma = np.matmul(Sigma, S)
223
- Sigma = np.matmul(Sigma, np.linalg.inv(R)) ## this is not efficient, forgive me.
224
- out = mvn.pdf(location, mu, Sigma) / mvn.pdf(data[player_index][time_index, :], mu, Sigma)
244
+ Sigma = np.matmul(
245
+ Sigma, np.linalg.inv(R)
246
+ ) ## this is not efficient, forgive me.
247
+ out = mvn.pdf(location, mu, Sigma) / mvn.pdf(
248
+ data[player_index][time_index, :], mu, Sigma
249
+ )
225
250
  else:
226
251
  if verbose:
227
252
  print("Data is not finite.")
@@ -270,8 +295,12 @@ class PitchControl:
270
295
  if not pos_t_list:
271
296
  return torch.zeros(locs.shape[0], device=device, dtype=dtype)
272
297
 
273
- pos_t = torch.tensor(np.asarray(pos_t_list), device=device, dtype=dtype) # (P,2)
274
- pos_tp1 = torch.tensor(np.asarray(pos_tp1_list), device=device, dtype=dtype) # (P,2)
298
+ pos_t = torch.tensor(
299
+ np.asarray(pos_t_list), device=device, dtype=dtype
300
+ ) # (P,2)
301
+ pos_tp1 = torch.tensor(
302
+ np.asarray(pos_tp1_list), device=device, dtype=dtype
303
+ ) # (P,2)
275
304
 
276
305
  # Velocity, speed, rotation -------------------------
277
306
  dt_sec = float(self.t[half][time_index + 1] - self.t[half][time_index])
@@ -280,13 +309,15 @@ class PitchControl:
280
309
  speed = torch.linalg.norm(sxy, dim=1) # (P,)
281
310
  norm_sxy = speed.clamp(min=1e-6)
282
311
 
283
- theta = torch.acos(torch.clamp(sxy[:, 0] / norm_sxy, -1 + 1e-6, 1 - 1e-6)) # (P,)
312
+ theta = torch.acos(
313
+ torch.clamp(sxy[:, 0] / norm_sxy, -1 + 1e-6, 1 - 1e-6)
314
+ ) # (P,)
284
315
  cos_t, sin_t = torch.cos(theta), torch.sin(theta)
285
316
 
286
317
  R = torch.stack(
287
318
  [
288
319
  torch.stack([cos_t, -sin_t], dim=1),
289
- torch.stack([sin_t, cos_t], dim=1),
320
+ torch.stack([sin_t, cos_t], dim=1),
290
321
  ],
291
322
  dim=1,
292
323
  ) # (P,2,2)
@@ -298,7 +329,9 @@ class PitchControl:
298
329
  self.locs_ball[half][time_index], device=device, dtype=dtype
299
330
  ) # (2,)
300
331
  Ri = torch.linalg.norm(ball_pos - pos_t, dim=1) # (P,)
301
- Ri = torch.minimum(4 + Ri ** 3 / (18 ** 3 / 6), torch.tensor(10.0, device=device, dtype=dtype))
332
+ Ri = torch.minimum(
333
+ 4 + Ri**3 / (18**3 / 6), torch.tensor(10.0, device=device, dtype=dtype)
334
+ )
302
335
 
303
336
  S11 = (1 + Srat) * Ri / 2
304
337
  S22 = (1 - Srat) * Ri / 2
@@ -322,19 +355,18 @@ class PitchControl:
322
355
 
323
356
  # Grid diff & Mahalanobis ----------------------------
324
357
  diff = locs.view(1, -1, 2) # (1,N,2)
325
- diff = diff - mu.unsqueeze(1) # (P,N,2)
358
+ diff = diff - mu.unsqueeze(1) # (P,N,2)
326
359
 
327
360
  device = torch.device(device)
328
361
 
329
- if device.type == 'cuda':
362
+ if device.type == "cuda":
330
363
  out = triton_influence(
331
- mu.unsqueeze(0), Sigma_inv.unsqueeze(0),
332
- locs, BLOCK_N=64
364
+ mu.unsqueeze(0), Sigma_inv.unsqueeze(0), locs, BLOCK_N=64
333
365
  )[0] # (N,)
334
366
 
335
367
  return out
336
368
  else:
337
- maha = torch.einsum('pni,pij,pnj->pn', diff, Sigma_inv, diff) # (P,N)
369
+ maha = torch.einsum("pni,pij,pnj->pn", diff, Sigma_inv, diff) # (P,N)
338
370
  maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
339
371
  out = torch.exp(-0.5 * maha) # (P,N)
340
372
 
@@ -342,19 +374,19 @@ class PitchControl:
342
374
 
343
375
  def _batch_team_influence_frames_pt(
344
376
  self,
345
- pos_t: torch.Tensor, # (F,P,2)
346
- pos_tp1: torch.Tensor, # (F,P,2)
347
- ball_pos: torch.Tensor, # (F,2)
348
- dt_secs: torch.Tensor, # (F,)
349
- locs: torch.Tensor, # (N,2)
377
+ pos_t: torch.Tensor, # (F,P,2)
378
+ pos_tp1: torch.Tensor, # (F,P,2)
379
+ ball_pos: torch.Tensor, # (F,2)
380
+ dt_secs: torch.Tensor, # (F,)
381
+ locs: torch.Tensor, # (N,2)
350
382
  dtype: torch.dtype,
351
- ) -> torch.Tensor: # returns (F,N)
383
+ ) -> torch.Tensor: # returns (F,N)
352
384
  """Vectorised influence for many frames & players of ОДНОЙ команды."""
353
385
 
354
386
  device = locs.device
355
387
 
356
388
  sxy = (pos_tp1 - pos_t) / dt_secs[:, None, None] # (F,P,2)
357
- speed = torch.linalg.norm(sxy, dim=-1) # (F,P)
389
+ speed = torch.linalg.norm(sxy, dim=-1) # (F,P)
358
390
  norm_sxy = speed.clamp(min=1e-6)
359
391
 
360
392
  theta = torch.acos(torch.clamp(sxy[..., 0] / norm_sxy, -1 + 1e-6, 1 - 1e-6))
@@ -363,7 +395,7 @@ class PitchControl:
363
395
  R = torch.stack(
364
396
  [
365
397
  torch.stack([cos_t, -sin_t], dim=-1),
366
- torch.stack([sin_t, cos_t], dim=-1),
398
+ torch.stack([sin_t, cos_t], dim=-1),
367
399
  ],
368
400
  dim=-2,
369
401
  ) # (F,P,2,2)
@@ -371,12 +403,16 @@ class PitchControl:
371
403
  Srat = (speed / 13) ** 2 # (F,P)
372
404
 
373
405
  Ri = torch.linalg.norm(ball_pos[:, None, :] - pos_t, dim=-1) # (F,P)
374
- Ri = torch.minimum(4 + Ri ** 3 / (18 ** 3 / 6), torch.tensor(10.0, device=device, dtype=dtype))
406
+ Ri = torch.minimum(
407
+ 4 + Ri**3 / (18**3 / 6), torch.tensor(10.0, device=device, dtype=dtype)
408
+ )
375
409
 
376
410
  S11 = (1 + Srat) * Ri / 2
377
411
  S22 = (1 - Srat) * Ri / 2
378
412
 
379
- S = torch.zeros((*pos_t.shape[:-1], 2, 2), device=device, dtype=dtype) # (F,P,2,2)
413
+ S = torch.zeros(
414
+ (*pos_t.shape[:-1], 2, 2), device=device, dtype=dtype
415
+ ) # (F,P,2,2)
380
416
  S[..., 0, 0] = S11
381
417
  S[..., 1, 1] = S22
382
418
 
@@ -393,21 +429,23 @@ class PitchControl:
393
429
  mu = pos_t + 0.5 * sxy # (F,P,2)
394
430
 
395
431
  diff = locs.view(1, 1, -1, 2) # (1,1,N,2)
396
- diff = diff - mu.unsqueeze(2) # (F,P,N,2)
432
+ diff = diff - mu.unsqueeze(2) # (F,P,N,2)
397
433
 
398
- if device.type == 'cuda':
434
+ if device.type == "cuda":
399
435
  out = triton_influence(mu, Sigma_inv, locs, BLOCK_N=64) # (F,N)
400
436
 
401
437
  return out
402
438
  else:
403
- maha = torch.einsum('fpni,fpij,fpnj->fpn', diff, Sigma_inv, diff) # (F,P,N)
439
+ maha = torch.einsum("fpni,fpij,fpnj->fpn", diff, Sigma_inv, diff) # (F,P,N)
404
440
  maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
405
441
  out = torch.exp(-0.5 * maha) # (F,P,N)
406
442
 
407
443
  return out.sum(dim=1) # sum over players
408
444
 
409
445
  @staticmethod
410
- def _stack_team_frames(players: list[np.ndarray], frames: np.ndarray, device: str, dtype: torch.dtype):
446
+ def _stack_team_frames(
447
+ players: list[np.ndarray], frames: np.ndarray, device: str, dtype: torch.dtype
448
+ ):
411
449
  """Stack positions for given frames into torch tensors (pos_t, pos_tp1)."""
412
450
  # Ensure every player's trajectory is long enough; if not, pad by repeating
413
451
  # the last available coordinate so that indexing `frames` and `frames+1` is safe.
@@ -426,7 +464,9 @@ class PitchControl:
426
464
  np.stack([p[frames] for p in padded], axis=1), device=device, dtype=dtype
427
465
  ) # (F,P,2)
428
466
  pos_tp1 = torch.tensor(
429
- np.stack([p[frames + 1] for p in padded], axis=1), device=device, dtype=dtype
467
+ np.stack([p[frames + 1] for p in padded], axis=1),
468
+ device=device,
469
+ dtype=dtype,
430
470
  )
431
471
  return pos_t, pos_tp1
432
472
 
@@ -438,6 +478,7 @@ class PitchControl:
438
478
  batch_size: int,
439
479
  use_fp16: bool,
440
480
  verbose: bool,
481
+ return_teams: str = "both",
441
482
  ):
442
483
  """Internal helper with fully batched PyTorch implementation."""
443
484
 
@@ -467,8 +508,12 @@ class PitchControl:
467
508
  )
468
509
 
469
510
  # stack teams
470
- pos_t_h, pos_tp1_h = self._stack_team_frames(home_players, frames, device, dtype)
471
- pos_t_a, pos_tp1_a = self._stack_team_frames(away_players, frames, device, dtype)
511
+ pos_t_h, pos_tp1_h = self._stack_team_frames(
512
+ home_players, frames, device, dtype
513
+ )
514
+ pos_t_a, pos_tp1_a = self._stack_team_frames(
515
+ away_players, frames, device, dtype
516
+ )
472
517
 
473
518
  Zh = self._batch_team_influence_frames_pt(
474
519
  pos_t_h, pos_tp1_h, ball_pos, dt_secs, locs_t, dtype
@@ -477,15 +522,30 @@ class PitchControl:
477
522
  pos_t_a, pos_tp1_a, ball_pos, dt_secs, locs_t, dtype
478
523
  )
479
524
 
480
- res_batch = torch.sigmoid(Za - Zh).reshape(-1, dt, dt)
525
+ # Return raw team influence or combined logistic transformation
526
+ if return_teams == "home":
527
+ res_batch = torch.clamp(Zh, 0, 1).reshape(
528
+ -1, dt, dt
529
+ ) # home team influence only
530
+ elif return_teams == "away":
531
+ res_batch = torch.clamp(Za, 0, 1).reshape(
532
+ -1, dt, dt
533
+ ) # away team influence only
534
+ else: # "both"
535
+ res_batch = torch.sigmoid(Za - Zh).reshape(
536
+ -1, dt, dt
537
+ ) # logistic transformation
538
+
481
539
  pc_all[start:end] = res_batch.cpu().numpy().astype(np.float32)
482
540
 
483
541
  if verbose:
484
- print(f"pt full: frames {start}-{end-1} done")
542
+ print(f"pt full: frames {start}-{end - 1} done")
485
543
 
486
544
  return pc_all, xx, yy
487
545
 
488
- def _get_grid(self, dt: int, device: str, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
546
+ def _get_grid(
547
+ self, dt: int, device: str, dtype: torch.dtype
548
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
489
549
  """Helper to create a grid of locations for torch-based pitch control."""
490
550
  x = torch.linspace(0, 105, dt, device=device, dtype=dtype)
491
551
  y = torch.linspace(0, 68, dt, device=device, dtype=dtype)
@@ -501,6 +561,7 @@ class PitchControl:
501
561
  device: str = "cpu",
502
562
  verbose: bool = False,
503
563
  use_fp16: bool = True,
564
+ return_teams: str = "both",
504
565
  ):
505
566
  """Torch-accelerated pitch-control calculation.
506
567
 
@@ -533,31 +594,51 @@ class PitchControl:
533
594
  self.locs_away[half], locs_t, tp, half, device, dtype
534
595
  )
535
596
 
536
- res_t = torch.sigmoid(Za - Zh).reshape(dt, dt)
597
+ # Return raw team influence or combined logistic transformation
598
+ if return_teams == "home":
599
+ res_t = torch.clamp(Zh, 0, 1).reshape(dt, dt) # home team influence only
600
+ elif return_teams == "away":
601
+ res_t = torch.clamp(Za, 0, 1).reshape(dt, dt) # away team influence only
602
+ else: # "both"
603
+ res_t = torch.sigmoid(Za - Zh).reshape(dt, dt) # logistic transformation
537
604
 
538
605
  # Convert to numpy for downstream plotting
539
606
  return res_t.cpu().numpy(), xx_t.cpu().numpy(), yy_t.cpu().numpy()
540
607
 
541
- def _fit_np(self, half: int, tp: int, dt: int, verbose: bool = False) -> tuple:
608
+ def _fit_np(
609
+ self,
610
+ half: int,
611
+ tp: int,
612
+ dt: int,
613
+ verbose: bool = False,
614
+ return_teams: str = "both",
615
+ ) -> tuple:
542
616
  x = np.linspace(0, 105, dt)
543
617
  y = np.linspace(0, 68, dt)
544
618
  xx, yy = np.meshgrid(x, y)
545
619
 
546
- Zh = np.zeros(dt*dt)
547
- Za = np.zeros(dt*dt)
620
+ Zh = np.zeros(dt * dt)
621
+ Za = np.zeros(dt * dt)
548
622
 
549
- locations = np.c_[xx.flatten(),yy.flatten()]
623
+ locations = np.c_[xx.flatten(), yy.flatten()]
550
624
 
551
625
  for k in self.locs_home[half].keys():
552
626
  # if len(self.locs_home[half][k]) >= tp:
553
- Zh += self.influence_np(k, locations, tp, 'h', half, verbose)
627
+ Zh += self.influence_np(k, locations, tp, "h", half, verbose)
554
628
  for k in self.locs_away[half].keys():
555
629
  # if len(self.locs_away[half][k]) >= tp:
556
- Za += self.influence_np(k, locations, tp, 'a', half, verbose)
630
+ Za += self.influence_np(k, locations, tp, "a", half, verbose)
557
631
 
558
632
  Zh = Zh.reshape((dt, dt))
559
633
  Za = Za.reshape((dt, dt))
560
- result = 1 / (1 + np.exp(-Za + Zh))
634
+
635
+ # Return raw team influence or combined logistic transformation
636
+ if return_teams == "home":
637
+ result = np.clip(Zh, 0, 1) # home team influence only
638
+ elif return_teams == "away":
639
+ result = np.clip(Za, 0, 1) # away team influence only
640
+ else: # "both"
641
+ result = 1 / (1 + np.exp(-Za + Zh)) # logistic transformation
561
642
 
562
643
  return result, xx, yy
563
644
 
@@ -570,13 +651,22 @@ class PitchControl:
570
651
  device: str = "cpu",
571
652
  verbose: bool = False,
572
653
  use_fp16: bool = True,
654
+ return_teams: str = "both",
573
655
  ):
574
656
  """Selects NumPy or PyTorch backend depending on `backend`."""
575
657
  match backend:
576
658
  case "np" | "numpy":
577
- return self._fit_np(half, tp, dt, verbose)
659
+ return self._fit_np(half, tp, dt, verbose, return_teams)
578
660
  case "torch" | "pt":
579
- return self._fit_pt(half, tp, dt, device=device, verbose=verbose, use_fp16=use_fp16)
661
+ return self._fit_pt(
662
+ half,
663
+ tp,
664
+ dt,
665
+ device=device,
666
+ verbose=verbose,
667
+ use_fp16=use_fp16,
668
+ return_teams=return_teams,
669
+ )
580
670
  case _:
581
671
  raise ValueError(f"Unknown backend '{backend}'. Use 'np' or 'torch'.")
582
672
 
@@ -586,9 +676,10 @@ class PitchControl:
586
676
  dt: int = 100,
587
677
  backend: str = "np",
588
678
  device: str = "cpu",
589
- batch_size: int = 30*60,
679
+ batch_size: int = 30 * 60,
590
680
  use_fp16: bool = True,
591
681
  verbose: bool = False,
682
+ return_teams: str = "both",
592
683
  ):
593
684
  """Compute pitch-control map for *каждый* кадр тайма.
594
685
 
@@ -606,7 +697,9 @@ class PitchControl:
606
697
  case "np" | "numpy":
607
698
  pc_all = np.empty((T, dt, dt), dtype=np.float32)
608
699
  for tp in tqdm(range(T)):
609
- pc_map, xx, yy = self._fit_np(half, tp, dt, verbose=False)
700
+ pc_map, xx, yy = self._fit_np(
701
+ half, tp, dt, verbose=False, return_teams=return_teams
702
+ )
610
703
  pc_all[tp] = pc_map.astype(np.float32)
611
704
  if verbose and tp % 500 == 0:
612
705
  print(f"np full-match: done {tp}/{T}")
@@ -614,7 +707,7 @@ class PitchControl:
614
707
 
615
708
  case "torch" | "pt":
616
709
  return self._fit_full_pt(
617
- half, dt, device, batch_size, use_fp16, verbose
710
+ half, dt, device, batch_size, use_fp16, verbose, return_teams
618
711
  )
619
712
  case _:
620
713
  raise ValueError("backend must be 'np' or 'pt'")
@@ -626,7 +719,7 @@ class PitchControl:
626
719
  pitch_control: tuple = None,
627
720
  save: bool = False,
628
721
  dt: int = 200,
629
- filename: str = 'pitch_control'
722
+ filename: str = "pitch_control",
630
723
  ):
631
724
  if pitch_control is None:
632
725
  pitch_control, xx, yy = self.fit(half, tp, dt)
@@ -641,29 +734,34 @@ class PitchControl:
641
734
 
642
735
  for k in self.locs_home[half].keys():
643
736
  # if len(self.locs_home[half][k]) >= tp:
644
- if np.isfinite(self.locs_home[half][k][tp, :]).all():
645
- plt.scatter(
646
- self.locs_home[half][k][tp, 0],
647
- self.locs_home[half][k][tp, 1],
648
- color='darkgrey'
649
- )
737
+ try:
738
+ if np.isfinite(self.locs_home[half][k][tp, :]).all():
739
+ plt.scatter(
740
+ self.locs_home[half][k][tp, 0],
741
+ self.locs_home[half][k][tp, 1],
742
+ color="darkgrey",
743
+ )
744
+ except Exception as e:
745
+ print(f"No data for player {k}: {e}")
650
746
 
651
747
  for k in self.locs_away[half].keys():
652
748
  # if len(self.locs_away[half][k]) >= tp:
653
- if np.isfinite(self.locs_away[half][k][tp, :]).all():
654
- plt.scatter(
655
- self.locs_away[half][k][tp, 0],
656
- self.locs_away[half][k][tp, 1], color='black'
657
- )
749
+ try:
750
+ if np.isfinite(self.locs_away[half][k][tp, :]).all():
751
+ plt.scatter(
752
+ self.locs_away[half][k][tp, 0],
753
+ self.locs_away[half][k][tp, 1],
754
+ color="black",
755
+ )
756
+ except Exception as e:
757
+ print(f"No data for player {k}: {e}")
658
758
 
659
759
  plt.scatter(
660
- self.locs_ball[half][tp, 0],
661
- self.locs_ball[half][tp, 1],
662
- color='red'
760
+ self.locs_ball[half][tp, 0], self.locs_ball[half][tp, 1], color="red"
663
761
  )
664
762
 
665
763
  if save:
666
- plt.savefig(f'{filename}.png', dpi=300)
764
+ plt.savefig(f"{filename}.png", dpi=300)
667
765
  else:
668
766
  plt.show()
669
767
 
@@ -674,7 +772,7 @@ class PitchControl:
674
772
  filename: str = "pitch_control_animation",
675
773
  dt: int = 200,
676
774
  frames: int = 30,
677
- interval: int = 1000
775
+ interval: int = 1000,
678
776
  ):
679
777
  """
680
778
  ffmpeg should be installed on your machine.
@@ -686,7 +784,7 @@ class PitchControl:
686
784
  pitch_control, xx, yy = self.fit(half, fr, dt)
687
785
 
688
786
  mpl.field("white", show=False, ax=ax)
689
- ax.axis('off')
787
+ ax.axis("off")
690
788
 
691
789
  plt.contourf(xx, yy, pitch_control)
692
790
 
@@ -696,7 +794,7 @@ class PitchControl:
696
794
  plt.scatter(
697
795
  self.locs_home[half][k][fr, 0],
698
796
  self.locs_home[half][k][fr, 1],
699
- color='darkgrey'
797
+ color="darkgrey",
700
798
  )
701
799
  for k in self.locs_away[half].keys():
702
800
  # if len(self.locs_away[half][k]) >= fr:
@@ -704,13 +802,11 @@ class PitchControl:
704
802
  plt.scatter(
705
803
  self.locs_away[half][k][fr, 0],
706
804
  self.locs_away[half][k][fr, 1],
707
- color='black'
805
+ color="black",
708
806
  )
709
807
 
710
808
  plt.scatter(
711
- self.locs_ball[half][fr, 0],
712
- self.locs_ball[half][fr, 1],
713
- color='red'
809
+ self.locs_ball[half][fr, 0], self.locs_ball[half][fr, 1], color="red"
714
810
  )
715
811
 
716
812
  return ax
@@ -724,7 +820,7 @@ class PitchControl:
724
820
  func=animate,
725
821
  frames=min(frames, len(self.locs_ball[half]) - tp),
726
822
  interval=interval,
727
- blit=False
823
+ blit=False,
728
824
  )
729
825
 
730
- ani.save(f'{filename}.mp4', writer='ffmpeg')
826
+ ani.save(f"{filename}.mp4", writer="ffmpeg")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: rustat-python-api
3
- Version: 0.7.8
3
+ Version: 0.7.10
4
4
  Summary: A Python wrapper for RuStat API
5
5
  Home-page: https://github.com/dailydaniel/rustat-python-api
6
6
  Author: Daniel Zholkovsky
@@ -0,0 +1,32 @@
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="rustat-python-api",
5
+ version="0.7.10",
6
+ description="A Python wrapper for RuStat API",
7
+ long_description=open("README.md").read(),
8
+ long_description_content_type="text/markdown",
9
+ author="Daniel Zholkovsky",
10
+ author_email="daniel@zholkovsky.com",
11
+ url="https://github.com/dailydaniel/rustat-python-api",
12
+ license="MIT",
13
+ packages=find_packages(),
14
+ install_requires=[
15
+ "requests==2.32.3",
16
+ "pandas==2.2.3",
17
+ "tqdm==4.66.5",
18
+ "scipy==1.14.1",
19
+ "matplotlib",
20
+ "matplotsoccer",
21
+ ],
22
+ extras_require={
23
+ "gpu": ["torch", "triton==3.0.0; platform_system=='Linux'"],
24
+ "cpu": ["torch"],
25
+ },
26
+ classifiers=[
27
+ "Programming Language :: Python :: 3",
28
+ "License :: OSI Approved :: MIT License",
29
+ "Operating System :: OS Independent",
30
+ ],
31
+ python_requires=">=3.10",
32
+ )
@@ -1,37 +0,0 @@
1
- from setuptools import setup, find_packages
2
-
3
- setup(
4
- name='rustat-python-api',
5
- version='0.7.8',
6
- description='A Python wrapper for RuStat API',
7
- long_description=open('README.md').read(),
8
- long_description_content_type='text/markdown',
9
- author='Daniel Zholkovsky',
10
- author_email='daniel@zholkovsky.com',
11
- url='https://github.com/dailydaniel/rustat-python-api',
12
- license='MIT',
13
- packages=find_packages(),
14
- install_requires=[
15
- 'requests==2.32.3',
16
- 'pandas==2.2.3',
17
- 'tqdm==4.66.5',
18
- 'scipy==1.14.1',
19
- 'matplotlib',
20
- 'matplotsoccer'
21
- ],
22
- extras_require={
23
- "gpu": [
24
- "torch",
25
- "triton==3.0.0; platform_system=='Linux'"
26
- ],
27
- "cpu": [
28
- "torch"
29
- ]
30
- },
31
- classifiers=[
32
- 'Programming Language :: Python :: 3',
33
- 'License :: OSI Approved :: MIT License',
34
- 'Operating System :: OS Independent',
35
- ],
36
- python_requires='>=3.10',
37
- )