rwkv-ops 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (43) hide show
  1. rwkv_ops/__init__.py +26 -0
  2. rwkv_ops/rwkv7_kernel/__init__.py +153 -0
  3. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +221 -0
  4. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  5. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  6. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  7. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  8. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  9. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  10. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  11. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  12. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  13. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  14. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  15. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  16. rwkv_ops/rwkv7_kernel/native_keras_op.py +95 -0
  17. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  18. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  19. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  20. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  21. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  22. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  23. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  24. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  25. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  26. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  27. rwkv_ops/rwkv7_kernel/torch_op.py +523 -0
  28. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  29. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  30. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  31. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  32. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  33. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  34. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  35. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  36. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  37. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  38. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  39. rwkv_ops-0.1.0.dist-info/LICENSE.txt +201 -0
  40. rwkv_ops-0.1.0.dist-info/METADATA +118 -0
  41. rwkv_ops-0.1.0.dist-info/RECORD +43 -0
  42. rwkv_ops-0.1.0.dist-info/WHEEL +5 -0
  43. rwkv_ops-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,95 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+ import jax_triton as jt
5
+ import jax
6
+ import triton
7
+ from ..triton_kernel.chunk_A_bwd import *
8
+ from ..triton_kernel.utils import is_gather_supported
9
+ from ..get_torch_devices_info import check_shared_mem
10
+
11
+
12
+ def chunk_dplr_bwd_dqk_intra(
13
+ q: jax.Array,
14
+ k: jax.Array,
15
+ a: jax.Array,
16
+ b: jax.Array,
17
+ gi: jax.Array,
18
+ ge: jax.Array,
19
+ dAqk: jax.Array,
20
+ dAqb: jax.Array,
21
+ dAak: jax.Array,
22
+ dAab: jax.Array,
23
+ dqg: jax.Array,
24
+ dkg: jax.Array,
25
+ dag: jax.Array,
26
+ dbg: jax.Array,
27
+ dgk_last: jax.Array,
28
+ scale: float = 1.0,
29
+ chunk_size: int = 16,
30
+ ):
31
+ B, T, H, K = q.shape
32
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
33
+ BK = (
34
+ min(64, triton.next_power_of_2(K))
35
+ if check_shared_mem()
36
+ else min(32, triton.next_power_of_2(K))
37
+ )
38
+
39
+ NT = triton.cdiv(T, BT)
40
+ NK = triton.cdiv(K, BK)
41
+ grid = (NK, NT, B * H)
42
+
43
+ out_shapes = [
44
+ jax.ShapeDtypeStruct(q.shape, q.dtype),
45
+ jax.ShapeDtypeStruct(k.shape, k.dtype),
46
+ jax.ShapeDtypeStruct(a.shape, a.dtype),
47
+ jax.ShapeDtypeStruct(b.shape, b.dtype),
48
+ jax.ShapeDtypeStruct(gi.shape, "float32"),
49
+ jax.ShapeDtypeStruct(gi.shape, "float32"),
50
+ ]
51
+
52
+ dq, dk, da, db, dgk, dgk_offset = jt.triton_call(
53
+ q,
54
+ k,
55
+ a,
56
+ b,
57
+ gi,
58
+ ge,
59
+ dAqk,
60
+ dAqb,
61
+ dAak,
62
+ dAab,
63
+ dqg,
64
+ dkg,
65
+ dag,
66
+ dbg,
67
+ T,
68
+ scale=scale,
69
+ H=H,
70
+ K=K,
71
+ BT=BT,
72
+ BC=BT,
73
+ BK=BK,
74
+ GATHER_SUPPORTED=is_gather_supported,
75
+ kernel=chunk_dplr_bwd_kernel_intra,
76
+ out_shape=out_shapes,
77
+ grid=grid,
78
+ )
79
+
80
+ def grid(meta):
81
+ return (NT, triton.cdiv(K, meta["BK"]), B * H)
82
+
83
+ dgk_output = jt.triton_call(
84
+ dgk,
85
+ dgk_offset,
86
+ dgk_last,
87
+ T,
88
+ H=H,
89
+ K=K,
90
+ BT=BT,
91
+ kernel=chunk_dplr_bwd_dgk_kernel,
92
+ out_shape=[jax.ShapeDtypeStruct(dgk.shape, "float32")],
93
+ grid=grid,
94
+ )[0]
95
+ return dq, dk, da, db, dgk_output
@@ -0,0 +1,60 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+
5
+ import jax_triton as jt
6
+ import jax
7
+ import triton
8
+
9
+ from ..triton_kernel.utils import is_gather_supported
10
+
11
+ from ..triton_kernel.chunk_A_fwd import *
12
+
13
+
14
+ def chunk_dplr_fwd_intra(
15
+ q: jax.Array,
16
+ k: jax.Array,
17
+ a: jax.Array,
18
+ b: jax.Array,
19
+ gi: jax.Array,
20
+ ge: jax.Array,
21
+ scale: float,
22
+ chunk_size: int,
23
+ ):
24
+ B, T, H, K = k.shape
25
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
26
+
27
+ NT = triton.cdiv(T, BT)
28
+ shape = [B, T, H, BT]
29
+ out_shapes = [
30
+ jax.ShapeDtypeStruct(q.shape, q.dtype),
31
+ jax.ShapeDtypeStruct(k.shape, q.dtype),
32
+ jax.ShapeDtypeStruct(a.shape, q.dtype),
33
+ jax.ShapeDtypeStruct(b.shape, q.dtype),
34
+ jax.ShapeDtypeStruct(shape, q.dtype),
35
+ jax.ShapeDtypeStruct(shape, q.dtype),
36
+ jax.ShapeDtypeStruct(shape, "float32"),
37
+ jax.ShapeDtypeStruct(shape, "float32"),
38
+ ]
39
+ grid = (NT, B, H)
40
+ BK = triton.next_power_of_2(K)
41
+ qg, kg, ag, bg, Aqk, Aqb, Aab, Aak = jt.triton_call(
42
+ q,
43
+ k,
44
+ a,
45
+ b,
46
+ gi,
47
+ ge,
48
+ T,
49
+ scale=scale,
50
+ H=H,
51
+ K=K,
52
+ BT=BT,
53
+ BC=BT,
54
+ BK=BK,
55
+ GATHER_SUPPORTED=is_gather_supported,
56
+ kernel=chunk_dplr_fwd_A_kernel_intra_sub_intra,
57
+ out_shape=out_shapes,
58
+ grid=grid,
59
+ )
60
+ return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
@@ -0,0 +1,78 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import jax_triton as jt
7
+ import jax
8
+ import triton
9
+
10
+ from ..get_jax_devices_info import check_shared_mem
11
+ from ..triton_kernel.chunk_h_bwd import *
12
+
13
+
14
+ def chunk_dplr_bwd_dhu(
15
+ qg: jax.Array,
16
+ bg: jax.Array,
17
+ w: jax.Array,
18
+ gk: jax.Array,
19
+ h0: jax.Array,
20
+ dht: Optional[jax.Array],
21
+ do: jax.Array,
22
+ dv: jax.Array,
23
+ chunk_size: int = 64,
24
+ ) -> Tuple[jax.Array, jax.Array, jax.Array]:
25
+ B, T, H, K, V = *qg.shape, do.shape[-1]
26
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
27
+ BK = triton.next_power_of_2(K)
28
+ assert BK <= 256, (
29
+ "current kernel does not support head dimension being larger than 256."
30
+ )
31
+ # H100
32
+ if check_shared_mem("hopper"):
33
+ BV = 64
34
+ BC = 64 if K <= 128 else 32
35
+ elif check_shared_mem("ampere"): # A100
36
+ BV = 32
37
+ BC = 32
38
+ else: # Etc: 4090
39
+ BV = 16
40
+ BC = 16
41
+
42
+ N, NT = B, triton.cdiv(T, BT)
43
+ BC = min(BT, BC)
44
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
45
+ assert NK == 1, (
46
+ "NK > 1 is not supported because it involves time-consuming synchronization"
47
+ )
48
+ dh_shape = (B, NT, H, K, V)
49
+ out_shapes = [
50
+ jax.ShapeDtypeStruct(dh_shape, dv.dtype),
51
+ jax.ShapeDtypeStruct((B, H, K, V), "float32"),
52
+ jax.ShapeDtypeStruct(dv.shape, dv.dtype),
53
+ ]
54
+
55
+ grid = (NK, NV, N * H)
56
+ dh, dh0, dv2 = jt.triton_call(
57
+ qg,
58
+ bg,
59
+ w,
60
+ gk,
61
+ dht,
62
+ dv,
63
+ do,
64
+ T,
65
+ H=H,
66
+ K=K,
67
+ V=V,
68
+ BT=BT,
69
+ BC=BC,
70
+ BK=BK,
71
+ BV=BV,
72
+ kernel=chunk_dplr_bwd_kernel_dhu.fn,
73
+ out_shape=out_shapes,
74
+ grid=grid,
75
+ USE_FINAL_STATE_GRADIENT=dht is not None,
76
+ USE_INITIAL_STATE=h0 is not None,
77
+ )
78
+ return dh, dh0, dv2
@@ -0,0 +1,80 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import jax_triton as jt
7
+ import jax
8
+ import triton
9
+
10
+ from ..get_jax_devices_info import check_shared_mem
11
+ from ..triton_kernel.chunk_h_fwd import *
12
+
13
+
14
+ def chunk_dplr_fwd_h(
15
+ kg: jax.Array,
16
+ v: jax.Array,
17
+ w: jax.Array,
18
+ u: jax.Array,
19
+ bg: jax.Array,
20
+ gk: jax.Array,
21
+ initial_state: Optional[jax.Array] = None,
22
+ output_final_state: bool = False,
23
+ chunk_size: int = 64,
24
+ ) -> Tuple[jax.Array, jax.Array]:
25
+ B, T, H, K, V = *kg.shape, u.shape[-1]
26
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
27
+
28
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
29
+ BK = triton.next_power_of_2(K)
30
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
31
+ # H100 can have larger block size
32
+
33
+ if check_shared_mem("hopper"):
34
+ BV = 64
35
+ BC = 64 if K <= 128 else 32
36
+ elif check_shared_mem("ampere"): # A100
37
+ BV = 32
38
+ BC = 32
39
+ else:
40
+ BV = 16
41
+ BC = 16
42
+
43
+ BC = min(BT, BC)
44
+ NK = triton.cdiv(K, BK)
45
+ NV = triton.cdiv(V, BV)
46
+ assert NK == 1, (
47
+ "NK > 1 is not supported because it involves time-consuming synchronization"
48
+ )
49
+
50
+ out_shapes = [
51
+ jax.ShapeDtypeStruct((B, NT, H, K, V), kg.dtype),
52
+ jax.ShapeDtypeStruct([N, H, K, V], "float32"),
53
+ jax.ShapeDtypeStruct(u.shape, u.dtype),
54
+ ]
55
+ grid = (NK, NV, N * H)
56
+ if initial_state is None:
57
+ initial_state = jax.numpy.zeros([N, H, K, V], "float32")
58
+ h, final_state, v_new = jt.triton_call(
59
+ kg,
60
+ v,
61
+ w,
62
+ bg,
63
+ u,
64
+ gk,
65
+ initial_state,
66
+ T,
67
+ H=H,
68
+ K=K,
69
+ V=V,
70
+ BT=BT,
71
+ BC=BC,
72
+ BK=BK,
73
+ BV=BV,
74
+ kernel=chunk_dplr_fwd_kernel_h.fn,
75
+ out_shape=out_shapes,
76
+ grid=grid,
77
+ STORE_FINAL_STATE=True,
78
+ USE_INITIAL_STATE=True,
79
+ )
80
+ return h, v_new, final_state
@@ -0,0 +1,150 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+ from typing import Tuple
5
+
6
+ import jax_triton as jt
7
+ import jax
8
+ import triton
9
+
10
+ from ..get_torch_devices_info import check_shared_mem
11
+ from ..triton_kernel.chunk_o_bwd import *
12
+
13
+
14
+ def chunk_dplr_bwd_dv(
15
+ A_qk: jax.Array,
16
+ kg: jax.Array,
17
+ do: jax.Array,
18
+ dh: jax.Array,
19
+ chunk_size: int = 64,
20
+ ) -> jax.Array:
21
+ B, T, H, K, V = *kg.shape, do.shape[-1]
22
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
23
+
24
+ NT = triton.cdiv(T, BT)
25
+
26
+ def grid(meta):
27
+ return (triton.cdiv(V, meta["BV"]), NT, B * H)
28
+
29
+ dv = jt.triton_call(
30
+ A_qk,
31
+ kg,
32
+ do,
33
+ dh,
34
+ T,
35
+ H=H,
36
+ K=K,
37
+ V=V,
38
+ BT=BT,
39
+ kernel=chunk_dplr_bwd_kernel_dv,
40
+ out_shape=jax.ShapeDtypeStruct(do.shape, do.dtype),
41
+ grid=grid,
42
+ )
43
+ return dv
44
+
45
+
46
+ def chunk_dplr_bwd_o(
47
+ k: jax.Array,
48
+ b: jax.Array,
49
+ v: jax.Array,
50
+ v_new: jax.Array,
51
+ gk: jax.Array,
52
+ do: jax.Array,
53
+ h: jax.Array,
54
+ dh: jax.Array,
55
+ dv: jax.Array,
56
+ w: jax.Array,
57
+ chunk_size: int = 64,
58
+ scale: float = 1.0,
59
+ ) -> Tuple[jax.Array, jax.Array, jax.Array]:
60
+ B, T, H, K, V = *w.shape, v.shape[-1]
61
+
62
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
63
+ NT = triton.cdiv(T, BT)
64
+
65
+ BK = (
66
+ min(triton.next_power_of_2(K), 64)
67
+ if check_shared_mem()
68
+ else min(triton.next_power_of_2(K), 32)
69
+ )
70
+ BV = (
71
+ min(triton.next_power_of_2(V), 64)
72
+ if check_shared_mem()
73
+ else min(triton.next_power_of_2(K), 32)
74
+ )
75
+ NK = triton.cdiv(K, BK)
76
+ grid = (NK, NT, B * H)
77
+
78
+ out_shapes = [
79
+ jax.ShapeDtypeStruct(k.shape, k.dtype),
80
+ jax.ShapeDtypeStruct(k.shape, k.dtype),
81
+ jax.ShapeDtypeStruct(w.shape, w.dtype),
82
+ jax.ShapeDtypeStruct(b.shape, b.dtype),
83
+ jax.ShapeDtypeStruct([B, NT, H, K], "float32"),
84
+ ]
85
+ dq, dk, dw, db, dgk_last = jt.triton_call(
86
+ v,
87
+ v_new,
88
+ h,
89
+ do,
90
+ dh,
91
+ w,
92
+ dv,
93
+ gk,
94
+ k,
95
+ b,
96
+ T,
97
+ H=H,
98
+ K=K,
99
+ V=V,
100
+ BT=BT,
101
+ BK=BK,
102
+ BV=BV,
103
+ kernel=chunk_dplr_bwd_o_kernel,
104
+ out_shape=out_shapes,
105
+ grid=grid,
106
+ )
107
+ return dq, dk, dw, db, dgk_last
108
+
109
+
110
+ def chunk_dplr_bwd_dAu(
111
+ v: jax.Array,
112
+ v_new: jax.Array,
113
+ do: jax.Array,
114
+ A_qb: jax.Array,
115
+ scale: float,
116
+ chunk_size: int = 64,
117
+ ) -> jax.Array:
118
+ B, T, H, V = v.shape
119
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
120
+ NT = triton.cdiv(T, BT)
121
+
122
+ if check_shared_mem("ampere"): # A100
123
+ BV = min(triton.next_power_of_2(V), 128)
124
+ elif check_shared_mem("ada"): # 4090
125
+ BV = min(triton.next_power_of_2(V), 64)
126
+ else:
127
+ BV = min(triton.next_power_of_2(V), 32)
128
+
129
+ grid = (NT, B * H)
130
+ out_shapes = [
131
+ jax.ShapeDtypeStruct([B, T, H, BT], "float32"),
132
+ jax.ShapeDtypeStruct([B, T, H, BT], "float32"),
133
+ jax.ShapeDtypeStruct(v_new.shape, v_new.dtype),
134
+ ]
135
+ dA_qk, dA_qb, dv_new = jt.triton_call(
136
+ v,
137
+ do,
138
+ v_new,
139
+ A_qb,
140
+ T,
141
+ scale=scale,
142
+ H=H,
143
+ V=V,
144
+ BT=BT,
145
+ BV=BV,
146
+ grid=grid,
147
+ out_shape=out_shapes,
148
+ kernel=chunk_dplr_bwd_kernel_dAu,
149
+ )
150
+ return dv_new, dA_qk, dA_qb
@@ -0,0 +1,45 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+
5
+ import jax_triton as jt
6
+ import jax
7
+ import triton
8
+
9
+ from ..triton_kernel.chunk_o_fwd import *
10
+
11
+
12
+ def chunk_dplr_fwd_o(
13
+ qg: jax.Array,
14
+ v: jax.Array,
15
+ v_new: jax.Array,
16
+ A_qk: jax.Array,
17
+ A_qb: jax.Array,
18
+ h: jax.Array,
19
+ chunk_size: int = 64,
20
+ ) -> jax.Array:
21
+ B, T, H, K, V = *qg.shape, v.shape[-1]
22
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
23
+
24
+ NT = triton.cdiv(T, BT)
25
+
26
+ def grid(meta):
27
+ return (triton.cdiv(V, meta["BV"]), NT, B * H)
28
+
29
+ o = jt.triton_call(
30
+ qg,
31
+ v,
32
+ v_new,
33
+ A_qk,
34
+ A_qb,
35
+ h,
36
+ T,
37
+ H=H,
38
+ K=K,
39
+ V=V,
40
+ BT=BT,
41
+ kernel=chunk_dplr_fwd_kernel_o,
42
+ out_shape=jax.ShapeDtypeStruct(v.shape, v.dtype),
43
+ grid=grid,
44
+ )
45
+ return o
@@ -0,0 +1,34 @@
1
+ from ..triton_kernel.cumsum import *
2
+ import jax_triton as jt
3
+ import jax
4
+ import triton
5
+
6
+
7
+ def chunk_rwkv6_fwd_cumsum(
8
+ g: jax.Array,
9
+ chunk_size: int,
10
+ ) -> jax.Array:
11
+ B, T, H, S = g.shape
12
+ BT = chunk_size
13
+ NT = triton.cdiv(T, BT)
14
+
15
+ out_shapes = [
16
+ jax.ShapeDtypeStruct(g.shape, "float32"),
17
+ jax.ShapeDtypeStruct(g.shape, "float32"),
18
+ ]
19
+
20
+ def grid(meta):
21
+ return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H)
22
+
23
+ gi, ge = jt.triton_call(
24
+ g,
25
+ T,
26
+ H=H,
27
+ S=S,
28
+ BT=BT,
29
+ grid=grid,
30
+ kernel=chunk_rwkv6_fwd_cumsum_kernel,
31
+ out_shape=out_shapes,
32
+ )
33
+
34
+ return gi, ge
@@ -0,0 +1,61 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+ from typing import Tuple
5
+
6
+ import jax_triton as jt
7
+ import jax
8
+ import triton
9
+
10
+
11
+ from ..get_torch_devices_info import check_shared_mem
12
+ from ..triton_kernel.wy_fast_bwd import *
13
+
14
+
15
+ def chunk_dplr_bwd_wy(
16
+ A_ab_inv: jax.Array,
17
+ A_ak: jax.Array,
18
+ v: jax.Array,
19
+ ag: jax.Array,
20
+ dw: jax.Array,
21
+ du: jax.Array,
22
+ dv0: jax.Array,
23
+ chunk_size: int = 16,
24
+ ) -> Tuple[jax.Array, jax.Array, jax.Array]:
25
+ B, T, H, K, V = *dw.shape, du.shape[-1]
26
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
27
+
28
+ NT = triton.cdiv(T, BT)
29
+ BK = min(triton.next_power_of_2(K), 64)
30
+ BV = (
31
+ min(triton.next_power_of_2(V), 64)
32
+ if check_shared_mem()
33
+ else min(triton.next_power_of_2(V), 32)
34
+ )
35
+ grid = (NT, B * H)
36
+ out_shapes = [
37
+ jax.ShapeDtypeStruct(A_ak.shape, "float32"),
38
+ jax.ShapeDtypeStruct(A_ab_inv.shape, "float32"),
39
+ jax.ShapeDtypeStruct(v.shape, v.dtype),
40
+ jax.ShapeDtypeStruct(ag.shape, ag.dtype),
41
+ ]
42
+ dA_ak, dA_ab, dv, dag = jt.triton_call(
43
+ A_ab_inv,
44
+ A_ak,
45
+ ag,
46
+ v,
47
+ dw,
48
+ du,
49
+ dv0,
50
+ T,
51
+ H=H,
52
+ K=K,
53
+ V=V,
54
+ BT=BT,
55
+ BK=BK,
56
+ BV=BV,
57
+ grid=grid,
58
+ kernel=prepare_wy_repr_bwd_kernel,
59
+ out_shape=out_shapes,
60
+ )
61
+ return dA_ab, dA_ak, dv, dag
@@ -0,0 +1,86 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+ from typing import Tuple
5
+
6
+ import jax_triton as jt
7
+ import jax
8
+ import triton
9
+
10
+
11
+ from ..triton_kernel.wy_fast_fwd import *
12
+
13
+
14
+ def wu_fwd(
15
+ ag: jax.Array,
16
+ v: jax.Array,
17
+ A_ak: jax.Array,
18
+ A_ab_inv: jax.Array,
19
+ chunk_size: int,
20
+ ) -> Tuple[jax.Array, jax.Array]:
21
+ B, T, H, K, V = *ag.shape, v.shape[-1]
22
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
23
+
24
+ NT = triton.cdiv(T, BT)
25
+ BK = min(triton.next_power_of_2(K), 64)
26
+ BV = min(triton.next_power_of_2(V), 64)
27
+
28
+ out_shapes = [
29
+ jax.ShapeDtypeStruct(v.shape, v.dtype),
30
+ jax.ShapeDtypeStruct(ag.shape, ag.dtype),
31
+ ]
32
+ grid = (NT, B * H)
33
+ w, u = jt.triton_call(
34
+ ag,
35
+ v,
36
+ A_ab_inv,
37
+ A_ak,
38
+ T,
39
+ H=H,
40
+ K=K,
41
+ V=V,
42
+ BT=BT,
43
+ BK=BK,
44
+ BV=BV,
45
+ grid=grid,
46
+ kernel=wu_fwd_kernel,
47
+ out_shape=out_shapes,
48
+ )
49
+ return w, u
50
+
51
+
52
+ def prepare_wy_repr_fwd(
53
+ ag: jax.Array,
54
+ v: jax.Array,
55
+ A_ak: jax.Array,
56
+ A_ab: jax.Array,
57
+ chunk_size: int = 64,
58
+ ) -> Tuple[jax.Array, jax.Array, jax.Array]:
59
+ B, T, H, _ = ag.shape
60
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
61
+
62
+ NT = triton.cdiv(T, BT)
63
+ BC = min(BT, 32)
64
+ fwd_fn = (
65
+ prepare_wy_repr_fwd_kernel_chunk64
66
+ if BT == 64
67
+ else prepare_wy_repr_fwd_kernel_chunk32
68
+ )
69
+ grid = (NT, B * H)
70
+ A_ab_inv = jt.triton_call(
71
+ A_ab,
72
+ T,
73
+ H=H,
74
+ BT=BT,
75
+ BC=BC,
76
+ grid=grid,
77
+ kernel=fwd_fn,
78
+ out_shape=jax.ShapeDtypeStruct(A_ab.shape, A_ab.dtype),
79
+ )
80
+ w, u = wu_fwd(ag=ag, v=v, A_ak=A_ak, A_ab_inv=A_ab_inv, chunk_size=BT)
81
+ return w, u, A_ab_inv
82
+
83
+
84
+ fwd_prepare_wy_repr = prepare_wy_repr_fwd
85
+
86
+ fwd_wu = wu_fwd