rwkv-ops 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (43) hide show
  1. rwkv_ops/__init__.py +26 -0
  2. rwkv_ops/rwkv7_kernel/__init__.py +153 -0
  3. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +221 -0
  4. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  5. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  6. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  7. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  8. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  9. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  10. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  11. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  12. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  13. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  14. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  15. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  16. rwkv_ops/rwkv7_kernel/native_keras_op.py +95 -0
  17. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  18. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  19. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  20. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  21. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  22. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  23. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  24. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  25. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  26. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  27. rwkv_ops/rwkv7_kernel/torch_op.py +523 -0
  28. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  29. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  30. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  31. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  32. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  33. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  34. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  35. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  36. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  37. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  38. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  39. rwkv_ops-0.1.0.dist-info/LICENSE.txt +201 -0
  40. rwkv_ops-0.1.0.dist-info/METADATA +118 -0
  41. rwkv_ops-0.1.0.dist-info/RECORD +43 -0
  42. rwkv_ops-0.1.0.dist-info/WHEEL +5 -0
  43. rwkv_ops-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,75 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+
9
+ from ..get_torch_devices_info import check_shared_mem
10
+ from ..triton_kernel.chunk_h_fwd import *
11
+
12
+
13
+ def chunk_dplr_fwd_h(
14
+ kg: torch.Tensor,
15
+ v: torch.Tensor,
16
+ w: torch.Tensor,
17
+ u: torch.Tensor,
18
+ bg: torch.Tensor,
19
+ gk: torch.Tensor,
20
+ initial_state: Optional[torch.Tensor] = None,
21
+ output_final_state: bool = False,
22
+ chunk_size: int = 64,
23
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
24
+ B, T, H, K, V = *kg.shape, u.shape[-1]
25
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
26
+
27
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
28
+ BK = triton.next_power_of_2(K)
29
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
30
+ # H100 can have larger block size
31
+
32
+ if check_shared_mem("hopper", kg.device.index):
33
+ BV = 64
34
+ BC = 64 if K <= 128 else 32
35
+ elif check_shared_mem("ampere", kg.device.index): # A100
36
+ BV = 32
37
+ BC = 32
38
+ else:
39
+ BV = 16
40
+ BC = 16
41
+
42
+ BC = min(BT, BC)
43
+ NK = triton.cdiv(K, BK)
44
+ NV = triton.cdiv(V, BV)
45
+ assert NK == 1, (
46
+ "NK > 1 is not supported because it involves time-consuming synchronization"
47
+ )
48
+
49
+ h = kg.new_empty(B, NT, H, K, V)
50
+ final_state = (
51
+ kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
52
+ )
53
+ v_new = torch.empty_like(u)
54
+ grid = (NK, NV, N * H)
55
+ chunk_dplr_fwd_kernel_h[grid](
56
+ kg=kg,
57
+ v=v,
58
+ w=w,
59
+ bg=bg,
60
+ u=u,
61
+ v_new=v_new,
62
+ h=h,
63
+ gk=gk,
64
+ h0=initial_state,
65
+ ht=final_state,
66
+ T=T,
67
+ H=H,
68
+ K=K,
69
+ V=V,
70
+ BT=BT,
71
+ BC=BC,
72
+ BK=BK,
73
+ BV=BV,
74
+ )
75
+ return h, v_new, final_state
@@ -0,0 +1,148 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import triton
8
+
9
+ from ..get_torch_devices_info import check_shared_mem
10
+ from ..triton_kernel.chunk_o_bwd import *
11
+
12
+
13
+ def chunk_dplr_bwd_dv(
14
+ A_qk: torch.Tensor,
15
+ kg: torch.Tensor,
16
+ do: torch.Tensor,
17
+ dh: torch.Tensor,
18
+ chunk_size: int = 64,
19
+ ) -> torch.Tensor:
20
+ B, T, H, K, V = *kg.shape, do.shape[-1]
21
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
22
+
23
+ NT = triton.cdiv(T, BT)
24
+
25
+ dv = torch.empty_like(do)
26
+
27
+ def grid(meta):
28
+ return (triton.cdiv(V, meta["BV"]), NT, B * H)
29
+
30
+ chunk_dplr_bwd_kernel_dv[grid](
31
+ A_qk=A_qk,
32
+ kg=kg,
33
+ do=do,
34
+ dv=dv,
35
+ dh=dh,
36
+ T=T,
37
+ H=H,
38
+ K=K,
39
+ V=V,
40
+ BT=BT,
41
+ )
42
+ return dv
43
+
44
+
45
+ def chunk_dplr_bwd_o(
46
+ k: torch.Tensor,
47
+ b: torch.Tensor,
48
+ v: torch.Tensor,
49
+ v_new: torch.Tensor,
50
+ gk: torch.Tensor,
51
+ do: torch.Tensor,
52
+ h: torch.Tensor,
53
+ dh: torch.Tensor,
54
+ dv: torch.Tensor,
55
+ w: torch.Tensor,
56
+ chunk_size: int = 64,
57
+ scale: float = 1.0,
58
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
59
+ B, T, H, K, V = *w.shape, v.shape[-1]
60
+
61
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
62
+ NT = triton.cdiv(T, BT)
63
+
64
+ BK = (
65
+ min(triton.next_power_of_2(K), 64)
66
+ if check_shared_mem()
67
+ else min(triton.next_power_of_2(K), 32)
68
+ )
69
+ BV = (
70
+ min(triton.next_power_of_2(V), 64)
71
+ if check_shared_mem()
72
+ else min(triton.next_power_of_2(K), 32)
73
+ )
74
+ NK = triton.cdiv(K, BK)
75
+ dq = torch.empty_like(k)
76
+ dk = torch.empty_like(k)
77
+ dw = torch.empty_like(w)
78
+ db = torch.empty_like(b)
79
+ grid = (NK, NT, B * H)
80
+
81
+ dgk_last = torch.empty(B, NT, H, K, dtype=torch.float, device=w.device)
82
+
83
+ chunk_dplr_bwd_o_kernel[grid](
84
+ k=k,
85
+ b=b,
86
+ v=v,
87
+ v_new=v_new,
88
+ h=h,
89
+ do=do,
90
+ dh=dh,
91
+ dq=dq,
92
+ dk=dk,
93
+ db=db,
94
+ dgk_last=dgk_last,
95
+ w=w,
96
+ dv=dv,
97
+ dw=dw,
98
+ gk=gk,
99
+ T=T,
100
+ H=H,
101
+ K=K,
102
+ V=V,
103
+ BT=BT,
104
+ BK=BK,
105
+ BV=BV,
106
+ )
107
+ return dq, dk, dw, db, dgk_last
108
+
109
+
110
+ def chunk_dplr_bwd_dAu(
111
+ v: torch.Tensor,
112
+ v_new: torch.Tensor,
113
+ do: torch.Tensor,
114
+ A_qb: torch.Tensor,
115
+ scale: float,
116
+ chunk_size: int = 64,
117
+ ) -> torch.Tensor:
118
+ B, T, H, V = v.shape
119
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
120
+ NT = triton.cdiv(T, BT)
121
+
122
+ if check_shared_mem("ampere"): # A100
123
+ BV = min(triton.next_power_of_2(V), 128)
124
+ elif check_shared_mem("ada"): # 4090
125
+ BV = min(triton.next_power_of_2(V), 64)
126
+ else:
127
+ BV = min(triton.next_power_of_2(V), 32)
128
+
129
+ grid = (NT, B * H)
130
+ dA_qk = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device)
131
+ dA_qb = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device)
132
+ dv_new = torch.empty_like(v_new)
133
+ chunk_dplr_bwd_kernel_dAu[grid](
134
+ v=v,
135
+ do=do,
136
+ v_new=v_new,
137
+ A_qb=A_qb,
138
+ dA_qk=dA_qk,
139
+ dA_qb=dA_qb,
140
+ dv_new=dv_new,
141
+ scale=scale,
142
+ T=T,
143
+ H=H,
144
+ V=V,
145
+ BT=BT,
146
+ BV=BV,
147
+ )
148
+ return dv_new, dA_qk, dA_qb
@@ -0,0 +1,44 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ import torch
6
+ import triton
7
+
8
+ from ..triton_kernel.chunk_o_fwd import *
9
+
10
+
11
+ def chunk_dplr_fwd_o(
12
+ qg: torch.Tensor,
13
+ v: torch.Tensor,
14
+ v_new: torch.Tensor,
15
+ A_qk: torch.Tensor,
16
+ A_qb: torch.Tensor,
17
+ h: torch.Tensor,
18
+ chunk_size: int = 64,
19
+ ) -> torch.Tensor:
20
+ B, T, H, K, V = *qg.shape, v.shape[-1]
21
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
22
+
23
+ NT = triton.cdiv(T, BT)
24
+
25
+ o = torch.empty_like(v)
26
+
27
+ def grid(meta):
28
+ return (triton.cdiv(V, meta["BV"]), NT, B * H)
29
+
30
+ chunk_dplr_fwd_kernel_o[grid](
31
+ qg=qg,
32
+ v=v,
33
+ v_new=v_new,
34
+ A_qk=A_qk,
35
+ A_qb=A_qb,
36
+ h=h,
37
+ o=o,
38
+ T=T,
39
+ H=H,
40
+ K=K,
41
+ V=V,
42
+ BT=BT,
43
+ )
44
+ return o
@@ -0,0 +1,31 @@
1
+ from ..triton_kernel.cumsum import *
2
+ import torch
3
+
4
+
5
+ def chunk_rwkv6_fwd_cumsum(
6
+ g: torch.Tensor,
7
+ chunk_size: int,
8
+ ) -> torch.Tensor:
9
+ B, T, H, S = g.shape
10
+ BT = chunk_size
11
+ NT = triton.cdiv(T, BT)
12
+
13
+ gi, ge = (
14
+ torch.empty_like(g, dtype=torch.float),
15
+ torch.empty_like(g, dtype=torch.float),
16
+ )
17
+
18
+ def grid(meta):
19
+ return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H)
20
+
21
+ # keep cummulative normalizer in fp32
22
+ chunk_rwkv6_fwd_cumsum_kernel[grid](
23
+ g,
24
+ T,
25
+ gi,
26
+ ge,
27
+ H=H,
28
+ S=S,
29
+ BT=BT,
30
+ )
31
+ return gi, ge
@@ -0,0 +1,63 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import triton
8
+
9
+
10
+ from ..get_torch_devices_info import check_shared_mem
11
+ from ..triton_kernel.wy_fast_bwd import *
12
+
13
+
14
+ def chunk_dplr_bwd_wy(
15
+ A_ab_inv: torch.Tensor,
16
+ A_ak: torch.Tensor,
17
+ v: torch.Tensor,
18
+ ag: torch.Tensor,
19
+ dw: torch.Tensor,
20
+ du: torch.Tensor,
21
+ dv0: torch.Tensor,
22
+ chunk_size: int = 16,
23
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
24
+ A_ab_inv, A_ak, v, ag, dw, du = map(
25
+ lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du]
26
+ )
27
+ B, T, H, K, V = *dw.shape, du.shape[-1]
28
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
29
+
30
+ NT = triton.cdiv(T, BT)
31
+ BK = min(triton.next_power_of_2(K), 64)
32
+ BV = (
33
+ min(triton.next_power_of_2(V), 64)
34
+ if check_shared_mem()
35
+ else min(triton.next_power_of_2(V), 32)
36
+ )
37
+
38
+ dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float)
39
+ dA_ak = torch.empty_like(A_ak, dtype=torch.float)
40
+ dv = torch.empty_like(v)
41
+ dag = torch.empty_like(ag)
42
+
43
+ prepare_wy_repr_bwd_kernel[(NT, B * H)](
44
+ A_ab_inv=A_ab_inv,
45
+ A_ak=A_ak,
46
+ ag=ag,
47
+ v=v,
48
+ dw=dw,
49
+ du=du,
50
+ dv=dv,
51
+ dv0=dv0,
52
+ dag=dag,
53
+ dAak=dA_ak,
54
+ dAab=dA_ab,
55
+ T=T,
56
+ H=H,
57
+ K=K,
58
+ V=V,
59
+ BT=BT,
60
+ BK=BK,
61
+ BV=BV,
62
+ )
63
+ return dA_ab, dA_ak, dv, dag
@@ -0,0 +1,79 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import triton
8
+
9
+ from ..triton_kernel.wy_fast_fwd import *
10
+
11
+
12
+ def wu_fwd(
13
+ ag: torch.Tensor,
14
+ v: torch.Tensor,
15
+ A_ak: torch.Tensor,
16
+ A_ab_inv: torch.Tensor,
17
+ chunk_size: int,
18
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
19
+ B, T, H, K, V = *ag.shape, v.shape[-1]
20
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
21
+
22
+ NT = triton.cdiv(T, BT)
23
+ BK = min(triton.next_power_of_2(K), 64)
24
+ BV = min(triton.next_power_of_2(V), 64)
25
+
26
+ w = torch.empty_like(ag)
27
+ u = torch.empty_like(v)
28
+ wu_fwd_kernel[(NT, B * H)](
29
+ ag=ag,
30
+ v=v,
31
+ A_ak=A_ak,
32
+ A_ab_inv=A_ab_inv,
33
+ w=w,
34
+ u=u,
35
+ T=T,
36
+ H=H,
37
+ K=K,
38
+ V=V,
39
+ BT=BT,
40
+ BK=BK,
41
+ BV=BV,
42
+ )
43
+ return w, u
44
+
45
+
46
+ def prepare_wy_repr_fwd(
47
+ ag: torch.Tensor,
48
+ v: torch.Tensor,
49
+ A_ak: torch.Tensor,
50
+ A_ab: torch.Tensor,
51
+ chunk_size: int = 64,
52
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
53
+ B, T, H, _ = ag.shape
54
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
55
+
56
+ NT = triton.cdiv(T, BT)
57
+ BC = min(BT, 32)
58
+ fwd_fn = (
59
+ prepare_wy_repr_fwd_kernel_chunk64
60
+ if BT == 64
61
+ else prepare_wy_repr_fwd_kernel_chunk32
62
+ )
63
+
64
+ A_ab_inv = torch.empty_like(A_ab)
65
+ fwd_fn[(NT, B * H)](
66
+ A_ab=A_ab,
67
+ A_ab_inv=A_ab_inv,
68
+ T=T,
69
+ H=H,
70
+ BT=BT,
71
+ BC=BC,
72
+ )
73
+ w, u = wu_fwd(ag=ag, v=v, A_ak=A_ak, A_ab_inv=A_ab_inv, chunk_size=BT)
74
+ return w, u, A_ab_inv
75
+
76
+
77
+ fwd_prepare_wy_repr = prepare_wy_repr_fwd
78
+
79
+ fwd_wu = wu_fwd