rustat-python-api 0.5.8__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,8 @@ from scipy.stats import multivariate_normal as mvn
4
4
  import matplotlib.pyplot as plt
5
5
  import matplotlib.animation as animation
6
6
  import matplotsoccer as mpl
7
+ import torch
8
+ from tqdm import tqdm
7
9
 
8
10
 
9
11
  class PitchControl:
@@ -20,6 +22,8 @@ class PitchControl:
20
22
  }
21
23
  }
22
24
 
25
+ self._grid_cache: dict[tuple[int, str, torch.dtype], tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {}
26
+
23
27
  self.locs_home, self.locs_away, self.locs_ball, self.t = self.get_locs(
24
28
  tracking,
25
29
  events,
@@ -32,12 +36,7 @@ class PitchControl:
32
36
  'half', 'second', 'pos_x', 'pos_y'
33
37
  ]]
34
38
 
35
- events.loc[:, 'pos_x'] = events.apply(
36
- lambda x: self.swap_coords(x, 'x'), axis=1
37
- )
38
- events.loc[:, 'pos_y'] = events.apply(
39
- lambda x: self.swap_coords(x, 'y'), axis=1
40
- )
39
+ events = self.swap_coords_batch(events)
41
40
 
42
41
  if ball_data is None:
43
42
  ball_data = self.interpolate_ball_data(
@@ -45,21 +44,7 @@ class PitchControl:
45
44
  tracking
46
45
  )
47
46
 
48
- locs_home = {
49
- half: {
50
- player_id: self.get_player_data(player_id, half, tracking)
51
- for player_id in tracking[tracking['side_1h'] == 'left']['player_id'].unique()
52
- }
53
- for half in tracking['half'].unique()
54
- }
55
-
56
- locs_away = {
57
- half: {
58
- player_id: self.get_player_data(player_id, half, tracking)
59
- for player_id in tracking[tracking['side_1h'] == 'right']['player_id'].unique()
60
- }
61
- for half in tracking['half'].unique()
62
- }
47
+ locs_home, locs_away = self.build_player_locs(tracking)
63
48
 
