rustat-python-api 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rustat_python_api/parser.py +0 -2
- rustat_python_api/pitch_control.py +40 -8
- rustat_python_api/processing.py +37 -27
- {rustat_python_api-0.8.0.dist-info → rustat_python_api-0.9.0.dist-info}/METADATA +1 -1
- {rustat_python_api-0.8.0.dist-info → rustat_python_api-0.9.0.dist-info}/RECORD +8 -8
- {rustat_python_api-0.8.0.dist-info → rustat_python_api-0.9.0.dist-info}/LICENSE +0 -0
- {rustat_python_api-0.8.0.dist-info → rustat_python_api-0.9.0.dist-info}/WHEEL +0 -0
- {rustat_python_api-0.8.0.dist-info → rustat_python_api-0.9.0.dist-info}/top_level.txt +0 -0
rustat_python_api/parser.py
CHANGED
|
@@ -597,6 +597,8 @@ class PitchControl:
|
|
|
597
597
|
verbose: bool = False,
|
|
598
598
|
use_fp16: bool = True,
|
|
599
599
|
return_teams: str = "both",
|
|
600
|
+
include_players: list | None = None,
|
|
601
|
+
exclude_players: list | None = None,
|
|
600
602
|
):
|
|
601
603
|
"""Torch-accelerated pitch-control calculation.
|
|
602
604
|
|
|
@@ -621,12 +623,16 @@ class PitchControl:
|
|
|
621
623
|
else:
|
|
622
624
|
xx_t, yy_t, locs_t = self._grid_cache[key]
|
|
623
625
|
|
|
626
|
+
# --- filter players ---
|
|
627
|
+
home_dict = self._filter_players(self.locs_home[half], include_players, exclude_players)
|
|
628
|
+
away_dict = self._filter_players(self.locs_away[half], include_players, exclude_players)
|
|
629
|
+
|
|
624
630
|
# --- vectorised influence computation ---
|
|
625
631
|
Zh = self._batch_influence_pt(
|
|
626
|
-
|
|
632
|
+
home_dict, locs_t, tp, half, device, dtype
|
|
627
633
|
)
|
|
628
634
|
Za = self._batch_influence_pt(
|
|
629
|
-
|
|
635
|
+
away_dict, locs_t, tp, half, device, dtype
|
|
630
636
|
)
|
|
631
637
|
|
|
632
638
|
# Return raw team influence or combined logistic transformation
|
|
@@ -640,6 +646,14 @@ class PitchControl:
|
|
|
640
646
|
# Convert to numpy for downstream plotting
|
|
641
647
|
return res_t.cpu().numpy(), xx_t.cpu().numpy(), yy_t.cpu().numpy()
|
|
642
648
|
|
|
649
|
+
@staticmethod
|
|
650
|
+
def _filter_players(player_dict, include_players=None, exclude_players=None):
|
|
651
|
+
if include_players is not None:
|
|
652
|
+
player_dict = {k: v for k, v in player_dict.items() if k in include_players}
|
|
653
|
+
if exclude_players is not None:
|
|
654
|
+
player_dict = {k: v for k, v in player_dict.items() if k not in exclude_players}
|
|
655
|
+
return player_dict
|
|
656
|
+
|
|
643
657
|
def _fit_np(
|
|
644
658
|
self,
|
|
645
659
|
half: int,
|
|
@@ -647,6 +661,8 @@ class PitchControl:
|
|
|
647
661
|
dt: int,
|
|
648
662
|
verbose: bool = False,
|
|
649
663
|
return_teams: str = "both",
|
|
664
|
+
include_players: list | None = None,
|
|
665
|
+
exclude_players: list | None = None,
|
|
650
666
|
) -> tuple:
|
|
651
667
|
x = np.linspace(0, 105, dt)
|
|
652
668
|
y = np.linspace(0, 68, dt)
|
|
@@ -657,11 +673,12 @@ class PitchControl:
|
|
|
657
673
|
|
|
658
674
|
locations = np.c_[xx.flatten(), yy.flatten()]
|
|
659
675
|
|
|
660
|
-
|
|
661
|
-
|
|
676
|
+
home_dict = self._filter_players(self.locs_home[half], include_players, exclude_players)
|
|
677
|
+
away_dict = self._filter_players(self.locs_away[half], include_players, exclude_players)
|
|
678
|
+
|
|
679
|
+
for k in home_dict.keys():
|
|
662
680
|
Zh += self.influence_np(k, locations, tp, "h", half, verbose)
|
|
663
|
-
for k in
|
|
664
|
-
# if len(self.locs_away[half][k]) >= tp:
|
|
681
|
+
for k in away_dict.keys():
|
|
665
682
|
Za += self.influence_np(k, locations, tp, "a", half, verbose)
|
|
666
683
|
|
|
667
684
|
Zh = Zh.reshape((dt, dt))
|
|
@@ -687,11 +704,24 @@ class PitchControl:
|
|
|
687
704
|
verbose: bool = False,
|
|
688
705
|
use_fp16: bool = True,
|
|
689
706
|
return_teams: str = "both",
|
|
707
|
+
include_players: list | None = None,
|
|
708
|
+
exclude_players: list | None = None,
|
|
690
709
|
):
|
|
691
|
-
"""Selects NumPy or PyTorch backend depending on `backend`.
|
|
710
|
+
"""Selects NumPy or PyTorch backend depending on `backend`.
|
|
711
|
+
|
|
712
|
+
Parameters
|
|
713
|
+
----------
|
|
714
|
+
include_players : list of player_ids, optional
|
|
715
|
+
If provided, only these players are included in the computation.
|
|
716
|
+
exclude_players : list of player_ids, optional
|
|
717
|
+
If provided, these players are excluded from the computation.
|
|
718
|
+
"""
|
|
692
719
|
match backend:
|
|
693
720
|
case "np" | "numpy":
|
|
694
|
-
return self._fit_np(
|
|
721
|
+
return self._fit_np(
|
|
722
|
+
half, tp, dt, verbose, return_teams,
|
|
723
|
+
include_players, exclude_players,
|
|
724
|
+
)
|
|
695
725
|
case "torch" | "pt":
|
|
696
726
|
return self._fit_pt(
|
|
697
727
|
half,
|
|
@@ -701,6 +731,8 @@ class PitchControl:
|
|
|
701
731
|
verbose=verbose,
|
|
702
732
|
use_fp16=use_fp16,
|
|
703
733
|
return_teams=return_teams,
|
|
734
|
+
include_players=include_players,
|
|
735
|
+
exclude_players=exclude_players,
|
|
704
736
|
)
|
|
705
737
|
case _:
|
|
706
738
|
raise ValueError(f"Unknown backend '{backend}'. Use 'np' or 'torch'.")
|
rustat_python_api/processing.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import pandas as pd
|
|
2
1
|
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
3
|
|
|
4
4
|
from .config import columns, id2type
|
|
5
5
|
|
|
@@ -13,6 +13,7 @@ def process_list(x: pd.Series):
|
|
|
13
13
|
else:
|
|
14
14
|
return lst
|
|
15
15
|
|
|
16
|
+
|
|
16
17
|
def take_last(x: pd.Series):
|
|
17
18
|
lst = x.dropna().tolist()
|
|
18
19
|
if lst:
|
|
@@ -21,45 +22,49 @@ def take_last(x: pd.Series):
|
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
def gluing(df: pd.DataFrame) -> pd.DataFrame:
|
|
24
|
-
cols = [
|
|
25
|
+
cols = ["player_id", "half", "second", "pos_x", "pos_y"]
|
|
25
26
|
|
|
26
27
|
agg_rules = {}
|
|
27
28
|
|
|
28
29
|
for col_name in df.columns:
|
|
29
30
|
if col_name not in cols:
|
|
30
|
-
if col_name in [
|
|
31
|
+
if col_name in ["action_name", "action_id"]:
|
|
31
32
|
agg_rules[col_name] = process_list
|
|
32
33
|
else:
|
|
33
34
|
agg_rules[col_name] = take_last
|
|
34
35
|
|
|
35
36
|
df_gb = df.groupby(cols).agg(agg_rules).reset_index()
|
|
36
37
|
|
|
37
|
-
df_gb[
|
|
38
|
-
|
|
38
|
+
df_gb["pos_dest_nan"] = (
|
|
39
|
+
df_gb["pos_dest_x"].isna() & df_gb["pos_dest_y"].isna()
|
|
40
|
+
).astype(int)
|
|
41
|
+
df_gb = df_gb.sort_values(
|
|
42
|
+
by=["half", "second", "possession_number", "pos_dest_nan"]
|
|
43
|
+
).reset_index(drop=True)
|
|
39
44
|
return df_gb
|
|
40
45
|
|
|
41
46
|
|
|
42
47
|
def add_reciever(glued_df: pd.DataFrame) -> pd.DataFrame:
|
|
43
48
|
df = glued_df.copy()
|
|
44
|
-
df[
|
|
45
|
-
df[
|
|
49
|
+
df["receiver_id"] = df["player_id"].shift(-1)
|
|
50
|
+
df["receiver_name"] = df["player_name"].shift(-1)
|
|
46
51
|
|
|
47
52
|
mask = (
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
53
|
+
(df["action_name"] == "Ball receiving")
|
|
54
|
+
& (df["pos_x"] == df["pos_dest_x"].shift(1))
|
|
55
|
+
& (df["pos_y"] == df["pos_dest_y"].shift(1))
|
|
56
|
+
& (df["team_id"] == df["team_id"].shift(1))
|
|
57
|
+
& (df["player_id"] != df["player_id"].shift(1))
|
|
58
|
+
& (df["possession_number"] == df["possession_number"].shift(1))
|
|
54
59
|
)
|
|
55
60
|
|
|
56
61
|
idx = df[mask].index
|
|
57
|
-
remaining_idx = df.drop(idx-1).index
|
|
62
|
+
remaining_idx = df.drop(idx - 1).index
|
|
58
63
|
|
|
59
|
-
df.loc[remaining_idx,
|
|
60
|
-
df.loc[remaining_idx,
|
|
64
|
+
df.loc[remaining_idx, "receiver_id"] = np.nan
|
|
65
|
+
df.loc[remaining_idx, "receiver_name"] = np.nan
|
|
61
66
|
|
|
62
|
-
df = df[df[
|
|
67
|
+
df = df[df["action_name"] != "Ball receiving"].reset_index(drop=True)
|
|
63
68
|
|
|
64
69
|
return df
|
|
65
70
|
|
|
@@ -69,20 +74,23 @@ def filter_data(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
69
74
|
if column not in df.columns:
|
|
70
75
|
df[column] = np.nan
|
|
71
76
|
|
|
72
|
-
return df[(~df[
|
|
77
|
+
return df[(~df["possession_number"].isna()) | (df["second"] != 0)][
|
|
78
|
+
columns
|
|
79
|
+
].reset_index(drop=True)
|
|
73
80
|
|
|
74
81
|
|
|
75
82
|
def tagging(df: pd.DataFrame) -> pd.DataFrame:
|
|
76
|
-
df = df.rename(columns={
|
|
77
|
-
df[
|
|
78
|
-
df[
|
|
79
|
-
lambda x:
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
+
df = df.rename(columns={"action_name": "sub_tags", "action_id": "sub_tags_ids"})
|
|
84
|
+
df["sub_tags"] = df["sub_tags"].apply(lambda x: x if isinstance(x, list) else [x])
|
|
85
|
+
df["sub_tags_ids"] = df["sub_tags_ids"].apply(
|
|
86
|
+
lambda x: (
|
|
87
|
+
list(set([int(t) // 1000 for t in x]))
|
|
88
|
+
if isinstance(x, list)
|
|
89
|
+
else [int(x) // 1000]
|
|
90
|
+
)
|
|
83
91
|
)
|
|
84
|
-
df[
|
|
85
|
-
df = df.rename(columns={
|
|
92
|
+
df["sub_tags_ids"] = df["sub_tags_ids"].apply(lambda x: [id2type[t] for t in x])
|
|
93
|
+
df = df.rename(columns={"sub_tags_ids": "tags"})
|
|
86
94
|
|
|
87
95
|
return df
|
|
88
96
|
|
|
@@ -93,4 +101,6 @@ def processing(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
93
101
|
df = filter_data(df)
|
|
94
102
|
df = tagging(df)
|
|
95
103
|
|
|
104
|
+
df["event_chron_id"] = range(1, len(df) + 1)
|
|
105
|
+
|
|
96
106
|
return df
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
rustat_python_api/__init__.py,sha256=Ij-PAm2y5ss_XAZhKTZus35cRPLzvXFyIswDa_Iq3rs,164
|
|
2
2
|
rustat_python_api/config.py,sha256=eMvi1p8Cfvnbp6Cd4bBOwgehVN7thKnaQV5uzWyGZXM,1844
|
|
3
3
|
rustat_python_api/models_api.py,sha256=oHXEqeCupvZwjVEdoxf7W9LP7ELFKA8-9DuRXpQHLno,1701
|
|
4
|
-
rustat_python_api/parser.py,sha256=
|
|
5
|
-
rustat_python_api/pitch_control.py,sha256=
|
|
6
|
-
rustat_python_api/processing.py,sha256=
|
|
4
|
+
rustat_python_api/parser.py,sha256=hMPZER6CwPk_dm8GkLulq6G0_jwOyRiZjbbDw3eO1Rc,10863
|
|
5
|
+
rustat_python_api/pitch_control.py,sha256=XSakx2KxDAvv17qFoBtsBBIEYGn3OqI_3q3Xfq3_OoA,31688
|
|
6
|
+
rustat_python_api/processing.py,sha256=sjZdjSs2BKyM_f7T_Utu1CJSF1l-V8UsWrAj7dhqyNg,2930
|
|
7
7
|
rustat_python_api/urls.py,sha256=iJTD31T6OyXPAhmhViwFXVehrzwsOjBDONA1SIVc_40,1068
|
|
8
8
|
rustat_python_api/kernels/__init__.py,sha256=eFJ-BMY8VcNZSjf3XjOnZf_nfOQ5t-7Lp57DPCHYOo0,124
|
|
9
9
|
rustat_python_api/kernels/maha.py,sha256=k2PqY6VghgER2j9QH8xGYq61JLfPaHjirLXb4aLnjQw,2591
|
|
@@ -11,8 +11,8 @@ rustat_python_api/matching/__init__.py,sha256=QIqkft_NVban8p0rxJsim3x4defZN3eQBq
|
|
|
11
11
|
rustat_python_api/matching/dataloader.py,sha256=FZ95ZC5Z-6FrZEVy-d5S1rXt5nIBwYMqTT0hvFlc4mI,2753
|
|
12
12
|
rustat_python_api/matching/pc_adder.py,sha256=oP84iplIqmqF1jFDN9ernkjKw7wgsmWbuOVPEsoTNWM,7582
|
|
13
13
|
rustat_python_api/matching/tr_adder.py,sha256=0sAjYfoLitScDhg2yjh2PIKWTseoaSeUa5chPnVau-Q,16025
|
|
14
|
-
rustat_python_api-0.
|
|
15
|
-
rustat_python_api-0.
|
|
16
|
-
rustat_python_api-0.
|
|
17
|
-
rustat_python_api-0.
|
|
18
|
-
rustat_python_api-0.
|
|
14
|
+
rustat_python_api-0.9.0.dist-info/LICENSE,sha256=4Cohqg5p6Mq1xyrzdEX8AvFSA62GSVvapEOr2xK_tgY,57
|
|
15
|
+
rustat_python_api-0.9.0.dist-info/METADATA,sha256=9_I5OFJ6m9tWDX4g5elAfjNgXpAIbLyJ6kdrVHEm9aE,1920
|
|
16
|
+
rustat_python_api-0.9.0.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
|
17
|
+
rustat_python_api-0.9.0.dist-info/top_level.txt,sha256=VK0hmkKZE9YThxolUcoE6JtGI67NFeKJMBLuet8kI4w,18
|
|
18
|
+
rustat_python_api-0.9.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|