rustat-python-api 0.6.1__tar.gz → 0.6.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. {rustat-python-api-0.6.1/rustat_python_api.egg-info → rustat-python-api-0.6.3}/PKG-INFO +1 -1
  2. rustat-python-api-0.6.3/rustat_python_api/kernels/__init__.py +3 -0
  3. rustat-python-api-0.6.3/rustat_python_api/kernels/maha.py +72 -0
  4. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/rustat_python_api/pitch_control.py +21 -14
  5. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3/rustat_python_api.egg-info}/PKG-INFO +1 -1
  6. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/rustat_python_api.egg-info/SOURCES.txt +3 -1
  7. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/rustat_python_api.egg-info/requires.txt +1 -0
  8. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/setup.py +3 -2
  9. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/LICENSE +0 -0
  10. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/README.md +0 -0
  11. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/pyproject.toml +0 -0
  12. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/rustat_python_api/__init__.py +0 -0
  13. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/rustat_python_api/config.py +0 -0
  14. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/rustat_python_api/models_api.py +0 -0
  15. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/rustat_python_api/parser.py +0 -0
  16. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/rustat_python_api/processing.py +0 -0
  17. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/rustat_python_api/urls.py +0 -0
  18. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/rustat_python_api.egg-info/dependency_links.txt +0 -0
  19. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/rustat_python_api.egg-info/top_level.txt +0 -0
  20. {rustat-python-api-0.6.1 → rustat-python-api-0.6.3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: rustat-python-api
3
- Version: 0.6.1
3
+ Version: 0.6.3
4
4
  Summary: A Python wrapper for RuStat API
5
5
  Home-page: https://github.com/dailydaniel/rustat-python-api
6
6
  Author: Daniel Zholkovsky
@@ -0,0 +1,3 @@
1
+ from .maha import triton_influence
2
+
3
+ __all__ = ["triton_influence"]
@@ -0,0 +1,72 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def influence_kernel(
8
+ MU, # (F,P,2)
9
+ SIGMA_INV, # (F,P,2,2)
10
+ LOCS, # (N,2) – read-only
11
+ OUT, # (F,N)
12
+ F, P, N, # sizes
13
+ BLOCK_N: tl.constexpr, # кол-во точек на блок
14
+ ):
15
+ pid_f = tl.program_id(0) # какой frame
16
+ offs_n = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) # chunk точек
17
+ mask_n = offs_n < N
18
+
19
+ # ----- загрузим координаты сетки (N,2) -----
20
+ loc_ptr = LOCS + offs_n[:, None] * 2 # strides=(2,)
21
+ x = tl.load(loc_ptr, mask=mask_n[:, None]) # (BLOCK_N,2)
22
+
23
+ # аккум для суммирования влияний
24
+ acc = tl.zeros([BLOCK_N], dtype=tl.float32)
25
+
26
+ for p in tl.static_range(0, 32):
27
+ is_valid = p < P
28
+
29
+ # --- pointers ---
30
+ mu_ptr = MU + pid_f * P * 2 + p * 2
31
+ s_ptr = SIGMA_INV + pid_f * P * 4 + p * 4
32
+
33
+ # --- scalar loads ---
34
+ mu0 = tl.load(mu_ptr + 0, mask=is_valid)
35
+ mu1 = tl.load(mu_ptr + 1, mask=is_valid)
36
+
37
+ s00 = tl.load(s_ptr + 0, mask=is_valid)
38
+ s01 = tl.load(s_ptr + 1, mask=is_valid)
39
+ s10 = tl.load(s_ptr + 2, mask=is_valid)
40
+ s11 = tl.load(s_ptr + 3, mask=is_valid)
41
+
42
+ # --- grid coords (also scalar) ---
43
+ x0 = tl.load(LOCS + offs_n * 2 + 0, mask=mask_n) # (BLOCK_N,)
44
+ x1 = tl.load(LOCS + offs_n * 2 + 1, mask=mask_n)
45
+
46
+ diff0 = x0 - mu0
47
+ diff1 = x1 - mu1
48
+
49
+ tmp0 = diff0 * s00 + diff1 * s01
50
+ tmp1 = diff0 * s10 + diff1 * s11
51
+ maha = diff0 * tmp0 + diff1 * tmp1
52
+
53
+ acc += tl.where(is_valid, tl.exp(-0.5 * maha), 0.0)
54
+
55
+ # запишем
56
+ out_ptr = OUT + pid_f*N + offs_n
57
+ tl.store(out_ptr, acc, mask=mask_n)
58
+
59
+
60
+ def triton_influence(mu, sigma_inv, locs, BLOCK_N=64):
61
+ F, P, _ = mu.shape
62
+ N = locs.shape[0]
63
+ out = torch.empty((F, N), device=mu.device, dtype=mu.dtype)
64
+
65
+ grid = (F, triton.cdiv(N, BLOCK_N))
66
+ influence_kernel[grid](
67
+ mu, sigma_inv, locs, out,
68
+ F, P, N,
69
+ BLOCK_N=BLOCK_N,
70
+ num_warps=4
71
+ )
72
+ return out
@@ -7,6 +7,8 @@ import matplotsoccer as mpl
7
7
  import torch