64
49
  locs_ball = {
65
50
  half: ball_data[ball_data['half'] == half][['pos_x', 'pos_y']].values
@@ -73,28 +58,73 @@ class PitchControl:
73
58
 
74
59
  return locs_home, locs_away, locs_ball, t
75
60
 
61
+ # def swap_coords(self, row, how: str = 'x'):
62
+ # half = row['half']
63
+ # team_id = row['team_id']
64
+ # possession_team_id = row['possession_team_id']
65
+ # x = row['pos_x']
66
+ # y = row['pos_y']
67
+
68
+ # if isinstance(possession_team_id, list):
69
+ # current_side = 'left' if team_id in possession_team_id else 'right'
70
+ # real_side = self.side_by_half[half][str(int(team_id))]
71
+ # else:
72
+ # current_side = 'left' if team_id == possession_team_id else 'right'
73
+ # real_side = self.side_by_half[half][str(int(team_id))]
74
+
75
+ # if current_side != real_side:
76
+ # if how == 'x':
77
+ # x = 105 - x
78
+ # else:
79
+ # y = 68 - y
80
+
81
+ # return x if how == 'x' else y
82
+
83
+ def swap_coords_batch(self, events: pd.DataFrame) -> pd.DataFrame:
84
+ """Vectorised replacement for per-row `swap_coords`.
85
+
86
+ Modifies *events* in-place: flips coordinates for rows where the
87
+ current attacking direction (left/right) does not match the
88
+ canonical side stored in ``self.side_by_half``.
89
+ Returns the same DataFrame for chaining.
90
+ """
91
+
92
+ side_by_half = self.side_by_half # local alias for speed
76
93
 
77
- def swap_coords(self, row, how: str = 'x'):
78
- half = row['half']
79
- team_id = row['team_id']
80
- possession_team_id = row['possession_team_id']
81
- x = row['pos_x']
82
- y = row['pos_y']
94
+ def needs_swap(row):
95
+ team_id = row['team_id']
96
+ poss = row['possession_team_id']
97
+ half = row['half']
83
98
 
84
- if isinstance(possession_team_id, list):
85
- current_side = 'left' if team_id in possession_team_id else 'right'
86
- real_side = self.side_by_half[half][str(int(team_id))]
87
- else:
88
- current_side = 'left' if team_id == possession_team_id else 'right'
89
- real_side = self.side_by_half[half][str(int(team_id))]
99
+ current_left = team_id in poss if isinstance(poss, list) else team_id == poss
100
+ current_side = 'left' if current_left else 'right'
101
+ real_side = side_by_half[half][str(int(team_id))]
102
+ return current_side != real_side
90
103
 
91
- if current_side != real_side:
92
- if how == 'x':
93
- x = 105 - x
94
- else:
95
- y = 68 - y
104
+ mask = events.apply(needs_swap, axis=1)
96
105
 
97
- return x if how == 'x' else y
106
+ # flip coords in bulk
107
+ events.loc[mask, 'pos_x'] = 105 - events.loc[mask, 'pos_x']
108
+ events.loc[mask, 'pos_y'] = 68 - events.loc[mask, 'pos_y']
109
+ return events
110
+
111
+ def build_player_locs(self, tracking: pd.DataFrame):
112
+ """Vectorised construction of player location dictionaries.
113
+
114
+ Returns (locs_home, locs_away) where each is
115
+ {half: {player_id: np.ndarray(T,2)}}.
116
+ """
117
+ locs_home = {1: {}, 2: {}}
118
+ locs_away = {1: {}, 2: {}}
119
+
120
+ # Work per half to keep order and avoid extra boolean checks.
121
+ for half in (1, 2):
122
+ half_df = tracking[tracking['half'] == half]
123
+ for side, locs_out in [('left', locs_home), ('right', locs_away)]:
124
+ side_df = half_df[half_df['side_1h'] == side]
125
+ for pid, grp in side_df.groupby('player_id'):
126
+ locs_out[half][pid] = grp[['pos_x', 'pos_y']].values
127
+ return locs_home, locs_away
98
128
 
99
129
  @staticmethod
100
130
  def interpolate_ball_data(
@@ -137,12 +167,7 @@ class PitchControl:
137
167
 
138
168
  return player_data_full[['pos_x', 'pos_y']].values
139
169
 
140
- # return tracking[
141
- # (tracking['player_id'] == player_id)
142
- # & (tracking['half'] == half)
143
- # ][['pos_x', 'pos_y']].values
144
-
145
- def influence_function(
170
+ def influence_np(
146
171
  self,
147
172
  player_index: str,
148
173
  location: np.ndarray,
@@ -199,7 +224,297 @@ class PitchControl:
199
224
  out = np.zeros(location.shape[0])
200
225
  return out
201
226
 
202
- def fit(self, half: int, tp: int, dt: int, verbose: bool = False) -> tuple:
227
+ def _batch_influence_pt(
228
+ self,
229
+ player_dict: dict,
230
+ locs: torch.Tensor,
231
+ time_index: int,
232
+ half: int,
233
+ device: str,
234
+ dtype: torch.dtype,
235
+ ) -> torch.Tensor:
236
+ """Compute cumulative influence of *many* players at once.
237
+
238
+ Parameters
239
+ ----------
240
+ player_dict : dict[player_id -> np.ndarray(shape=(T,2))]
241
+ Pre-loaded trajectory arrays for one team.
242
+ locs : torch.Tensor shape (N,2)
243
+ Grid locations (already on correct device / dtype).
244
+ time_index : int
245
+ Frame index t.
246
+ half : int
247
+ Half number.
248
+ device, dtype : torch configuration.
249
+
250
+ Returns
251
+ -------
252
+ torch.Tensor shape (N,)
253
+ Sum of influences from all valid players in *player_dict*.
254
+ """
255
+
256
+ pos_t_list, pos_tp1_list = [], []
257
+ for arr in player_dict.values():
258
+ # Ensure we have t and t+1 and no NaNs at those rows
259
+ if (
260
+ time_index + 1 < arr.shape[0]
261
+ and np.isfinite(arr[[time_index, time_index + 1], :]).all()
262
+ ):
263
+ pos_t_list.append(arr[time_index])
264
+ pos_tp1_list.append(arr[time_index + 1])
265
+
266
+ if not pos_t_list:
267
+ return torch.zeros(locs.shape[0], device=device, dtype=dtype)
268
+
269
+ pos_t = torch.tensor(np.asarray(pos_t_list), device=device, dtype=dtype) # (P,2)
270
+ pos_tp1 = torch.tensor(np.asarray(pos_tp1_list), device=device, dtype=dtype) # (P,2)
271
+
272
+ # Velocity, speed, rotation -------------------------
273
+ dt_sec = float(self.t[half][time_index + 1] - self.t[half][time_index])
274
+ sxy = (pos_tp1 - pos_t) / dt_sec # (P,2)
275
+
276
+ speed = torch.linalg.norm(sxy, dim=1) # (P,)
277
+ norm_sxy = speed.clamp(min=1e-6)
278
+
279
+ theta = torch.acos(torch.clamp(sxy[:, 0] / norm_sxy, -1 + 1e-6, 1 - 1e-6)) # (P,)
280
+ cos_t, sin_t = torch.cos(theta), torch.sin(theta)
281
+
282
+ R = torch.stack(
283
+ [
284
+ torch.stack([cos_t, -sin_t], dim=1),
285
+ torch.stack([sin_t, cos_t], dim=1),
286
+ ],
287
+ dim=1,
288
+ ) # (P,2,2)
289
+
290
+ # Shape parameters ----------------------------------
291
+ Srat = (speed / 13) ** 2 # (P,)
292
+
293
+ ball_pos = torch.tensor(
294
+ self.locs_ball[half][time_index], device=device, dtype=dtype
295
+ ) # (2,)
296
+ Ri = torch.linalg.norm(ball_pos - pos_t, dim=1) # (P,)
297
+ Ri = torch.minimum(4 + Ri ** 3 / (18 ** 3 / 6), torch.tensor(10.0, device=device, dtype=dtype))
298
+
299
+ S11 = (1 + Srat) * Ri / 2
300
+ S22 = (1 - Srat) * Ri / 2
301
+
302
+ S = torch.zeros((pos_t.shape[0], 2, 2), device=device, dtype=dtype)
303
+ S[:, 0, 0] = S11
304
+ S[:, 1, 1] = S22
305
+
306
+ Sigma = R @ S @ S @ R.transpose(1, 2) # (P,2,2)
307
+
308
+ eye = torch.eye(2, device=device, dtype=dtype) * 1e-6
309
+ eye = eye.expand(pos_t.shape[0], 2, 2) # broadcast to (P,2,2)
310
+
311
+ if dtype == torch.float16:
312
+ Sigma_inv = torch.linalg.inv((Sigma + eye).float()).to(dtype)
313
+ else:
314
+ Sigma_inv = torch.linalg.inv(Sigma + eye)
315
+
316
+ # Mean ----------------------------------------------
317
+ mu = pos_t + 0.5 * sxy # (P,2)
318
+
319
+ # Grid diff & Mahalanobis ----------------------------
320
+ diff = locs.view(1, -1, 2) # (1,N,2)
321
+ diff = diff - mu.unsqueeze(1) # (P,N,2)
322
+
323
+ maha = torch.einsum('pni,pij,pnj->pn', diff, Sigma_inv, diff) # (P,N)
324
+
325
+ # Replace NaNs that arise from invalid player positions with large value
326
+ # so their influence tends to zero after exponent, then eliminate residual NaNs.
327
+ maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
328
+
329
+ out = torch.exp(-0.5 * maha) # (P,N)
330
+
331
+ return out.sum(dim=0) # sum over players
332
+
333
+ def _batch_team_influence_frames_pt(
334
+ self,
335
+ pos_t: torch.Tensor, # (F,P,2)
336
+ pos_tp1: torch.Tensor, # (F,P,2)
337
+ ball_pos: torch.Tensor, # (F,2)
338
+ dt_secs: torch.Tensor, # (F,)
339
+ locs: torch.Tensor, # (N,2)
340
+ dtype: torch.dtype,
341
+ ) -> torch.Tensor: # returns (F,N)
342
+ """Vectorised influence for many frames & players of ОДНОЙ команды."""
343
+
344
+ device = locs.device
345
+
346
+ sxy = (pos_tp1 - pos_t) / dt_secs[:, None, None] # (F,P,2)
347
+ speed = torch.linalg.norm(sxy, dim=-1) # (F,P)
348
+ norm_sxy = speed.clamp(min=1e-6)
349
+
350
+ theta = torch.acos(torch.clamp(sxy[..., 0] / norm_sxy, -1 + 1e-6, 1 - 1e-6))
351
+ cos_t, sin_t = torch.cos(theta), torch.sin(theta) # (F,P)
352
+
353
+ R = torch.stack(
354
+ [
355
+ torch.stack([cos_t, -sin_t], dim=-1),
356
+ torch.stack([sin_t, cos_t], dim=-1),
357
+ ],
358
+ dim=-2,
359
+ ) # (F,P,2,2)
360
+
361
+ Srat = (speed / 13) ** 2 # (F,P)
362
+
363
+ Ri = torch.linalg.norm(ball_pos[:, None, :] - pos_t, dim=-1) # (F,P)
364
+ Ri = torch.minimum(4 + Ri ** 3 / (18 ** 3 / 6), torch.tensor(10.0, device=device, dtype=dtype))
365
+
366
+ S11 = (1 + Srat) * Ri / 2
367
+ S22 = (1 - Srat) * Ri / 2
368
+
369
+ S = torch.zeros((*pos_t.shape[:-1], 2, 2), device=device, dtype=dtype) # (F,P,2,2)
370
+ S[..., 0, 0] = S11
371
+ S[..., 1, 1] = S22
372
+
373
+ Sigma = R @ S @ S @ R.transpose(-1, -2) # (F,P,2,2)
374
+
375
+ eye = torch.eye(2, device=device, dtype=dtype) * 1e-6
376
+ eye = eye.view(1, 1, 2, 2)
377
+
378
+ if dtype == torch.float16:
379
+ Sigma_inv = torch.linalg.inv((Sigma + eye).float()).to(dtype)
380
+ else:
381
+ Sigma_inv = torch.linalg.inv(Sigma + eye)
382
+
383
+ mu = pos_t + 0.5 * sxy # (F,P,2)
384
+
385
+ diff = locs.view(1, 1, -1, 2) # (1,1,N,2)
386
+ diff = diff - mu.unsqueeze(2) # (F,P,N,2)
387
+
388
+ maha = torch.einsum('fpni,fpij,fpnj->fpn', diff, Sigma_inv, diff) # (F,P,N)
389
+
390
+ # Replace NaNs that arise from invalid player positions with large value
391
+ # so their influence tends to zero after exponent, then eliminate residual NaNs.
392
+ maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
393
+
394
+ out = torch.exp(-0.5 * maha) # (F,P,N)
395
+
396
+ return out.sum(dim=1) # sum over players
397
+
398
+ @staticmethod
399
+ def _stack_team_frames(players: list[np.ndarray], frames: np.ndarray, device: str, dtype: torch.dtype):
400
+ """Stack positions for given frames into torch tensors (pos_t, pos_tp1)."""
401
+ pos_t = torch.tensor(
402
+ np.stack([p[frames] for p in players], axis=1), device=device, dtype=dtype
403
+ ) # (F,P,2)
404
+ pos_tp1 = torch.tensor(
405
+ np.stack([p[frames + 1] for p in players], axis=1), device=device, dtype=dtype
406
+ )
407
+ return pos_t, pos_tp1
408
+
409
+ def _fit_full_pt(
410
+ self,
411
+ half: int,
412
+ dt: int,
413
+ device: str,
414
+ batch_size: int,
415
+ use_fp16: bool,
416
+ verbose: bool,
417
+ ):
418
+ """Internal helper with fully batched PyTorch implementation."""
419
+
420
+ dtype = torch.float16 if use_fp16 else torch.float32
421
+ xx_t, yy_t, locs_t = self._get_grid(dt, device, dtype)
422
+ xx, yy = xx_t.cpu().numpy(), yy_t.cpu().numpy()
423
+
424
+ T = len(self.t[half]) - 1
425
+ pc_all = np.empty((T, dt, dt), dtype=np.float32)
426
+
427
+ home_players = list(self.locs_home[half].values())
428
+ away_players = list(self.locs_away[half].values())
429
+
430
+ for start in tqdm(range(0, T, batch_size)):
431
+ end = min(start + batch_size, T)
432
+ frames = np.arange(start, end)
433
+
434
+ # deltas t
435
+ dt_secs = torch.tensor(
436
+ self.t[half][frames + 1] - self.t[half][frames],
437
+ device=device,
438
+ dtype=dtype,
439
+ )
440
+
441
+ ball_pos = torch.tensor(
442
+ self.locs_ball[half][frames], device=device, dtype=dtype
443
+ )
444
+
445
+ # stack teams
446
+ pos_t_h, pos_tp1_h = self._stack_team_frames(home_players, frames, device, dtype)
447
+ pos_t_a, pos_tp1_a = self._stack_team_frames(away_players, frames, device, dtype)
448
+
449
+ Zh = self._batch_team_influence_frames_pt(
450
+ pos_t_h, pos_tp1_h, ball_pos, dt_secs, locs_t, dtype
451
+ )
452
+ Za = self._batch_team_influence_frames_pt(
453
+ pos_t_a, pos_tp1_a, ball_pos, dt_secs, locs_t, dtype
454
+ )
455
+
456
+ res_batch = torch.sigmoid(Za - Zh).reshape(-1, dt, dt)
457
+ pc_all[start:end] = res_batch.cpu().numpy().astype(np.float32)
458
+
459
+ if verbose:
460
+ print(f"pt full: frames {start}-{end-1} done")
461
+
462
+ return pc_all, xx, yy
463
+
464
+ def _get_grid(self, dt: int, device: str, dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
465
+ """Helper to create a grid of locations for torch-based pitch control."""
466
+ x = torch.linspace(0, 105, dt, device=device, dtype=dtype)
467
+ y = torch.linspace(0, 68, dt, device=device, dtype=dtype)
468
+ xx_t, yy_t = torch.meshgrid(x, y, indexing="xy") # (dt,dt)
469
+ locs_t = torch.stack([xx_t, yy_t], dim=-1).reshape(-1, 2) # (N,2)
470
+ return xx_t, yy_t, locs_t
471
+
472
+ def _fit_pt(
473
+ self,
474
+ half: int,
475
+ tp: int,
476
+ dt: int = 200,
477
+ device: str = "cpu",
478
+ verbose: bool = False,
479
+ use_fp16: bool = True,
480
+ ):
481
+ """Torch-accelerated pitch-control calculation.
482
+
483
+ Returns
484
+ -------
485
+ result : np.ndarray shape (dt,dt)
486
+ xx, yy : np.ndarray meshgrids (dt,dt)
487
+ """
488
+ if torch is None:
489
+ raise ImportError("PyTorch is required for backend='torch'.")
490
+
491
+ dtype = torch.float16 if (use_fp16 and device != "cpu") else torch.float32
492
+
493
+ # ---- grid caching ----
494
+ key = (dt, device, dtype)
495
+ if key not in self._grid_cache:
496
+ xx_t, yy_t, locs_t = self._get_grid(dt, device, dtype)
497
+ # Store; keep cache small (max 3 grids)
498
+ if len(self._grid_cache) >= 3:
499
+ self._grid_cache.pop(next(iter(self._grid_cache)))
500
+ self._grid_cache[key] = (xx_t, yy_t, locs_t)
501
+ else:
502
+ xx_t, yy_t, locs_t = self._grid_cache[key]
503
+
504
+ # --- vectorised influence computation ---
505
+ Zh = self._batch_influence_pt(
506
+ self.locs_home[half], locs_t, tp, half, device, dtype
507
+ )
508
+ Za = self._batch_influence_pt(
509
+ self.locs_away[half], locs_t, tp, half, device, dtype
510
+ )
511
+
512
+ res_t = torch.sigmoid(Za - Zh).reshape(dt, dt)
513
+
514
+ # Convert to numpy for downstream plotting
515
+ return res_t.cpu().numpy(), xx_t.cpu().numpy(), yy_t.cpu().numpy()
516
+
517
+ def _fit_np(self, half: int, tp: int, dt: int, verbose: bool = False) -> tuple:
203
518
  x = np.linspace(0, 105, dt)
204
519
  y = np.linspace(0, 68, dt)
205
520
  xx, yy = np.meshgrid(x, y)
@@ -211,10 +526,10 @@ class PitchControl:
211
526
 
212
527
  for k in self.locs_home[half].keys():
213
528
  # if len(self.locs_home[half][k]) >= tp:
214
- Zh += self.influence_function(k, locations, tp, 'h', half, verbose)
529
+ Zh += self.influence_np(k, locations, tp, 'h', half, verbose)
215
530
  for k in self.locs_away[half].keys():
216
531
  # if len(self.locs_away[half][k]) >= tp:
217
- Za += self.influence_function(k, locations, tp, 'a', half, verbose)
532
+ Za += self.influence_np(k, locations, tp, 'a', half, verbose)
218
533
 
219
534
  Zh = Zh.reshape((dt, dt))
220
535
  Za = Za.reshape((dt, dt))
@@ -222,6 +537,64 @@ class PitchControl:
222
537
 
223
538
  return result, xx, yy
224
539
 
540
+ def fit(
541
+ self,
542
+ half: int,
543
+ tp: int,
544
+ dt: int = 100,
545
+ backend: str = "np",
546
+ device: str = "cpu",
547
+ verbose: bool = False,
548
+ use_fp16: bool = True,
549
+ ):
550
+ """Selects NumPy or PyTorch backend depending on `backend`."""
551
+ match backend:
552
+ case "np" | "numpy":
553
+ return self._fit_np(half, tp, dt, verbose)
554
+ case "torch" | "pt":
555
+ return self._fit_pt(half, tp, dt, device=device, verbose=verbose, use_fp16=use_fp16)
556
+ case _:
557
+ raise ValueError(f"Unknown backend '{backend}'. Use 'np' or 'torch'.")
558
+
559
+ def fit_full(
560
+ self,
561
+ half: int,
562
+ dt: int = 100,
563
+ backend: str = "np",
564
+ device: str = "cpu",
565
+ batch_size: int = 30*60,
566
+ use_fp16: bool = True,
567
+ verbose: bool = False,
568
+ ):
569
+ """Compute pitch-control map for *каждый* кадр тайма.
570
+
571
+ Returns
572
+ -------
573
+ maps : np.ndarray, shape (T, dt, dt)
574
+ Pitch-control probability for home team at every frame.
575
+ xx, yy : np.ndarray, shape (dt, dt)
576
+ Coordinate grids (общие для всех кадров).
577
+ """
578
+
579
+ T = len(self.t[half]) - 1 # мы используем t и t+1, поэтому последний кадр T-1
580
+
581
+ match backend:
582
+ case "np" | "numpy":
583
+ pc_all = np.empty((T, dt, dt), dtype=np.float32)
584
+ for tp in tqdm(range(T)):
585
+ pc_map, xx, yy = self._fit_np(half, tp, dt, verbose=False)
586
+ pc_all[tp] = pc_map.astype(np.float32)
587
+ if verbose and tp % 500 == 0:
588
+ print(f"np full-match: done {tp}/{T}")
589
+ return pc_all, xx, yy
590
+
591
+ case "torch" | "pt":
592
+ return self._fit_full_pt(
593
+ half, dt, device, batch_size, use_fp16, verbose
594
+ )
595
+ case _:
596
+ raise ValueError("backend must be 'np' or 'pt'")
597
+
225
598
  def draw_pitch_control(
226
599
  self,
227
600
  half: int,
@@ -233,8 +606,11 @@ class PitchControl:
233
606
  ):
234
607
  if pitch_control is None:
235
608
  pitch_control, xx, yy = self.fit(half, tp, dt)
609
+ else:
610
+ pitch_control, xx, yy = pitch_control
236
611
 
237
612
  fig, ax = plt.subplots(figsize=(10.5, 6.8))
613
+ # mpl.field(fieldcolor="white", linecolor="black", alpha=1, show=False, ax=ax)
238
614
  mpl.field("white", show=False, ax=ax)
239
615
 
240
616
  plt.contourf(xx, yy, pitch_control)
@@ -13,20 +13,27 @@ def process_list(x: pd.Series):
13
13
  else:
14
14
  return lst
15
15
 
16
+ def take_last(x: pd.Series):
17
+ lst = x.dropna().tolist()
18
+ if lst:
19
+ return lst[-1]
20
+ return np.nan
21
+
16
22
 
17
23
  def gluing(df: pd.DataFrame) -> pd.DataFrame:
18
24
  cols = ['player_id', 'half', 'second', 'pos_x', 'pos_y']
19
25
 
20
- df_gb = df.groupby(cols).agg(process_list).reset_index()
21
- df_gb['possession_number'] = df_gb['possession_number'].apply(
22
- lambda x: max(x) if isinstance(x, list) else x
23
- )
24
- df_gb['pos_dest_x'] = df_gb['pos_dest_x'].apply(
25
- lambda x: x[0] if isinstance(x, list) else x
26
- )
27
- df_gb['pos_dest_y'] = df_gb['pos_dest_y'].apply(
28
- lambda x: x[0] if isinstance(x, list) else x
29
- )
26
+ agg_rules = {}
27
+
28
+ for col_name in df.columns:
29
+ if col_name not in cols:
30
+ if col_name in ['action_name', 'action_id']:
31
+ agg_rules[col_name] = process_list
32
+ else:
33
+ agg_rules[col_name] = take_last
34
+
35
+ df_gb = df.groupby(cols).agg(agg_rules).reset_index()
36
+
30
37
  df_gb['pos_dest_nan'] = (df_gb['pos_dest_x'].isna() & df_gb['pos_dest_y'].isna()).astype(int)
31
38
  df_gb = df_gb.sort_values(by=['half', 'second', 'possession_number', 'pos_dest_nan']).reset_index(drop=True)
32
39
  return df_gb
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: rustat-python-api
3
- Version: 0.5.8
3
+ Version: 0.6.0
4
4
  Summary: A Python wrapper for RuStat API
5
5
  Home-page: https://github.com/dailydaniel/rustat-python-api
6
6
  Author: Daniel Zholkovsky
@@ -18,6 +18,7 @@ Requires-Dist: tqdm==4.66.5
18
18
  Requires-Dist: scipy==1.14.1
19
19
  Requires-Dist: matplotlib
20
20
  Requires-Dist: matplotsoccer
21
+ Requires-Dist: torch
21
22
 
22
23
  # rustat-python-api
23
24
 
@@ -2,11 +2,11 @@ rustat_python_api/__init__.py,sha256=Ij-PAm2y5ss_XAZhKTZus35cRPLzvXFyIswDa_Iq3rs
2
2
  rustat_python_api/config.py,sha256=eMvi1p8Cfvnbp6Cd4bBOwgehVN7thKnaQV5uzWyGZXM,1844
3
3
  rustat_python_api/models_api.py,sha256=oHXEqeCupvZwjVEdoxf7W9LP7ELFKA8-9DuRXpQHLno,1701
4
4
  rustat_python_api/parser.py,sha256=PrGvN3vY0oczMsJxyUxBO0Yb1P03Moc74AgGYHYr_X8,10216
5
- rustat_python_api/pitch_control.py,sha256=c2uX7B3IMjdKFIb3QulO3188RfzgPwT4BdR8N1XYf_w,11643
6
- rustat_python_api/processing.py,sha256=46O7wUGv5boWf4A38zYinbtqEC7ke-BwTdEg589XQaw,2816
5
+ rustat_python_api/pitch_control.py,sha256=3Sokn8yqVxO_XSMfIq9DkGBP6OkoxsYES56LxihaU-I,25598
6
+ rustat_python_api/processing.py,sha256=WES2D77uu3ScpjN0nW8ozatHteCZJQmmF9WpHPgGfJo,2835
7
7
  rustat_python_api/urls.py,sha256=iJTD31T6OyXPAhmhViwFXVehrzwsOjBDONA1SIVc_40,1068
8
- rustat_python_api-0.5.8.dist-info/LICENSE,sha256=4Cohqg5p6Mq1xyrzdEX8AvFSA62GSVvapEOr2xK_tgY,57
9
- rustat_python_api-0.5.8.dist-info/METADATA,sha256=w8nXtOwc4t5YT8MIdbYE7Zt3YScKS9fneZBYPK74Iu0,1730
10
- rustat_python_api-0.5.8.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
11
- rustat_python_api-0.5.8.dist-info/top_level.txt,sha256=VK0hmkKZE9YThxolUcoE6JtGI67NFeKJMBLuet8kI4w,18
12
- rustat_python_api-0.5.8.dist-info/RECORD,,
8
+ rustat_python_api-0.6.0.dist-info/LICENSE,sha256=4Cohqg5p6Mq1xyrzdEX8AvFSA62GSVvapEOr2xK_tgY,57
9
+ rustat_python_api-0.6.0.dist-info/METADATA,sha256=AgCVwTnDBQZ9IU5D14rStxIP3S2S0MyBVwULZMndlgc,1751
10
+ rustat_python_api-0.6.0.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
11
+ rustat_python_api-0.6.0.dist-info/top_level.txt,sha256=VK0hmkKZE9YThxolUcoE6JtGI67NFeKJMBLuet8kI4w,18
12
+ rustat_python_api-0.6.0.dist-info/RECORD,,