rustat-python-api 0.6.6__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,6 @@
1
- from .maha import triton_influence
1
+ try:
2
+ from .maha import triton_influence
3
+ except ImportError:
4
+ triton_influence = None
2
5
 
3
6
  __all__ = ["triton_influence"]
@@ -1,72 +1,77 @@
1
1
  import torch
2
- import triton
3
- import triton.language as tl
4
-
5
-
6
- @triton.jit
7
- def influence_kernel(
8
- MU, # (F,P,2)
9
- SIGMA_INV, # (F,P,2,2)
10
- LOCS, # (N,2) – read-only
11
- OUT, # (F,N)
12
- F, P, N, # sizes
13
- BLOCK_N: tl.constexpr, # кол-во точек на блок
14
- ):
15
- pid_f = tl.program_id(0) # какой frame
16
- offs_n = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) # chunk точек
17
- mask_n = offs_n < N
18
-
19
- # ----- загрузим координаты сетки (N,2) -----
20
- loc_ptr = LOCS + offs_n[:, None] * 2 # strides=(2,)
21
- x = tl.load(loc_ptr, mask=mask_n[:, None]) # (BLOCK_N,2)
22
-
23
- # аккум для суммирования влияний
24
- acc = tl.zeros([BLOCK_N], dtype=tl.float32)
25
-
26
- for p in tl.static_range(0, 32):
27
- is_valid = p < P
28
-
29
- # --- pointers ---
30
- mu_ptr = MU + pid_f * P * 2 + p * 2
31
- s_ptr = SIGMA_INV + pid_f * P * 4 + p * 4
32
-
33
- # --- scalar loads ---
34
- mu0 = tl.load(mu_ptr + 0, mask=is_valid)
35
- mu1 = tl.load(mu_ptr + 1, mask=is_valid)
36
-
37
- s00 = tl.load(s_ptr + 0, mask=is_valid)
38
- s01 = tl.load(s_ptr + 1, mask=is_valid)
39
- s10 = tl.load(s_ptr + 2, mask=is_valid)
40
- s11 = tl.load(s_ptr + 3, mask=is_valid)
41
-
42
- # --- grid coords (also scalar) ---
43
- x0 = tl.load(LOCS + offs_n * 2 + 0, mask=mask_n) # (BLOCK_N,)
44
- x1 = tl.load(LOCS + offs_n * 2 + 1, mask=mask_n)
45
-
46
- diff0 = x0 - mu0
47
- diff1 = x1 - mu1
48
-
49
- tmp0 = diff0 * s00 + diff1 * s01
50
- tmp1 = diff0 * s10 + diff1 * s11
51
- maha = diff0 * tmp0 + diff1 * tmp1
52
-
53
- acc += tl.where(is_valid, tl.exp(-0.5 * maha), 0.0)
54
-
55
- # запишем
56
- out_ptr = OUT + pid_f*N + offs_n
57
- tl.store(out_ptr, acc, mask=mask_n)
58
-
59
-
60
- def triton_influence(mu, sigma_inv, locs, BLOCK_N=64):
61
- F, P, _ = mu.shape
62
- N = locs.shape[0]
63
- out = torch.empty((F, N), device=mu.device, dtype=mu.dtype)
64
-
65
- grid = (F, triton.cdiv(N, BLOCK_N))
66
- influence_kernel[grid](
67
- mu, sigma_inv, locs, out,
68
- F, P, N,
69
- BLOCK_N=BLOCK_N,
70
- num_warps=4
71
- )
72
- return out
2
+
3
+ try:
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ @triton.jit
8
+ def influence_kernel(
9
+ MU, # (F,P,2)
10
+ SIGMA_INV, # (F,P,2,2)
11
+ LOCS, # (N,2) – read-only
12
+ OUT, # (F,N)
13
+ F, P, N, # sizes
14
+ BLOCK_N: tl.constexpr, # кол-во точек на блок
15
+ ):
16
+ pid_f = tl.program_id(0) # какой frame
17
+ offs_n = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) # chunk точек
18
+ mask_n = offs_n < N
19
+
20
+ loc_ptr = LOCS + offs_n[:, None] * 2 # strides=(2,)
21
+ x = tl.load(loc_ptr, mask=mask_n[:, None]) # (BLOCK_N,2)
22
+
23
+ acc = tl.zeros([BLOCK_N], dtype=tl.float32)
24
+
25
+ for p in tl.static_range(0, 32):
26
+ is_valid = p < P
27
+
28
+ # --- pointers ---
29
+ mu_ptr = MU + pid_f * P * 2 + p * 2
30
+ s_ptr = SIGMA_INV + pid_f * P * 4 + p * 4
31
+
32
+ # --- scalar loads ---
33
+ mu0 = tl.load(mu_ptr + 0, mask=is_valid)
34
+ mu1 = tl.load(mu_ptr + 1, mask=is_valid)
35
+
36
+ s00 = tl.load(s_ptr + 0, mask=is_valid)
37
+ s01 = tl.load(s_ptr + 1, mask=is_valid)
38
+ s10 = tl.load(s_ptr + 2, mask=is_valid)
39
+ s11 = tl.load(s_ptr + 3, mask=is_valid)
40
+
41
+ # --- grid coords (also scalar) ---
42
+ x0 = tl.load(LOCS + offs_n * 2 + 0, mask=mask_n) # (BLOCK_N,)
43
+ x1 = tl.load(LOCS + offs_n * 2 + 1, mask=mask_n)
44
+
45
+ diff0 = x0 - mu0
46
+ diff1 = x1 - mu1
47
+
48
+ tmp0 = diff0 * s00 + diff1 * s01
49
+ tmp1 = diff0 * s10 + diff1 * s11
50
+ maha = diff0 * tmp0 + diff1 * tmp1
51
+
52
+ acc += tl.where(is_valid, tl.exp(-0.5 * maha), 0.0)
53
+
54
+ out_ptr = OUT + pid_f*N + offs_n
55
+ tl.store(out_ptr, acc, mask=mask_n)
56
+
57
+ def triton_influence(mu, sigma_inv, locs, BLOCK_N=64):
58
+ F, P, _ = mu.shape
59
+ N = locs.shape[0]
60
+ out = torch.empty((F, N), device=mu.device, dtype=mu.dtype)
61
+
62
+ grid = (F, triton.cdiv(N, BLOCK_N))
63
+ influence_kernel[grid](
64
+ mu, sigma_inv, locs, out,
65
+ F, P, N,
66
+ BLOCK_N=BLOCK_N,
67
+ num_warps=4
68
+ )
69
+ return out
70
+
71
+ except ImportError:
72
+ def triton_influence(*args, **kwargs):
73
+ raise RuntimeError(
74
+ "Triton not available. This function requires Linux with NVIDIA GPU. "
75
+ "For CPU fallback, please use the .calculate_influence() method with device='cpu'. "
76
+ "On macOS, the package installs without GPU acceleration."
77
+ )
@@ -7,8 +7,10 @@ import matplotsoccer as mpl
7
7
  import torch