8
8
  from tqdm import tqdm
9
9
 
10
+ from .kernels import triton_influence
11
+
10
12
 
11
13
  class PitchControl:
12
14
  def __init__(self, tracking: pd.DataFrame, events: pd.DataFrame, ball_data: pd.DataFrame = None):
@@ -320,15 +322,19 @@ class PitchControl:
320
322
  diff = locs.view(1, -1, 2) # (1,N,2)
321
323
  diff = diff - mu.unsqueeze(1) # (P,N,2)
322
324
 
323
- maha = torch.einsum('pni,pij,pnj->pn', diff, Sigma_inv, diff) # (P,N)
324
-
325
- # Replace NaNs that arise from invalid player positions with large value
326
- # so their influence tends to zero after exponent, then eliminate residual NaNs.
327
- maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
325
+ if "cuda" in device:
326
+ out = triton_influence(
327
+ mu.unsqueeze(0), Sigma_inv.unsqueeze(0),
328
+ locs, BLOCK_N=64
329
+ )[0] # (N,)
328
330
 
329
- out = torch.exp(-0.5 * maha) # (P,N)
331
+ return out
332
+ else:
333
+ maha = torch.einsum('pni,pij,pnj->pn', diff, Sigma_inv, diff) # (P,N)
334
+ maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
335
+ out = torch.exp(-0.5 * maha) # (P,N)
330
336
 
331
- return out.sum(dim=0) # sum over players
337
+ return out.sum(dim=0) # sum over players
332
338
 
333
339
  def _batch_team_influence_frames_pt(
334
340
  self,
@@ -385,15 +391,16 @@ class PitchControl:
385
391
  diff = locs.view(1, 1, -1, 2) # (1,1,N,2)
386
392
  diff = diff - mu.unsqueeze(2) # (F,P,N,2)
387
393
 
388
- maha = torch.einsum('fpni,fpij,fpnj->fpn', diff, Sigma_inv, diff) # (F,P,N)
394
+ if "cuda" in device:
395
+ out = triton_influence(mu, Sigma_inv, locs, BLOCK_N=64) # (F,N)
389
396
 
390
- # Replace NaNs that arise from invalid player positions with large value
391
- # so their influence tends to zero after exponent, then eliminate residual NaNs.
392
- maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
393
-
394
- out = torch.exp(-0.5 * maha) # (F,P,N)
397
+ return out
398
+ else:
399
+ maha = torch.einsum('fpni,fpij,fpnj->fpn', diff, Sigma_inv, diff) # (F,P,N)
400
+ maha = torch.nan_to_num(maha, nan=1e9, posinf=1e9, neginf=1e9)
401
+ out = torch.exp(-0.5 * maha) # (F,P,N)
395
402
 
396
- return out.sum(dim=1) # sum over players
403
+ return out.sum(dim=1) # sum over players
397
404
 
398
405
  @staticmethod
399
406
  def _stack_team_frames(players: list[np.ndarray], frames: np.ndarray, device: str, dtype: torch.dtype):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: rustat-python-api
3
- Version: 0.6.1
3
+ Version: 0.6.3
4
4
  Summary: A Python wrapper for RuStat API
5
5
  Home-page: https://github.com/dailydaniel/rustat-python-api
6
6
  Author: Daniel Zholkovsky
@@ -13,4 +13,6 @@ rustat_python_api.egg-info/PKG-INFO
13
13
  rustat_python_api.egg-info/SOURCES.txt
14
14
  rustat_python_api.egg-info/dependency_links.txt
15
15
  rustat_python_api.egg-info/requires.txt
16
- rustat_python_api.egg-info/top_level.txt
16
+ rustat_python_api.egg-info/top_level.txt
17
+ rustat_python_api/kernels/__init__.py
18
+ rustat_python_api/kernels/maha.py
@@ -5,3 +5,4 @@ scipy==1.14.1
5
5
  matplotlib
6
6
  matplotsoccer
7
7
  torch
8
+ triton==3.0.0
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
2
2
 
3
3
  setup(
4
4
  name='rustat-python-api',
5
- version='0.6.1',
5
+ version='0.6.3',
6
6
  description='A Python wrapper for RuStat API',
7
7
  long_description=open('README.md').read(),
8
8
  long_description_content_type='text/markdown',
@@ -18,7 +18,8 @@ setup(
18
18
  'scipy==1.14.1',
19
19
  'matplotlib',
20
20
  'matplotsoccer',
21
- 'torch'
21
+ 'torch',
22
+ 'triton==3.0.0',
22
23
  ],
23
24
  classifiers=[
24
25
  'Programming Language :: Python :: 3',