rustat-python-api 0.6.0__tar.gz → 0.6.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rustat-python-api-0.6.0/rustat_python_api.egg-info → rustat-python-api-0.6.2}/PKG-INFO +1 -1
- rustat-python-api-0.6.2/rustat_python_api/kernels/__init__.py +3 -0
- rustat-python-api-0.6.2/rustat_python_api/kernels/maha.py +72 -0
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/rustat_python_api/pitch_control.py +36 -16
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2/rustat_python_api.egg-info}/PKG-INFO +1 -1
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/rustat_python_api.egg-info/SOURCES.txt +3 -1
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/rustat_python_api.egg-info/requires.txt +1 -0
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/setup.py +3 -2
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/LICENSE +0 -0
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/README.md +0 -0
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/pyproject.toml +0 -0
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/rustat_python_api/__init__.py +0 -0
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/rustat_python_api/config.py +0 -0
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/rustat_python_api/models_api.py +0 -0
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/rustat_python_api/parser.py +0 -0
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/rustat_python_api/processing.py +0 -0
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/rustat_python_api/urls.py +0 -0
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/rustat_python_api.egg-info/dependency_links.txt +0 -0
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/rustat_python_api.egg-info/top_level.txt +0 -0
- {rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/setup.cfg +0 -0
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import triton
|
|
3
|
+
import triton.language as tl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@triton.jit
|
|
7
|
+
def influence_kernel(
|
|
8
|
+
MU, # (F,P,2)
|
|
9
|
+
SIGMA_INV, # (F,P,2,2)
|
|
10
|
+
LOCS, # (N,2) – read-only
|
|
11
|
+
OUT, # (F,N)
|
|
12
|
+
F, P, N, # sizes
|
|
13
|
+
BLOCK_N: tl.constexpr, # кол-во точек на блок
|
|
14
|
+
):
|
|
15
|
+
pid_f = tl.program_id(0) # какой frame
|
|
16
|
+
offs_n = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) # chunk точек
|
|
17
|
+
mask_n = offs_n < N
|
|
18
|
+
|
|
19
|
+
# ----- загрузим координаты сетки (N,2) -----
|
|
20
|
+
loc_ptr = LOCS + offs_n[:, None] * 2 # strides=(2,)
|
|
21
|
+
x = tl.load(loc_ptr, mask=mask_n[:, None]) # (BLOCK_N,2)
|
|
22
|
+
|
|
23
|
+
# аккум для суммирования влияний
|
|
24
|
+
acc = tl.zeros([BLOCK_N], dtype=tl.float32)
|
|
25
|
+
|
|
26
|
+
for p in tl.static_range(0, 32):
|
|
27
|
+
is_valid = p < P
|
|
28
|
+
|
|
29
|
+
# --- pointers ---
|
|
30
|
+
mu_ptr = MU + pid_f * P * 2 + p * 2
|
|
31
|
+
s_ptr = SIGMA_INV + pid_f * P * 4 + p * 4
|
|
32
|
+
|
|
33
|
+
# --- scalar loads ---
|
|
34
|
+
mu0 = tl.load(mu_ptr + 0, mask=is_valid)
|
|
35
|
+
mu1 = tl.load(mu_ptr + 1, mask=is_valid)
|
|
36
|
+
|
|
37
|
+
s00 = tl.load(s_ptr + 0, mask=is_valid)
|
|
38
|
+
s01 = tl.load(s_ptr + 1, mask=is_valid)
|
|
39
|
+
s10 = tl.load(s_ptr + 2, mask=is_valid)
|
|
40
|
+
s11 = tl.load(s_ptr + 3, mask=is_valid)
|
|
41
|
+
|
|
42
|
+
# --- grid coords (also scalar) ---
|
|
43
|
+
x0 = tl.load(LOCS + offs_n * 2 + 0, mask=mask_n) # (BLOCK_N,)
|
|
44
|
+
x1 = tl.load(LOCS + offs_n * 2 + 1, mask=mask_n)
|
|
45
|
+
|
|
46
|
+
diff0 = x0 - mu0
|
|
47
|
+
diff1 = x1 - mu1
|
|
48
|
+
|
|
49
|
+
tmp0 = diff0 * s00 + diff1 * s01
|
|
50
|
+
tmp1 = diff0 * s10 + diff1 * s11
|
|
51
|
+
maha = diff0 * tmp0 + diff1 * tmp1
|
|
52
|
+
|
|
53
|
+
acc += tl.where(is_valid, tl.exp(-0.5 * maha), 0.0)
|
|
54
|
+
|
|
55
|
+
# запишем
|
|
56
|
+
out_ptr = OUT + pid_f*N + offs_n
|
|
57
|
+
tl.store(out_ptr, acc, mask=mask_n)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def triton_influence(mu, sigma_inv, locs, BLOCK_N=64):
|
|
61
|
+
F, P, _ = mu.shape
|
|
62
|
+
N = locs.shape[0]
|
|
63
|
+
out = torch.empty((F, N), device=mu.device, dtype=mu.dtype)
|
|
64
|
+
|
|
65
|
+
grid = (F, triton.cdiv(N, BLOCK_N))
|
|
66
|
+
influence_kernel[grid](
|
|
67
|
+
mu, sigma_inv, locs, out,
|
|
68
|
+
F, P, N,
|
|
69
|
+
BLOCK_N=BLOCK_N,
|
|
70
|
+
num_warps=4
|
|
71
|
+
)
|
|
72
|
+
return out
|
|
@@ -7,6 +7,8 @@ import matplotsoccer as mpl
|
|
|
7
7
|
import torch
|
|
8
8
|
from tqdm import tqdm
|
|
9
9
|
|
|
10
|
+
from .kernels import triton_influence
|
|
11
|
+
|
|
10
12
|
|
|
11
13
|
class PitchControl:
|
|
12
14
|
def __init__(self, tracking: pd.DataFrame, events: pd.DataFrame, ball_data: pd.DataFrame = None):
|
|
@@ -320,15 +322,19 @@ class PitchControl:
|
|
|
320
322
|
diff = locs.view(1, -1, 2) # (1,N,2)
|
|
321
323
|
diff = diff - mu.unsqueeze(1) # (P,N,2)
|
|
322
324
|
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
325
|
+
if device == "cuda":
|
|
326
|
+
out = triton_influence(
|
|
327
|
+
mu.unsqueeze(0), Sigma_inv.unsqueeze(0),
|
|
328
|
+
locs, BLOCK_N=64
|
|
329
|
+
)[0] # (N,)
|
|
328
330
|
|
|
329
|
-
|
|
331
|
+
return out
|
|
332
|
+
else:
|
|
333
|
+
maha = torch.einsum('pni,pij,pnj->pn', diff, Sigma_inv, diff) # (P,N)
|
|
334
|
+
maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
|
|
335
|
+
out = torch.exp(-0.5 * maha) # (P,N)
|
|
330
336
|
|
|
331
|
-
|
|
337
|
+
return out.sum(dim=0) # sum over players
|
|
332
338
|
|
|
333
339
|
def _batch_team_influence_frames_pt(
|
|
334
340
|
self,
|
|
@@ -385,24 +391,38 @@ class PitchControl:
|
|
|
385
391
|
diff = locs.view(1, 1, -1, 2) # (1,1,N,2)
|
|
386
392
|
diff = diff - mu.unsqueeze(2) # (F,P,N,2)
|
|
387
393
|
|
|
388
|
-
|
|
394
|
+
if device == "cuda":
|
|
395
|
+
out = triton_influence(mu, Sigma_inv, locs, BLOCK_N=64) # (F,N)
|
|
389
396
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
397
|
+
return out
|
|
398
|
+
else:
|
|
399
|
+
maha = torch.einsum('fpni,fpij,fpnj->fpn', diff, Sigma_inv, diff) # (F,P,N)
|
|
400
|
+
maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
|
|
401
|
+
out = torch.exp(-0.5 * maha) # (F,P,N)
|
|
395
402
|
|
|
396
|
-
|
|
403
|
+
return out.sum(dim=1) # sum over players
|
|
397
404
|
|
|
398
405
|
@staticmethod
|
|
399
406
|
def _stack_team_frames(players: list[np.ndarray], frames: np.ndarray, device: str, dtype: torch.dtype):
|
|
400
407
|
"""Stack positions for given frames into torch tensors (pos_t, pos_tp1)."""
|
|
408
|
+
# Ensure every player's trajectory is long enough; if not, pad by repeating
|
|
409
|
+
# the last available coordinate so that indexing `frames` and `frames+1` is safe.
|
|
410
|
+
max_needed = frames[-1] + 1 # we access idx and idx+1
|
|
411
|
+
|
|
412
|
+
padded = []
|
|
413
|
+
for p in players:
|
|
414
|
+
if len(p) <= max_needed:
|
|
415
|
+
pad_len = max_needed + 1 - len(p)
|
|
416
|
+
if pad_len > 0:
|
|
417
|
+
last = p[-1][None, :]
|
|
418
|
+
p = np.vstack([p, np.repeat(last, pad_len, axis=0)])
|
|
419
|
+
padded.append(p)
|
|
420
|
+
|
|
401
421
|
pos_t = torch.tensor(
|
|
402
|
-
np.stack([p[frames] for p in
|
|
422
|
+
np.stack([p[frames] for p in padded], axis=1), device=device, dtype=dtype
|
|
403
423
|
) # (F,P,2)
|
|
404
424
|
pos_tp1 = torch.tensor(
|
|
405
|
-
np.stack([p[frames + 1] for p in
|
|
425
|
+
np.stack([p[frames + 1] for p in padded], axis=1), device=device, dtype=dtype
|
|
406
426
|
)
|
|
407
427
|
return pos_t, pos_tp1
|
|
408
428
|
|
|
@@ -13,4 +13,6 @@ rustat_python_api.egg-info/PKG-INFO
|
|
|
13
13
|
rustat_python_api.egg-info/SOURCES.txt
|
|
14
14
|
rustat_python_api.egg-info/dependency_links.txt
|
|
15
15
|
rustat_python_api.egg-info/requires.txt
|
|
16
|
-
rustat_python_api.egg-info/top_level.txt
|
|
16
|
+
rustat_python_api.egg-info/top_level.txt
|
|
17
|
+
rustat_python_api/kernels/__init__.py
|
|
18
|
+
rustat_python_api/kernels/maha.py
|
|
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
|
|
|
2
2
|
|
|
3
3
|
setup(
|
|
4
4
|
name='rustat-python-api',
|
|
5
|
-
version='0.6.
|
|
5
|
+
version='0.6.2',
|
|
6
6
|
description='A Python wrapper for RuStat API',
|
|
7
7
|
long_description=open('README.md').read(),
|
|
8
8
|
long_description_content_type='text/markdown',
|
|
@@ -18,7 +18,8 @@ setup(
|
|
|
18
18
|
'scipy==1.14.1',
|
|
19
19
|
'matplotlib',
|
|
20
20
|
'matplotsoccer',
|
|
21
|
-
'torch'
|
|
21
|
+
'torch',
|
|
22
|
+
'triton==3.0.0',
|
|
22
23
|
],
|
|
23
24
|
classifiers=[
|
|
24
25
|
'Programming Language :: Python :: 3',
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/rustat_python_api.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
{rustat-python-api-0.6.0 → rustat-python-api-0.6.2}/rustat_python_api.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|