rustat-python-api 0.6.1__tar.gz → 0.6.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rustat-python-api-0.6.1/rustat_python_api.egg-info → rustat-python-api-0.6.2}/PKG-INFO +1 -1
- rustat-python-api-0.6.2/rustat_python_api/kernels/__init__.py +3 -0
- rustat-python-api-0.6.2/rustat_python_api/kernels/maha.py +72 -0
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/rustat_python_api/pitch_control.py +21 -14
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2/rustat_python_api.egg-info}/PKG-INFO +1 -1
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/rustat_python_api.egg-info/SOURCES.txt +3 -1
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/rustat_python_api.egg-info/requires.txt +1 -0
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/setup.py +3 -2
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/LICENSE +0 -0
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/README.md +0 -0
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/pyproject.toml +0 -0
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/rustat_python_api/__init__.py +0 -0
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/rustat_python_api/config.py +0 -0
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/rustat_python_api/models_api.py +0 -0
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/rustat_python_api/parser.py +0 -0
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/rustat_python_api/processing.py +0 -0
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/rustat_python_api/urls.py +0 -0
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/rustat_python_api.egg-info/dependency_links.txt +0 -0
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/rustat_python_api.egg-info/top_level.txt +0 -0
- {rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/setup.cfg +0 -0
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@triton.jit
|
|
7
|
+
def influence_kernel(
|
|
8
|
+
MU, # (F,P,2)
|
|
9
|
+
SIGMA_INV, # (F,P,2,2)
|
|
10
|
+
LOCS, # (N,2) – read-only
|
|
11
|
+
OUT, # (F,N)
|
|
12
|
+
F, P, N, # sizes
|
|
13
|
+
BLOCK_N: tl.constexpr, # кол-во точек на блок
|
|
14
|
+
):
|
|
15
|
+
pid_f = tl.program_id(0) # какой frame
|
|
16
|
+
offs_n = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) # chunk точек
|
|
17
|
+
mask_n = offs_n < N
|
|
18
|
+
|
|
19
|
+
# ----- загрузим координаты сетки (N,2) -----
|
|
20
|
+
loc_ptr = LOCS + offs_n[:, None] * 2 # strides=(2,)
|
|
21
|
+
x = tl.load(loc_ptr, mask=mask_n[:, None]) # (BLOCK_N,2)
|
|
22
|
+
|
|
23
|
+
# аккум для суммирования влияний
|
|
24
|
+
acc = tl.zeros([BLOCK_N], dtype=tl.float32)
|
|
25
|
+
|
|
26
|
+
for p in tl.static_range(0, 32):
|
|
27
|
+
is_valid = p < P
|
|
28
|
+
|
|
29
|
+
# --- pointers ---
|
|
30
|
+
mu_ptr = MU + pid_f * P * 2 + p * 2
|
|
31
|
+
s_ptr = SIGMA_INV + pid_f * P * 4 + p * 4
|
|
32
|
+
|
|
33
|
+
# --- scalar loads ---
|
|
34
|
+
mu0 = tl.load(mu_ptr + 0, mask=is_valid)
|
|
35
|
+
mu1 = tl.load(mu_ptr + 1, mask=is_valid)
|
|
36
|
+
|
|
37
|
+
s00 = tl.load(s_ptr + 0, mask=is_valid)
|
|
38
|
+
s01 = tl.load(s_ptr + 1, mask=is_valid)
|
|
39
|
+
s10 = tl.load(s_ptr + 2, mask=is_valid)
|
|
40
|
+
s11 = tl.load(s_ptr + 3, mask=is_valid)
|
|
41
|
+
|
|
42
|
+
# --- grid coords (also scalar) ---
|
|
43
|
+
x0 = tl.load(LOCS + offs_n * 2 + 0, mask=mask_n) # (BLOCK_N,)
|
|
44
|
+
x1 = tl.load(LOCS + offs_n * 2 + 1, mask=mask_n)
|
|
45
|
+
|
|
46
|
+
diff0 = x0 - mu0
|
|
47
|
+
diff1 = x1 - mu1
|
|
48
|
+
|
|
49
|
+
tmp0 = diff0 * s00 + diff1 * s01
|
|
50
|
+
tmp1 = diff0 * s10 + diff1 * s11
|
|
51
|
+
maha = diff0 * tmp0 + diff1 * tmp1
|
|
52
|
+
|
|
53
|
+
acc += tl.where(is_valid, tl.exp(-0.5 * maha), 0.0)
|
|
54
|
+
|
|
55
|
+
# запишем
|
|
56
|
+
out_ptr = OUT + pid_f*N + offs_n
|
|
57
|
+
tl.store(out_ptr, acc, mask=mask_n)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def triton_influence(mu, sigma_inv, locs, BLOCK_N=64):
|
|
61
|
+
F, P, _ = mu.shape
|
|
62
|
+
N = locs.shape[0]
|
|
63
|
+
out = torch.empty((F, N), device=mu.device, dtype=mu.dtype)
|
|
64
|
+
|
|
65
|
+
grid = (F, triton.cdiv(N, BLOCK_N))
|
|
66
|
+
influence_kernel[grid](
|
|
67
|
+
mu, sigma_inv, locs, out,
|
|
68
|
+
F, P, N,
|
|
69
|
+
BLOCK_N=BLOCK_N,
|
|
70
|
+
num_warps=4
|
|
71
|
+
)
|
|
72
|
+
return out
|
|
@@ -7,6 +7,8 @@ import matplotsoccer as mpl
|
|
|
7
7
|
import torch
|
|
8
8
|
from tqdm import tqdm
|
|
9
9
|
|
|
10
|
+
from .kernels import triton_influence
|
|
11
|
+
|
|
10
12
|
|
|
11
13
|
class PitchControl:
|
|
12
14
|
def __init__(self, tracking: pd.DataFrame, events: pd.DataFrame, ball_data: pd.DataFrame = None):
|
|
@@ -320,15 +322,19 @@ class PitchControl:
|
|
|
320
322
|
diff = locs.view(1, -1, 2) # (1,N,2)
|
|
321
323
|
diff = diff - mu.unsqueeze(1) # (P,N,2)
|
|
322
324
|
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
325
|
+
if device == "cuda":
|
|
326
|
+
out = triton_influence(
|
|
327
|
+
mu.unsqueeze(0), Sigma_inv.unsqueeze(0),
|
|
328
|
+
locs, BLOCK_N=64
|
|
329
|
+
)[0] # (N,)
|
|
328
330
|
|
|
329
|
-
|
|
331
|
+
return out
|
|
332
|
+
else:
|
|
333
|
+
maha = torch.einsum('pni,pij,pnj->pn', diff, Sigma_inv, diff) # (P,N)
|
|
334
|
+
maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
|
|
335
|
+
out = torch.exp(-0.5 * maha) # (P,N)
|
|
330
336
|
|
|
331
|
-
|
|
337
|
+
return out.sum(dim=0) # sum over players
|
|
332
338
|
|
|
333
339
|
def _batch_team_influence_frames_pt(
|
|
334
340
|
self,
|
|
@@ -385,15 +391,16 @@ class PitchControl:
|
|
|
385
391
|
diff = locs.view(1, 1, -1, 2) # (1,1,N,2)
|
|
386
392
|
diff = diff - mu.unsqueeze(2) # (F,P,N,2)
|
|
387
393
|
|
|
388
|
-
|
|
394
|
+
if device == "cuda":
|
|
395
|
+
out = triton_influence(mu, Sigma_inv, locs, BLOCK_N=64) # (F,N)
|
|
389
396
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
397
|
+
return out
|
|
398
|
+
else:
|
|
399
|
+
maha = torch.einsum('fpni,fpij,fpnj->fpn', diff, Sigma_inv, diff) # (F,P,N)
|
|
400
|
+
maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
|
|
401
|
+
out = torch.exp(-0.5 * maha) # (F,P,N)
|
|
395
402
|
|
|
396
|
-
|
|
403
|
+
return out.sum(dim=1) # sum over players
|
|
397
404
|
|
|
398
405
|
@staticmethod
|
|
399
406
|
def _stack_team_frames(players: list[np.ndarray], frames: np.ndarray, device: str, dtype: torch.dtype):
|
|
@@ -13,4 +13,6 @@ rustat_python_api.egg-info/PKG-INFO
|
|
|
13
13
|
rustat_python_api.egg-info/SOURCES.txt
|
|
14
14
|
rustat_python_api.egg-info/dependency_links.txt
|
|
15
15
|
rustat_python_api.egg-info/requires.txt
|
|
16
|
-
rustat_python_api.egg-info/top_level.txt
|
|
16
|
+
rustat_python_api.egg-info/top_level.txt
|
|
17
|
+
rustat_python_api/kernels/__init__.py
|
|
18
|
+
rustat_python_api/kernels/maha.py
|
|
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|
|
2
2
|
|
|
3
3
|
setup(
|
|
4
4
|
name='rustat-python-api',
|
|
5
|
-
version='0.6.
|
|
5
|
+
version='0.6.2',
|
|
6
6
|
description='A Python wrapper for RuStat API',
|
|
7
7
|
long_description=open('README.md').read(),
|
|
8
8
|
long_description_content_type='text/markdown',
|
|
@@ -18,7 +18,8 @@ setup(
|
|
|
18
18
|
'scipy==1.14.1',
|
|
19
19
|
'matplotlib',
|
|
20
20
|
'matplotsoccer',
|
|
21
|
-
'torch'
|
|
21
|
+
'torch',
|
|
22
|
+
'triton==3.0.0',
|
|
22
23
|
],
|
|
23
24
|
classifiers=[
|
|
24
25
|
'Programming Language :: Python :: 3',
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/rustat_python_api.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{rustat-python-api-0.6.1 → rustat-python-api-0.6.2}/rustat_python_api.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|