rwkv-ops 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (43) hide show
  1. rwkv_ops/__init__.py +26 -0
  2. rwkv_ops/rwkv7_kernel/__init__.py +153 -0
  3. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +221 -0
  4. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  5. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  6. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  7. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  8. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  9. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  10. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  11. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  12. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  13. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  14. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  15. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  16. rwkv_ops/rwkv7_kernel/native_keras_op.py +95 -0
  17. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  18. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  19. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  20. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  21. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  22. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  23. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  24. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  25. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  26. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  27. rwkv_ops/rwkv7_kernel/torch_op.py +523 -0
  28. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  29. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  30. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  31. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  32. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  33. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  34. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  35. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  36. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  37. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  38. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  39. rwkv_ops-0.1.0.dist-info/LICENSE.txt +201 -0
  40. rwkv_ops-0.1.0.dist-info/METADATA +118 -0
  41. rwkv_ops-0.1.0.dist-info/RECORD +43 -0
  42. rwkv_ops-0.1.0.dist-info/WHEEL +5 -0
  43. rwkv_ops-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,160 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from ..triton_kernel.utils import exp, use_cuda_graph
9
+
10
+
11
+ @triton.heuristics(
12
+ {
13
+ "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
14
+ "STORE_FINAL_STATE": lambda args: args["ht"] is not None,
15
+ }
16
+ )
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ for num_warps in [2, 4, 8, 16, 32]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=["BT", "BK", "BV"],
24
+ use_cuda_graph=use_cuda_graph,
25
+ )
26
+ @triton.jit(do_not_specialize=["T"])
27
+ def chunk_dplr_fwd_kernel_h(
28
+ kg,
29
+ v,
30
+ w,
31
+ bg,
32
+ u,
33
+ gk,
34
+ h0,
35
+ T,
36
+ h,
37
+ ht,
38
+ v_new,
39
+ H: tl.constexpr,
40
+ K: tl.constexpr,
41
+ V: tl.constexpr,
42
+ BT: tl.constexpr,
43
+ BC: tl.constexpr,
44
+ BK: tl.constexpr,
45
+ BV: tl.constexpr,
46
+ USE_INITIAL_STATE: tl.constexpr,
47
+ STORE_FINAL_STATE: tl.constexpr,
48
+ ):
49
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
50
+ i_n, i_h = i_nh // H, i_nh % H
51
+ if False:
52
+ bos, eos = (
53
+ tl.load(cu_seqlens + i_n).to(tl.int32),
54
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
55
+ )
56
+ T = eos - bos
57
+ NT = tl.cdiv(T, BT)
58
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
59
+ else:
60
+ bos, eos = i_n * T, i_n * T + T
61
+ NT = tl.cdiv(T, BT)
62
+ boh = i_n * NT
63
+ o_k = i_k * BK + tl.arange(0, BK)
64
+
65
+ # [BK, BV]
66
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
67
+ if USE_INITIAL_STATE:
68
+ p_h0 = tl.make_block_ptr(
69
+ h0 + i_nh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
70
+ )
71
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
72
+
73
+ for i_t in range(NT):
74
+ p_h = tl.make_block_ptr(
75
+ h + ((boh + i_t) * H + i_h) * K * V,
76
+ (K, V),
77
+ (V, 1),
78
+ (i_k * BK, i_v * BV),
79
+ (BK, BV),
80
+ (1, 0),
81
+ )
82
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
83
+
84
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
85
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
86
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
87
+ p_kg = tl.make_block_ptr(
88
+ kg + (bos * H + i_h) * K,
89
+ (K, T),
90
+ (1, H * K),
91
+ (i_k * BK, i_t * BT + i_c * BC),
92
+ (BK, BC),
93
+ (0, 1),
94
+ )
95
+ p_bg = tl.make_block_ptr(
96
+ bg + (bos * H + i_h) * K,
97
+ (K, T),
98
+ (1, H * K),
99
+ (i_k * BK, i_t * BT + i_c * BC),
100
+ (BK, BC),
101
+ (0, 1),
102
+ )
103
+ p_w = tl.make_block_ptr(
104
+ w + (bos * H + i_h) * K,
105
+ (T, K),
106
+ (H * K, 1),
107
+ (i_t * BT + i_c * BC, i_k * BK),
108
+ (BC, BK),
109
+ (1, 0),
110
+ )
111
+ p_v = tl.make_block_ptr(
112
+ v + (bos * H + i_h) * V,
113
+ (T, V),
114
+ (H * V, 1),
115
+ (i_t * BT + i_c * BC, i_v * BV),
116
+ (BC, BV),
117
+ (1, 0),
118
+ )
119
+ p_u = tl.make_block_ptr(
120
+ u + (bos * H + i_h) * V,
121
+ (T, V),
122
+ (H * V, 1),
123
+ (i_t * BT + i_c * BC, i_v * BV),
124
+ (BC, BV),
125
+ (1, 0),
126
+ )
127
+ p_v_new = tl.make_block_ptr(
128
+ v_new + (bos * H + i_h) * V,
129
+ (T, V),
130
+ (H * V, 1),
131
+ (i_t * BT + i_c * BC, i_v * BV),
132
+ (BC, BV),
133
+ (1, 0),
134
+ )
135
+ # [BK, BC]
136
+ b_kg = tl.load(p_kg, boundary_check=(0, 1))
137
+ b_v = tl.load(p_v, boundary_check=(0, 1))
138
+ b_w = tl.load(p_w, boundary_check=(0, 1))
139
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
140
+ b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1))
141
+ b_hc += tl.dot(b_kg, b_v)
142
+ b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2)
143
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
144
+
145
+ last_idx = min((i_t + 1) * BT, T) - 1
146
+ b_g_last = tl.load(
147
+ gk + (bos + last_idx) * H * K + i_h * K + o_k, mask=o_k < K
148
+ ).to(tl.float32)
149
+ b_h *= exp(b_g_last[:, None])
150
+ b_h += b_hc
151
+
152
+ if STORE_FINAL_STATE:
153
+ p_ht = tl.make_block_ptr(
154
+ ht + i_nh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
155
+ )
156
+ tl.store(
157
+ p_ht,
158
+ b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"),
159
+ boundary_check=(0, 1),
160
+ )
@@ -0,0 +1,382 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from ..triton_kernel.utils import (
9
+ exp,
10
+ check_shared_mem,
11
+ use_cuda_graph,
12
+ )
13
+
14
+ BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
15
+
16
+
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ for num_warps in [2, 4, 8, 16, 32]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=["BV", "BT"],
24
+ use_cuda_graph=use_cuda_graph,
25
+ )
26
+ @triton.jit(do_not_specialize=["T"])
27
+ def chunk_dplr_bwd_kernel_dAu(
28
+ v,
29
+ do,
30
+ v_new,
31
+ A_qb,
32
+ T,
33
+ dA_qk,
34
+ dA_qb,
35
+ dv_new,
36
+ scale: tl.constexpr,
37
+ H: tl.constexpr,
38
+ V: tl.constexpr,
39
+ BT: tl.constexpr,
40
+ BV: tl.constexpr,
41
+ ):
42
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
43
+ i_b, i_h = i_bh // H, i_bh % H
44
+ if False:
45
+ i_n, i_t = (
46
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
47
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
48
+ )
49
+ bos, eos = (
50
+ tl.load(cu_seqlens + i_n).to(tl.int32),
51
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
52
+ )
53
+ else:
54
+ bos, eos = i_b * T, i_b * T + T
55
+ T = eos - bos
56
+
57
+ b_dA_qk = tl.zeros([BT, BT], dtype=tl.float32)
58
+ b_dA_qb = tl.zeros([BT, BT], dtype=tl.float32)
59
+
60
+ p_A_qb = tl.make_block_ptr(
61
+ A_qb + (bos * H + i_h) * BT,
62
+ (T, BT),
63
+ (H * BT, 1),
64
+ (i_t * BT, 0),
65
+ (BT, BT),
66
+ (1, 0),
67
+ )
68
+
69
+ b_A_qb = tl.load(p_A_qb, boundary_check=(0, 1))
70
+ # causal mask
71
+ b_A_qb = tl.where(
72
+ tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_A_qb, 0.0
73
+ ).to(b_A_qb.dtype)
74
+
75
+ for i_v in range(tl.cdiv(V, BV)):
76
+ p_do = tl.make_block_ptr(
77
+ do + (bos * H + i_h) * V,
78
+ (T, V),
79
+ (H * V, 1),
80
+ (i_t * BT, i_v * BV),
81
+ (BT, BV),
82
+ (1, 0),
83
+ )
84
+ p_v = tl.make_block_ptr(
85
+ v + (bos * H + i_h) * V,
86
+ (V, T),
87
+ (1, H * V),
88
+ (i_v * BV, i_t * BT),
89
+ (BV, BT),
90
+ (0, 1),
91
+ )
92
+ p_v_new = tl.make_block_ptr(
93
+ v_new + (bos * H + i_h) * V,
94
+ (V, T),
95
+ (1, H * V),
96
+ (i_v * BV, i_t * BT),
97
+ (BV, BT),
98
+ (0, 1),
99
+ )
100
+ p_dv_new = tl.make_block_ptr(
101
+ dv_new + (bos * H + i_h) * V,
102
+ (T, V),
103
+ (H * V, 1),
104
+ (i_t * BT, i_v * BV),
105
+ (BT, BV),
106
+ (1, 0),
107
+ )
108
+ b_v = tl.load(p_v, boundary_check=(0, 1))
109
+ b_do = tl.load(p_do, boundary_check=(0, 1))
110
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
111
+ b_dA_qk += tl.dot(b_do, b_v)
112
+ b_dA_qb += tl.dot(b_do, b_v_new)
113
+ b_dv_new = tl.dot(tl.trans(b_A_qb), b_do)
114
+ # for recurrent
115
+ tl.store(
116
+ p_dv_new, b_dv_new.to(p_dv_new.dtype.element_ty), boundary_check=(0, 1)
117
+ )
118
+
119
+ p_dA_qk = tl.make_block_ptr(
120
+ dA_qk + (bos * H + i_h) * BT,
121
+ (T, BT),
122
+ (H * BT, 1),
123
+ (i_t * BT, 0),
124
+ (BT, BT),
125
+ (1, 0),
126
+ )
127
+ p_dA_qb = tl.make_block_ptr(
128
+ dA_qb + (bos * H + i_h) * BT,
129
+ (T, BT),
130
+ (H * BT, 1),
131
+ (i_t * BT, 0),
132
+ (BT, BT),
133
+ (1, 0),
134
+ )
135
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
136
+ b_dA_qk = tl.where(m_s, b_dA_qk * scale, 0.0)
137
+ tl.store(p_dA_qk, b_dA_qk.to(p_dA_qk.dtype.element_ty), boundary_check=(0, 1))
138
+ b_dA_qb = tl.where(m_s, b_dA_qb * scale, 0.0)
139
+ tl.store(p_dA_qb, b_dA_qb.to(p_dA_qb.dtype.element_ty), boundary_check=(0, 1))
140
+
141
+
142
+ @triton.autotune(
143
+ configs=[
144
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
145
+ for num_warps in [2, 4, 8, 16, 32]
146
+ for num_stages in [2, 3, 4]
147
+ ],
148
+ key=["BT", "BK", "BV"],
149
+ use_cuda_graph=use_cuda_graph,
150
+ )
151
+ @triton.jit
152
+ def chunk_dplr_bwd_o_kernel(
153
+ v,
154
+ v_new,
155
+ h,
156
+ do,
157
+ dh,
158
+ w,
159
+ dv,
160
+ gk,
161
+ k,
162
+ b,
163
+ T,
164
+ dq,
165
+ dk,
166
+ dw,
167
+ db,
168
+ dgk_last,
169
+ H: tl.constexpr,
170
+ K: tl.constexpr,
171
+ V: tl.constexpr,
172
+ BT: tl.constexpr,
173
+ BK: tl.constexpr,
174
+ BV: tl.constexpr,
175
+ ):
176
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
177
+ i_b, i_h = i_bh // H, i_bh % H
178
+
179
+ if False:
180
+ i_tg = i_t
181
+ i_n, i_t = (
182
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
183
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
184
+ )
185
+ bos, eos = (
186
+ tl.load(cu_seqlens + i_n).to(tl.int32),
187
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
188
+ )
189
+ T = eos - bos
190
+ NT = tl.cdiv(T, BT)
191
+ else:
192
+ NT = tl.cdiv(T, BT)
193
+ i_tg = i_b * NT + i_t
194
+ bos, eos = i_b * T, i_b * T + T
195
+
196
+ # offset calculation
197
+ v += (bos * H + i_h) * V
198
+ v_new += (bos * H + i_h) * V
199
+ do += (bos * H + i_h) * V
200
+ h += (i_tg * H + i_h) * K * V
201
+ dh += (i_tg * H + i_h) * K * V
202
+ dk += (bos * H + i_h) * K
203
+ k += (bos * H + i_h) * K
204
+ db += (bos * H + i_h) * K
205
+ b += (bos * H + i_h) * K
206
+ dw += (bos * H + i_h) * K
207
+ dv += (bos * H + i_h) * V
208
+ dq += (bos * H + i_h) * K
209
+ w += (bos * H + i_h) * K
210
+
211
+ dgk_last += (i_tg * H + i_h) * K
212
+ gk += (bos * H + i_h) * K
213
+
214
+ stride_qk = H * K
215
+ stride_vo = H * V
216
+
217
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
218
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
219
+ b_dw = tl.zeros([BT, BK], dtype=tl.float32)
220
+ b_db = tl.zeros([BT, BK], dtype=tl.float32)
221
+ b_dgk_last = tl.zeros([BK], dtype=tl.float32)
222
+
223
+ for i_v in range(tl.cdiv(V, BV)):
224
+ p_v = tl.make_block_ptr(
225
+ v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
226
+ )
227
+ p_v_new = tl.make_block_ptr(
228
+ v_new, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
229
+ )
230
+ p_do = tl.make_block_ptr(
231
+ do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
232
+ )
233
+ p_h = tl.make_block_ptr(
234
+ h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)
235
+ )
236
+ p_dh = tl.make_block_ptr(
237
+ dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)
238
+ )
239
+ # [BT, BV]
240
+ b_v = tl.load(p_v, boundary_check=(0, 1))
241
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
242
+ b_do = tl.load(p_do, boundary_check=(0, 1))
243
+ # [BV, BK]
244
+ b_h = tl.load(p_h, boundary_check=(0, 1))
245
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
246
+ b_dgk_last += tl.sum((b_h * b_dh).to(tl.float32), axis=0)
247
+
248
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
249
+ b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
250
+ # [BT, BV] @ [BV, BK] -> [BT, BK]
251
+ b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
252
+ b_db += tl.dot(b_v_new, b_dh.to(b_v_new.dtype))
253
+ p_dv = tl.make_block_ptr(
254
+ dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
255
+ )
256
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
257
+ b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
258
+
259
+ m_k = (i_k * BK + tl.arange(0, BK)) < K
260
+ last_idx = min(i_t * BT + BT, T) - 1
261
+ b_gk_last = tl.load(
262
+ gk + last_idx * stride_qk + i_k * BK + tl.arange(0, BK),
263
+ mask=m_k,
264
+ other=float("-inf"),
265
+ )
266
+ b_dgk_last *= exp(b_gk_last)
267
+ p_k = tl.make_block_ptr(
268
+ k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
269
+ )
270
+ p_b = tl.make_block_ptr(
271
+ b, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
272
+ )
273
+ b_k = tl.load(p_k, boundary_check=(0, 1))
274
+ b_b = tl.load(p_b, boundary_check=(0, 1))
275
+ b_dgk_last += tl.sum(b_k * b_dk, axis=0)
276
+ b_dgk_last += tl.sum(b_b * b_db, axis=0)
277
+ tl.store(dgk_last + tl.arange(0, BK) + i_k * BK, b_dgk_last, mask=m_k)
278
+
279
+ p_dw = tl.make_block_ptr(
280
+ dw, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
281
+ )
282
+ p_dk = tl.make_block_ptr(
283
+ dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
284
+ )
285
+ p_db = tl.make_block_ptr(
286
+ db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
287
+ )
288
+ p_dq = tl.make_block_ptr(
289
+ dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
290
+ )
291
+ tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
292
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
293
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1))
294
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
295
+
296
+
297
+ @triton.autotune(
298
+ configs=[
299
+ triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages)
300
+ for num_warps in [2, 4, 8, 16, 32]
301
+ for num_stages in [2, 3, 4]
302
+ for BK in BK_LIST
303
+ for BV in BK_LIST
304
+ ],
305
+ key=["BT"],
306
+ use_cuda_graph=use_cuda_graph,
307
+ )
308
+ @triton.jit
309
+ def chunk_dplr_bwd_kernel_dv(
310
+ A_qk,
311
+ kg,
312
+ do,
313
+ dh,
314
+ T,
315
+ dv,
316
+ H: tl.constexpr,
317
+ K: tl.constexpr,
318
+ V: tl.constexpr,
319
+ BT: tl.constexpr,
320
+ BK: tl.constexpr,
321
+ BV: tl.constexpr,
322
+ ):
323
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
324
+ i_b, i_h = i_bh // H, i_bh % H
325
+ if False:
326
+ i_tg = i_t
327
+ i_n, i_t = (
328
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
329
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
330
+ )
331
+ bos, eos = (
332
+ tl.load(cu_seqlens + i_n).to(tl.int32),
333
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
334
+ )
335
+ T = eos - bos
336
+ NT = tl.cdiv(T, BT)
337
+ else:
338
+ NT = tl.cdiv(T, BT)
339
+ i_tg = i_b * NT + i_t
340
+ bos, eos = i_b * T, i_b * T + T
341
+
342
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
343
+
344
+ # offset calculation
345
+ A_qk += (bos * H + i_h) * BT
346
+ do += (bos * H + i_h) * V
347
+ dv += (bos * H + i_h) * V
348
+ kg += (bos * H + i_h) * K
349
+ dh += (i_tg * H + i_h) * K * V
350
+
351
+ stride_qk = H * K
352
+ stride_vo = H * V
353
+ stride_A = H * BT
354
+
355
+ for i_k in range(tl.cdiv(K, BK)):
356
+ p_dh = tl.make_block_ptr(
357
+ dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
358
+ )
359
+ p_kg = tl.make_block_ptr(
360
+ kg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
361
+ )
362
+ b_dh = tl.load(p_dh, boundary_check=(0, 1))
363
+ b_kg = tl.load(p_kg, boundary_check=(0, 1))
364
+ b_dv += tl.dot(b_kg, b_dh.to(b_kg.dtype))
365
+
366
+ p_Aqk = tl.make_block_ptr(
367
+ A_qk, (BT, T), (1, stride_A), (0, i_t * BT), (BT, BT), (0, 1)
368
+ )
369
+ b_A = tl.where(
370
+ tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :],
371
+ tl.load(p_Aqk, boundary_check=(0, 1)),
372
+ 0,
373
+ )
374
+ p_do = tl.make_block_ptr(
375
+ do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
376
+ )
377
+ p_dv = tl.make_block_ptr(
378
+ dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
379
+ )
380
+ b_do = tl.load(p_do, boundary_check=(0, 1))
381
+ b_dv += tl.dot(b_A.to(b_do.dtype), b_do)
382
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
@@ -0,0 +1,137 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from ..triton_kernel.utils import check_shared_mem, use_cuda_graph
9
+
10
+
11
+ BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32]
12
+
13
+
14
+ @triton.autotune(
15
+ configs=[
16
+ triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages)
17
+ for BK in BK_LIST
18
+ for BV in BK_LIST
19
+ for num_warps in [2, 4, 8, 16, 32]
20
+ for num_stages in [2, 3, 4]
21
+ ],
22
+ key=["BT"],
23
+ use_cuda_graph=use_cuda_graph,
24
+ )
25
+ @triton.jit(do_not_specialize=["T"])
26
+ def chunk_dplr_fwd_kernel_o(
27
+ qg,
28
+ v,
29
+ v_new,
30
+ A_qk,
31
+ A_qb,
32
+ h,
33
+ T,
34
+ o,
35
+ H: tl.constexpr,
36
+ K: tl.constexpr,
37
+ V: tl.constexpr,
38
+ BT: tl.constexpr,
39
+ BK: tl.constexpr,
40
+ BV: tl.constexpr,
41
+ ):
42
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
43
+ i_b, i_h = i_bh // H, i_bh % H
44
+
45
+ if False:
46
+ i_tg = i_t
47
+ i_n, i_t = (
48
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
49
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
50
+ )
51
+ bos, eos = (
52
+ tl.load(cu_seqlens + i_n).to(tl.int32),
53
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
54
+ )
55
+ T = eos - bos
56
+ NT = tl.cdiv(T, BT)
57
+ else:
58
+ NT = tl.cdiv(T, BT)
59
+ i_tg = i_b * NT + i_t
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
63
+ for i_k in range(tl.cdiv(K, BK)):
64
+ p_qg = tl.make_block_ptr(
65
+ qg + (bos * H + i_h) * K,
66
+ (T, K),
67
+ (H * K, 1),
68
+ (i_t * BT, i_k * BK),
69
+ (BT, BK),
70
+ (1, 0),
71
+ )
72
+ p_h = tl.make_block_ptr(
73
+ h + (i_tg * H + i_h) * K * V,
74
+ (K, V),
75
+ (V, 1),
76
+ (i_k * BK, i_v * BV),
77
+ (BK, BV),
78
+ (1, 0),
79
+ )
80
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
81
+ b_h = tl.load(p_h, boundary_check=(0, 1))
82
+ b_o += tl.dot(b_qg, b_h)
83
+
84
+ p_Aqk = tl.make_block_ptr(
85
+ A_qk + (bos * H + i_h) * BT,
86
+ (T, BT),
87
+ (H * BT, 1),
88
+ (i_t * BT, 0),
89
+ (BT, BT),
90
+ (1, 0),
91
+ )
92
+ p_Aqb = tl.make_block_ptr(
93
+ A_qb + (bos * H + i_h) * BT,
94
+ (T, BT),
95
+ (H * BT, 1),
96
+ (i_t * BT, 0),
97
+ (BT, BT),
98
+ (1, 0),
99
+ )
100
+ p_v = tl.make_block_ptr(
101
+ v + (bos * H + i_h) * V,
102
+ (T, V),
103
+ (H * V, 1),
104
+ (i_t * BT, i_v * BV),
105
+ (BT, BV),
106
+ (1, 0),
107
+ )
108
+ p_v_new = tl.make_block_ptr(
109
+ v_new + (bos * H + i_h) * V,
110
+ (T, V),
111
+ (H * V, 1),
112
+ (i_t * BT, i_v * BV),
113
+ (BT, BV),
114
+ (1, 0),
115
+ )
116
+ p_o = tl.make_block_ptr(
117
+ o + (bos * H + i_h) * V,
118
+ (T, V),
119
+ (H * V, 1),
120
+ (i_t * BT, i_v * BV),
121
+ (BT, BV),
122
+ (1, 0),
123
+ )
124
+
125
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
126
+ b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1))
127
+ b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1))
128
+ b_Aqk = tl.where(m_s, b_Aqk, 0)
129
+ b_Aqb = tl.where(m_s, b_Aqb, 0)
130
+ b_v = tl.load(p_v, boundary_check=(0, 1))
131
+ b_v_new = tl.load(p_v_new, boundary_check=(0, 1))
132
+ b_o = (
133
+ b_o
134
+ + tl.dot(b_Aqk.to(b_v.dtype), b_v)
135
+ + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new)
136
+ )
137
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))