rustat-python-api 0.7.0__tar.gz → 0.7.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {rustat-python-api-0.7.0/rustat_python_api.egg-info → rustat-python-api-0.7.1}/PKG-INFO +1 -1
  2. rustat-python-api-0.7.1/rustat_python_api/kernels/maha.py +77 -0
  3. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1/rustat_python_api.egg-info}/PKG-INFO +1 -1
  4. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/setup.py +1 -1
  5. rustat-python-api-0.7.0/rustat_python_api/kernels/maha.py +0 -77
  6. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/LICENSE +0 -0
  7. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/README.md +0 -0
  8. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/pyproject.toml +0 -0
  9. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/rustat_python_api/__init__.py +0 -0
  10. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/rustat_python_api/config.py +0 -0
  11. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/rustat_python_api/kernels/__init__.py +0 -0
  12. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/rustat_python_api/models_api.py +0 -0
  13. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/rustat_python_api/parser.py +0 -0
  14. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/rustat_python_api/pitch_control.py +0 -0
  15. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/rustat_python_api/processing.py +0 -0
  16. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/rustat_python_api/urls.py +0 -0
  17. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/rustat_python_api.egg-info/SOURCES.txt +0 -0
  18. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/rustat_python_api.egg-info/dependency_links.txt +0 -0
  19. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/rustat_python_api.egg-info/requires.txt +0 -0
  20. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/rustat_python_api.egg-info/top_level.txt +0 -0
  21. {rustat-python-api-0.7.0 → rustat-python-api-0.7.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: rustat-python-api
3
- Version: 0.7.0
3
+ Version: 0.7.1
4
4
  Summary: A Python wrapper for RuStat API
5
5
  Home-page: https://github.com/dailydaniel/rustat-python-api
6
6
  Author: Daniel Zholkovsky
@@ -0,0 +1,77 @@
1
+ import torch
2
+
3
+ try:
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ @triton.jit
8
+ def influence_kernel(
9
+ MU, # (F,P,2)
10
+ SIGMA_INV, # (F,P,2,2)
11
+ LOCS, # (N,2) – read-only
12
+ OUT, # (F,N)
13
+ F, P, N, # sizes
14
+ BLOCK_N: tl.constexpr, # кол-во точек на блок
15
+ ):
16
+ pid_f = tl.program_id(0) # какой frame
17
+ offs_n = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) # chunk точек
18
+ mask_n = offs_n < N
19
+
20
+ loc_ptr = LOCS + offs_n[:, None] * 2 # strides=(2,)
21
+ x = tl.load(loc_ptr, mask=mask_n[:, None]) # (BLOCK_N,2)
22
+
23
+ acc = tl.zeros([BLOCK_N], dtype=tl.float32)
24
+
25
+ for p in tl.static_range(0, 32):
26
+ is_valid = p < P
27
+
28
+ # --- pointers ---
29
+ mu_ptr = MU + pid_f * P * 2 + p * 2
30
+ s_ptr = SIGMA_INV + pid_f * P * 4 + p * 4
31
+
32
+ # --- scalar loads ---
33
+ mu0 = tl.load(mu_ptr + 0, mask=is_valid)
34
+ mu1 = tl.load(mu_ptr + 1, mask=is_valid)
35
+
36
+ s00 = tl.load(s_ptr + 0, mask=is_valid)
37
+ s01 = tl.load(s_ptr + 1, mask=is_valid)
38
+ s10 = tl.load(s_ptr + 2, mask=is_valid)
39
+ s11 = tl.load(s_ptr + 3, mask=is_valid)
40
+
41
+ # --- grid coords (also scalar) ---
42
+ x0 = tl.load(LOCS + offs_n * 2 + 0, mask=mask_n) # (BLOCK_N,)
43
+ x1 = tl.load(LOCS + offs_n * 2 + 1, mask=mask_n)
44
+
45
+ diff0 = x0 - mu0
46
+ diff1 = x1 - mu1
47
+
48
+ tmp0 = diff0 * s00 + diff1 * s01
49
+ tmp1 = diff0 * s10 + diff1 * s11
50
+ maha = diff0 * tmp0 + diff1 * tmp1
51
+
52
+ acc += tl.where(is_valid, tl.exp(-0.5 * maha), 0.0)
53
+
54
+ out_ptr = OUT + pid_f*N + offs_n
55
+ tl.store(out_ptr, acc, mask=mask_n)
56
+
57
+ def triton_influence(mu, sigma_inv, locs, BLOCK_N=64):
58
+ F, P, _ = mu.shape
59
+ N = locs.shape[0]
60
+ out = torch.empty((F, N), device=mu.device, dtype=mu.dtype)
61
+
62
+ grid = (F, triton.cdiv(N, BLOCK_N))
63
+ influence_kernel[grid](
64
+ mu, sigma_inv, locs, out,
65
+ F, P, N,
66
+ BLOCK_N=BLOCK_N,
67
+ num_warps=4
68
+ )
69
+ return out
70
+
71
+ except ImportError:
72
+ def triton_influence(*args, **kwargs):
73
+ raise RuntimeError(
74
+ "Triton not available. This function requires Linux with NVIDIA GPU. "
75
+ "For CPU fallback, please use the .calculate_influence() method with device='cpu'. "
76
+ "On macOS, the package installs without GPU acceleration."
77
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: rustat-python-api
3
- Version: 0.7.0
3
+ Version: 0.7.1
4
4
  Summary: A Python wrapper for RuStat API
5
5
  Home-page: https://github.com/dailydaniel/rustat-python-api
6
6
  Author: Daniel Zholkovsky
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
2
2
 
3
3
  setup(
4
4
  name='rustat-python-api',
5
- version='0.7.0',
5
+ version='0.7.1',
6
6
  description='A Python wrapper for RuStat API',
7
7
  long_description=open('README.md').read(),
8
8
  long_description_content_type='text/markdown',
@@ -1,77 +0,0 @@
1
- import torch
2
-
3
- try:
4
- import triton
5
- import triton.language as tl
6
- except ImportError:
7
- triton = None
8
- tl = None
9
-
10
-
11
- @triton.jit
12
- def influence_kernel(
13
- MU, # (F,P,2)
14
- SIGMA_INV, # (F,P,2,2)
15
- LOCS, # (N,2) – read-only
16
- OUT, # (F,N)
17
- F, P, N, # sizes
18
- BLOCK_N: tl.constexpr, # кол-во точек на блок
19
- ):
20
- pid_f = tl.program_id(0) # какой frame
21
- offs_n = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) # chunk точек
22
- mask_n = offs_n < N
23
-
24
- # ----- загрузим координаты сетки (N,2) -----
25
- loc_ptr = LOCS + offs_n[:, None] * 2 # strides=(2,)
26
- x = tl.load(loc_ptr, mask=mask_n[:, None]) # (BLOCK_N,2)
27
-
28
- # аккум для суммирования влияний
29
- acc = tl.zeros([BLOCK_N], dtype=tl.float32)
30
-
31
- for p in tl.static_range(0, 32):
32
- is_valid = p < P
33
-
34
- # --- pointers ---
35
- mu_ptr = MU + pid_f * P * 2 + p * 2
36
- s_ptr = SIGMA_INV + pid_f * P * 4 + p * 4
37
-
38
- # --- scalar loads ---
39
- mu0 = tl.load(mu_ptr + 0, mask=is_valid)
40
- mu1 = tl.load(mu_ptr + 1, mask=is_valid)
41
-
42
- s00 = tl.load(s_ptr + 0, mask=is_valid)
43
- s01 = tl.load(s_ptr + 1, mask=is_valid)
44
- s10 = tl.load(s_ptr + 2, mask=is_valid)
45
- s11 = tl.load(s_ptr + 3, mask=is_valid)
46
-
47
- # --- grid coords (also scalar) ---
48
- x0 = tl.load(LOCS + offs_n * 2 + 0, mask=mask_n) # (BLOCK_N,)
49
- x1 = tl.load(LOCS + offs_n * 2 + 1, mask=mask_n)
50
-
51
- diff0 = x0 - mu0
52
- diff1 = x1 - mu1
53
-
54
- tmp0 = diff0 * s00 + diff1 * s01
55
- tmp1 = diff0 * s10 + diff1 * s11
56
- maha = diff0 * tmp0 + diff1 * tmp1
57
-
58
- acc += tl.where(is_valid, tl.exp(-0.5 * maha), 0.0)
59
-
60
- # запишем
61
- out_ptr = OUT + pid_f*N + offs_n
62
- tl.store(out_ptr, acc, mask=mask_n)
63
-
64
-
65
- def triton_influence(mu, sigma_inv, locs, BLOCK_N=64):
66
- F, P, _ = mu.shape
67
- N = locs.shape[0]
68
- out = torch.empty((F, N), device=mu.device, dtype=mu.dtype)
69
-
70
- grid = (F, triton.cdiv(N, BLOCK_N))
71
- influence_kernel[grid](
72
- mu, sigma_inv, locs, out,
73
- F, P, N,
74
- BLOCK_N=BLOCK_N,
75
- num_warps=4
76
- )
77
- return out