8
8
  from tqdm import tqdm
9
9
 
10
- from .kernels import triton_influence
11
-
10
+ try:
11
+ from .kernels import triton_influence
12
+ except ImportError:
13
+ triton_influence = None
12
14
 
13
15
  class PitchControl:
14
16
  def __init__(self, tracking: pd.DataFrame, events: pd.DataFrame, ball_data: pd.DataFrame = None):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: rustat-python-api
3
- Version: 0.6.6
3
+ Version: 0.7.1
4
4
  Summary: A Python wrapper for RuStat API
5
5
  Home-page: https://github.com/dailydaniel/rustat-python-api
6
6
  Author: Daniel Zholkovsky
@@ -2,13 +2,13 @@ rustat_python_api/__init__.py,sha256=Ij-PAm2y5ss_XAZhKTZus35cRPLzvXFyIswDa_Iq3rs
2
2
  rustat_python_api/config.py,sha256=eMvi1p8Cfvnbp6Cd4bBOwgehVN7thKnaQV5uzWyGZXM,1844
3
3
  rustat_python_api/models_api.py,sha256=oHXEqeCupvZwjVEdoxf7W9LP7ELFKA8-9DuRXpQHLno,1701
4
4
  rustat_python_api/parser.py,sha256=c5SW9p5N1PLYnza1olZbnvxkWtCAYYty2ifjyaUwXkw,10280
5
- rustat_python_api/pitch_control.py,sha256=s501qCpvMI7qNQAajDF_NXnlIDMloJQCgUxAfjKGOa0,26282
5
+ rustat_python_api/pitch_control.py,sha256=iJiF0_3j6yh3riV8kTZ_mtJibPqotQZOUgpumSyS-Zw,26338
6
6
  rustat_python_api/processing.py,sha256=WES2D77uu3ScpjN0nW8ozatHteCZJQmmF9WpHPgGfJo,2835
7
7
  rustat_python_api/urls.py,sha256=iJTD31T6OyXPAhmhViwFXVehrzwsOjBDONA1SIVc_40,1068
8
- rustat_python_api/kernels/__init__.py,sha256=fQrUfDakSr5IsFV0LqAxvLoEw9Og5PThfux1GQw1Tps,67
9
- rustat_python_api/kernels/maha.py,sha256=ARGJRJRp9VlX-zQbKipGMUKMMSqIEWg4Pg-skhkYMWc,2184
10
- rustat_python_api-0.6.6.dist-info/LICENSE,sha256=4Cohqg5p6Mq1xyrzdEX8AvFSA62GSVvapEOr2xK_tgY,57
11
- rustat_python_api-0.6.6.dist-info/METADATA,sha256=Ykv3P2-fdkRn6wYT25RIotnuX8t2ZPLgW0vYHl8aj7Y,1808
12
- rustat_python_api-0.6.6.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
13
- rustat_python_api-0.6.6.dist-info/top_level.txt,sha256=VK0hmkKZE9YThxolUcoE6JtGI67NFeKJMBLuet8kI4w,18
14
- rustat_python_api-0.6.6.dist-info/RECORD,,
8
+ rustat_python_api/kernels/__init__.py,sha256=eFJ-BMY8VcNZSjf3XjOnZf_nfOQ5t-7Lp57DPCHYOo0,124
9
+ rustat_python_api/kernels/maha.py,sha256=k2PqY6VghgER2j9QH8xGYq61JLfPaHjirLXb4aLnjQw,2591
10
+ rustat_python_api-0.7.1.dist-info/LICENSE,sha256=4Cohqg5p6Mq1xyrzdEX8AvFSA62GSVvapEOr2xK_tgY,57
11
+ rustat_python_api-0.7.1.dist-info/METADATA,sha256=Db46gm9pWb-1Zra3BrFFnUI8AU9aLJ4D93XQxnZc6PU,1808
12
+ rustat_python_api-0.7.1.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
13
+ rustat_python_api-0.7.1.dist-info/top_level.txt,sha256=VK0hmkKZE9YThxolUcoE6JtGI67NFeKJMBLuet8kI4w,18
14
+ rustat_python_api-0.7.1.dist-info/RECORD,,