rwkv-ops 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (43) hide show
  1. rwkv_ops/__init__.py +26 -0
  2. rwkv_ops/rwkv7_kernel/__init__.py +153 -0
  3. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +221 -0
  4. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  5. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  6. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  7. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  8. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  9. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  10. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  11. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  12. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  13. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  14. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  15. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  16. rwkv_ops/rwkv7_kernel/native_keras_op.py +95 -0
  17. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  18. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  19. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  20. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  21. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  22. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  23. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  24. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  25. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  26. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  27. rwkv_ops/rwkv7_kernel/torch_op.py +523 -0
  28. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  29. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  30. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  31. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  32. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  33. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  34. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  35. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  36. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  37. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  38. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  39. rwkv_ops-0.1.0.dist-info/LICENSE.txt +201 -0
  40. rwkv_ops-0.1.0.dist-info/METADATA +118 -0
  41. rwkv_ops-0.1.0.dist-info/RECORD +43 -0
  42. rwkv_ops-0.1.0.dist-info/WHEEL +5 -0
  43. rwkv_ops-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,328 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from ..triton_kernel.utils import exp, gather, use_cuda_graph
9
+
10
+
11
+ @triton.autotune(
12
+ configs=[
13
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
14
+ for num_warps in [2, 4, 8, 16, 32]
15
+ for num_stages in [2, 3, 4]
16
+ ],
17
+ key=["BK", "BT", "K"],
18
+ use_cuda_graph=use_cuda_graph,
19
+ )
20
+ @triton.jit(do_not_specialize=["T"])
21
+ def chunk_dplr_bwd_kernel_intra(
22
+ q,
23
+ k,
24
+ a,
25
+ b,
26
+ gi,
27
+ ge,
28
+ dAqk,
29
+ dAqb,
30
+ dAak,
31
+ dAab,
32
+ dqg,
33
+ dkg,
34
+ dag,
35
+ dbg,
36
+ T,
37
+ dq,
38
+ dk,
39
+ da,
40
+ db,
41
+ dgk,
42
+ dgk_offset,
43
+ scale: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ BT: tl.constexpr,
47
+ BC: tl.constexpr,
48
+ BK: tl.constexpr,
49
+ GATHER_SUPPORTED: tl.constexpr,
50
+ ):
51
+ i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
52
+ i_b, i_h = i_bh // H, i_bh % H
53
+ if False:
54
+ i_n, i_t = (
55
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
56
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
57
+ )
58
+ bos, eos = (
59
+ tl.load(cu_seqlens + i_n).to(tl.int32),
60
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
61
+ )
62
+ T = eos - bos
63
+ else:
64
+ bos, eos = (i_b * T).to(tl.int32), (i_b * T + T).to(tl.int32)
65
+
66
+ if i_t * BT >= T:
67
+ return
68
+
69
+ # offset calculation
70
+ ge += (bos * H + i_h) * K
71
+ gi += (bos * H + i_h) * K
72
+ q += (bos * H + i_h) * K
73
+ a += (bos * H + i_h) * K
74
+ b += (bos * H + i_h) * K
75
+ k += (bos * H + i_h) * K
76
+ dq += (bos * H + i_h) * K
77
+ dk += (bos * H + i_h) * K
78
+ da += (bos * H + i_h) * K
79
+ db += (bos * H + i_h) * K
80
+ dqg += (bos * H + i_h) * K
81
+ dag += (bos * H + i_h) * K
82
+ dkg += (bos * H + i_h) * K
83
+ dbg += (bos * H + i_h) * K
84
+ dgk += (bos * H + i_h) * K
85
+ dgk_offset += (bos * H + i_h) * K
86
+ dAqk += (bos * H + i_h) * BT
87
+ dAqb += (bos * H + i_h) * BT
88
+ dAak += (bos * H + i_h) * BT
89
+ dAab += (bos * H + i_h) * BT
90
+
91
+ stride_qk = H * K
92
+ stride_A = H * BT
93
+
94
+ p_ge = tl.make_block_ptr(
95
+ ge, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
96
+ )
97
+ p_gi = tl.make_block_ptr(
98
+ gi, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
99
+ )
100
+ # [BC, BK]
101
+ b_ge = tl.load(p_ge, boundary_check=(0, 1))
102
+ b_gi = tl.load(p_gi, boundary_check=(0, 1))
103
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
104
+ b_da = tl.zeros([BC, BK], dtype=tl.float32)
105
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
106
+ b_db = tl.zeros([BC, BK], dtype=tl.float32)
107
+ # intra chunk gradient calculation
108
+ p_dAqk = tl.make_block_ptr(
109
+ dAqk, (T, BT), (stride_A, 1), (i_t * BT, 0), (BC, BC), (1, 0)
110
+ )
111
+ p_dAab = tl.make_block_ptr(
112
+ dAab, (T, BT), (stride_A, 1), (i_t * BT, 0), (BC, BC), (1, 0)
113
+ )
114
+ p_dAqb = tl.make_block_ptr(
115
+ dAqb, (T, BT), (stride_A, 1), (i_t * BT, 0), (BC, BC), (1, 0)
116
+ )
117
+ p_dAak = tl.make_block_ptr(
118
+ dAak, (T, BT), (stride_A, 1), (i_t * BT, 0), (BC, BC), (1, 0)
119
+ )
120
+ o_i = tl.arange(0, BC)
121
+ p_k = tl.make_block_ptr(
122
+ k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
123
+ )
124
+ p_b = tl.make_block_ptr(
125
+ b, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
126
+ )
127
+ p_a = tl.make_block_ptr(
128
+ a, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
129
+ )
130
+ p_q = tl.make_block_ptr(
131
+ q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
132
+ )
133
+ b_k = tl.load(p_k, boundary_check=(0, 1))
134
+ b_b = tl.load(p_b, boundary_check=(0, 1))
135
+ b_q = tl.load(p_q, boundary_check=(0, 1))
136
+ b_a = tl.load(p_a, boundary_check=(0, 1))
137
+ b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1))
138
+ b_dAab = tl.load(p_dAab, boundary_check=(0, 1))
139
+ b_dAqb = tl.load(p_dAqb, boundary_check=(0, 1))
140
+ b_dAak = tl.load(p_dAak, boundary_check=(0, 1))
141
+
142
+ # inter chunk gradient calculation
143
+ o_k = i_k * BK + tl.arange(0, BK)
144
+ m_k = o_k < K
145
+ # intra chunk gradient calculation
146
+ for j in range(0, min(BC, T - i_t * BT)):
147
+ # trick to index the block
148
+ if GATHER_SUPPORTED:
149
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
150
+ col_idx = tl.full([BC, 1], j, dtype=tl.int16)
151
+ row_idx_bc = tl.full([1, BC], j, dtype=tl.int16)
152
+ # [1, BK]
153
+ b_kj = gather(b_k, row_idx, axis=0)
154
+ b_bj = gather(b_b, row_idx, axis=0)
155
+ b_gij = gather(b_gi, row_idx, axis=0)
156
+ b_gej = gather(b_ge, row_idx, axis=0)
157
+ b_qj = gather(b_q, row_idx, axis=0)
158
+ b_aj = gather(b_a, row_idx, axis=0)
159
+ # [BC, 1]
160
+ b_dAqk_j = gather(b_dAqk, col_idx, axis=1)
161
+ b_dAab_j = gather(b_dAab, col_idx, axis=1)
162
+ b_dAqb_j = gather(b_dAqb, col_idx, axis=1)
163
+ b_dAak_j = gather(b_dAak, col_idx, axis=1)
164
+ # [1, BC] -> [BC, 1]
165
+ b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
166
+ b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
167
+ b_dA_ab_j = tl.sum(gather(b_dAab, row_idx_bc, axis=0), 0)[:, None]
168
+ b_dA_qb_j = tl.sum(gather(b_dAqb, row_idx_bc, axis=0), 0)[:, None]
169
+ b_dA_ak_j = tl.sum(gather(b_dAak, row_idx_bc, axis=0), 0)[:, None]
170
+ else:
171
+ mask_idx = tl.arange(0, BC) == j
172
+ b_kj = tl.sum(tl.where(mask_idx[:, None], b_k, 0), 0)[None, :]
173
+ b_bj = tl.sum(tl.where(mask_idx[:, None], b_b, 0), 0)[None, :]
174
+ b_gij = tl.sum(tl.where(mask_idx[:, None], b_gi, 0), 0)[None, :]
175
+ b_gej = tl.sum(tl.where(mask_idx[:, None], b_ge, 0), 0)[None, :]
176
+ b_dAqk_j = tl.sum(tl.where(mask_idx[None, :], b_dAqk, 0), 1)[:, None]
177
+ b_dAab_j = tl.sum(tl.where(mask_idx[None, :], b_dAab, 0), 1)[:, None]
178
+ b_dAqb_j = tl.sum(tl.where(mask_idx[None, :], b_dAqb, 0), 1)[:, None]
179
+ b_dAak_j = tl.sum(tl.where(mask_idx[None, :], b_dAak, 0), 1)[:, None]
180
+ b_dA_qk_j = tl.sum(tl.where(mask_idx[:, None], b_dAqk, 0), 0)[:, None]
181
+ b_dA_ab_j = tl.sum(tl.where(mask_idx[:, None], b_dAab, 0), 0)[:, None]
182
+ b_dA_qb_j = tl.sum(tl.where(mask_idx[:, None], b_dAqb, 0), 0)[:, None]
183
+ b_dA_ak_j = tl.sum(tl.where(mask_idx[:, None], b_dAak, 0), 0)[:, None]
184
+ # [1, BK] b_qj, b_aj
185
+ b_qj = tl.sum(tl.where(mask_idx[:, None], b_q, 0), 0)[None, :]
186
+ b_aj = tl.sum(tl.where(mask_idx[:, None], b_a, 0), 0)[None, :]
187
+
188
+ m_e = o_i[:, None] > j
189
+ m_i = o_i[:, None] >= j
190
+ tmp1 = exp(b_gi - b_gij)
191
+ tmp2 = exp(b_ge - b_gij)
192
+ b_dq += tl.where(m_i, b_dAqk_j * b_kj * tmp1, 0.0)
193
+ b_dq += tl.where(m_i, b_dAqb_j * b_bj * tmp1, 0.0)
194
+ b_da += tl.where(m_e, b_dAab_j * b_bj * tmp2, 0.0)
195
+ b_da += tl.where(m_e, b_dAak_j * b_kj * tmp2, 0.0)
196
+
197
+ m_i = o_i[:, None] <= j
198
+ m_e = o_i[:, None] < j
199
+ tmp1 = exp(b_gij - b_gi)
200
+ tmp2 = exp(b_gej - b_gi)
201
+ b_dk += tl.where(m_i, b_dA_qk_j * b_qj * tmp1, 0.0)
202
+ b_dk += tl.where(m_e, b_dA_ak_j * b_aj * tmp2, 0.0)
203
+ b_db += tl.where(m_i, b_dA_qb_j * b_qj * tmp1, 0.0)
204
+ b_db += tl.where(m_e, b_dA_ab_j * b_aj * tmp2, 0.0)
205
+
206
+ # post processing
207
+ p_dq = tl.make_block_ptr(
208
+ dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
209
+ )
210
+ p_dk = tl.make_block_ptr(
211
+ dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
212
+ )
213
+ p_da = tl.make_block_ptr(
214
+ da, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
215
+ )
216
+ p_db = tl.make_block_ptr(
217
+ db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
218
+ )
219
+ p_dgk = tl.make_block_ptr(
220
+ dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
221
+ )
222
+ p_dgk_offset = tl.make_block_ptr(
223
+ dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
224
+ )
225
+ p_dqg = tl.make_block_ptr(
226
+ dqg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
227
+ )
228
+ p_dkg = tl.make_block_ptr(
229
+ dkg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
230
+ )
231
+ p_dag = tl.make_block_ptr(
232
+ dag, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
233
+ )
234
+ p_dbg = tl.make_block_ptr(
235
+ dbg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
236
+ )
237
+ p_gn = gi + (min(i_t * BT + BT, T) - 1) * stride_qk + o_k
238
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
239
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
240
+ b_da += tl.load(p_dag, boundary_check=(0, 1)) * exp(b_ge)
241
+ b_dq += tl.load(p_dqg, boundary_check=(0, 1)) * exp(b_gi) * scale
242
+ tmp = exp(b_gn[None, :] - b_gi)
243
+ b_dk += tl.load(p_dkg, boundary_check=(0, 1)).to(tl.float32) * tmp
244
+ b_db += tl.load(p_dbg, boundary_check=(0, 1)).to(tl.float32) * tmp
245
+ tl.store(p_dq, (b_dq).to(p_dq.dtype.element_ty), boundary_check=(0, 1))
246
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
247
+ tl.store(p_da, b_da.to(p_da.dtype.element_ty), boundary_check=(0, 1))
248
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1))
249
+ b_dgk = (b_dq * b_q + b_da * b_a - b_dk * b_k - b_db * b_b).to(tl.float32)
250
+ b_dgk_offset = b_da * b_a
251
+ tl.store(p_dgk, b_dgk.to(p_dgk.dtype.element_ty), boundary_check=(0, 1))
252
+ tl.store(
253
+ p_dgk_offset,
254
+ b_dgk_offset.to(p_dgk_offset.dtype.element_ty),
255
+ boundary_check=(0, 1),
256
+ )
257
+
258
+
259
+ @triton.autotune(
260
+ configs=[
261
+ triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
262
+ for num_warps in [2, 4, 8, 16, 32]
263
+ for num_stages in [2, 3, 4]
264
+ for BK in [32, 64]
265
+ ],
266
+ key=["BK", "BT", "K"],
267
+ use_cuda_graph=use_cuda_graph,
268
+ )
269
+ @triton.jit(do_not_specialize=["T"])
270
+ def chunk_dplr_bwd_dgk_kernel(
271
+ dgk,
272
+ dgk_offset,
273
+ dgk_last,
274
+ T,
275
+ dgk_output,
276
+ H: tl.constexpr,
277
+ K: tl.constexpr,
278
+ BT: tl.constexpr,
279
+ BK: tl.constexpr,
280
+ ):
281
+ i_t, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
282
+ i_b, i_h = i_bh // H, i_bh % H
283
+ if False:
284
+ i_tg = i_t
285
+ i_n, i_t = (
286
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
287
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
288
+ )
289
+ bos, eos = (
290
+ tl.load(cu_seqlens + i_n).to(tl.int32),
291
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
292
+ )
293
+ T = eos - bos
294
+ NT = tl.cdiv(T, BT)
295
+ else:
296
+ NT = tl.cdiv(T, BT)
297
+ i_tg = (i_b * NT + i_t).to(tl.int32)
298
+ bos, eos = (i_b * T).to(tl.int32), (i_b * T + T).to(tl.int32)
299
+
300
+ stride_qk = H * K
301
+ dgk += (bos * H + i_h) * K
302
+ dgk_offset += (bos * H + i_h) * K
303
+ dgk_last += (i_tg * H + i_h) * K
304
+ dgk_output += (bos * H + i_h) * K
305
+ p_dgk_last = dgk_last + tl.arange(0, BK) + i_k * BK
306
+ m_k = tl.arange(0, BK) + i_k * BK < K
307
+ b_dgk_last = tl.load(p_dgk_last, mask=m_k, other=0)
308
+ p_dgk_offset = tl.make_block_ptr(
309
+ dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
310
+ )
311
+ p_dgk = tl.make_block_ptr(
312
+ dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
313
+ )
314
+ b_dgk = tl.load(p_dgk, boundary_check=(0, 1))
315
+ b_dgk_offset = tl.load(p_dgk_offset, boundary_check=(0, 1))
316
+ # m_inv_cumsum = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]).to(tl.float32)
317
+ # b_dgk_cumsum = tl.dot(m_inv_cumsum, b_dgk, allow_tf32=False)
318
+ b_dgk_cumsum = tl.cumsum(b_dgk, 0, reverse=True)
319
+ b_dgk_cumsum += b_dgk_last[None, :]
320
+ b_dgk_cumsum -= b_dgk_offset
321
+ p_dgk_output = tl.make_block_ptr(
322
+ dgk_output, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
323
+ )
324
+ tl.store(
325
+ p_dgk_output,
326
+ b_dgk_cumsum.to(p_dgk_output.dtype.element_ty),
327
+ boundary_check=(0, 1),
328
+ )
@@ -0,0 +1,186 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from ..triton_kernel.utils import exp, gather, use_cuda_graph
9
+
10
+
11
+ @triton.autotune(
12
+ configs=[
13
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
14
+ for num_warps in [2, 4, 8, 16, 32]
15
+ for num_stages in [2, 3, 4]
16
+ ],
17
+ key=["BK", "BT"],
18
+ use_cuda_graph=use_cuda_graph,
19
+ )
20
+ @triton.jit(do_not_specialize=["T"])
21
+ def chunk_dplr_fwd_A_kernel_intra_sub_intra(
22
+ q,
23
+ k,
24
+ a,
25
+ b,
26
+ gi,
27
+ ge,
28
+ T,
29
+ qg,
30
+ kg,
31
+ ag,
32
+ bg,
33
+ Aqk,
34
+ Aqb,
35
+ Aab,
36
+ Aak,
37
+ scale: tl.constexpr,
38
+ H: tl.constexpr,
39
+ K: tl.constexpr,
40
+ BT: tl.constexpr,
41
+ BC: tl.constexpr,
42
+ BK: tl.constexpr,
43
+ GATHER_SUPPORTED: tl.constexpr,
44
+ ):
45
+ i_t, i_b, i_h = tl.program_id(0), tl.program_id(1), tl.program_id(2)
46
+
47
+ if False:
48
+ i_n, i_t = (
49
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
50
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
51
+ )
52
+ bos, eos = (
53
+ tl.load(cu_seqlens + i_n).to(tl.int32),
54
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
55
+ )
56
+ T = eos - bos
57
+ else:
58
+ bos, eos = i_b * T, i_b * T + T
59
+
60
+ if i_t * BT >= T:
61
+ return
62
+
63
+ o_i = tl.arange(0, BC)
64
+ o_k = tl.arange(0, BK)
65
+ m_k = o_k < K
66
+ m_A = (i_t * BT + tl.arange(0, BC)) < T
67
+ last_idx = min((i_t + 1) * BT, T) - 1
68
+ o_A = (bos + i_t * BT + tl.arange(0, BC)) * H * BT + i_h * BT
69
+ p_q = tl.make_block_ptr(
70
+ q + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
71
+ )
72
+ p_k = tl.make_block_ptr(
73
+ k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
74
+ )
75
+ p_a = tl.make_block_ptr(
76
+ a + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
77
+ )
78
+ p_b = tl.make_block_ptr(
79
+ b + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
80
+ )
81
+ p_gi = tl.make_block_ptr(
82
+ gi + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
83
+ )
84
+ p_ge = tl.make_block_ptr(
85
+ ge + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
86
+ )
87
+ p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK)
88
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
89
+ p_qg = tl.make_block_ptr(
90
+ qg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
91
+ )
92
+ p_kg = tl.make_block_ptr(
93
+ kg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
94
+ )
95
+ p_ag = tl.make_block_ptr(
96
+ ag + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
97
+ )
98
+ p_bg = tl.make_block_ptr(
99
+ bg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
100
+ )
101
+
102
+ b_q = tl.load(p_q, boundary_check=(0, 1))
103
+ b_q = b_q * scale
104
+ b_k = tl.load(p_k, boundary_check=(0, 1))
105
+ b_a = tl.load(p_a, boundary_check=(0, 1))
106
+ b_b = tl.load(p_b, boundary_check=(0, 1))
107
+ b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32)
108
+ b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32)
109
+
110
+ # deal with decay term.
111
+ g_exp = exp(b_gi)
112
+ g_exp_inv = exp(-b_gi + b_g_last[None, :])
113
+ b_qg = b_q * g_exp
114
+ b_kg = b_k * g_exp_inv
115
+ b_bg = b_b * g_exp_inv
116
+ b_ag = b_a * exp(b_ge)
117
+ tl.store(
118
+ p_qg,
119
+ b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"),
120
+ boundary_check=(0, 1),
121
+ )
122
+ tl.store(
123
+ p_bg,
124
+ b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"),
125
+ boundary_check=(0, 1),
126
+ )
127
+ tl.store(
128
+ p_ag,
129
+ b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"),
130
+ boundary_check=(0, 1),
131
+ )
132
+ tl.store(
133
+ p_kg,
134
+ b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"),
135
+ boundary_check=(0, 1),
136
+ )
137
+ # tl.debug_barrier()
138
+
139
+ b_q = b_q.to(b_k.dtype)
140
+ # inner attn
141
+ for j in range(0, min(BC, T - i_t * BT)):
142
+ # a trick to index the j-th row of b_k, b_g, b_b
143
+ if GATHER_SUPPORTED:
144
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
145
+ # [1, BK]
146
+ b_k_j = gather(b_k, row_idx, axis=0)
147
+ b_gk_j = gather(b_gi, row_idx, axis=0)
148
+ b_b_j = gather(b_b, row_idx, axis=0)
149
+ else:
150
+ mask = tl.arange(0, BC) == j
151
+ b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :]
152
+ b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :]
153
+ b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :]
154
+ tmp = exp(b_gi - b_gk_j)
155
+ b_A_qk = tl.sum(b_q * b_k_j * tmp, 1)
156
+ m_i = (o_i >= j).to(tl.float32)
157
+ b_A_qk = b_A_qk * m_i
158
+ b_A_qb = tl.sum(b_q * b_b_j * tmp, 1)
159
+ b_A_qb = b_A_qb * m_i
160
+ tmp2 = exp(b_ge - b_gk_j)
161
+ b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1)
162
+ m_i2 = (o_i > j).to(tl.float32)
163
+ b_A_ak = b_A_ak * m_i2
164
+ b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1)
165
+ b_A_ab = b_A_ab * m_i2
166
+
167
+ tl.store(
168
+ Aqk + o_A + j,
169
+ b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"),
170
+ mask=m_A,
171
+ )
172
+ tl.store(
173
+ Aqb + o_A + j,
174
+ b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"),
175
+ mask=m_A,
176
+ )
177
+ tl.store(
178
+ Aab + o_A + j,
179
+ b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"),
180
+ mask=m_A,
181
+ )
182
+ tl.store(
183
+ Aak + o_A + j,
184
+ b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"),
185
+ mask=m_A,
186
+ )
@@ -0,0 +1,157 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from ..triton_kernel.utils import exp, use_cuda_graph
9
+
10
+
11
+ @triton.heuristics(
12
+ {
13
+ "USE_FINAL_STATE_GRADIENT": lambda args: args["dht"] is not None,
14
+ "USE_INITIAL_STATE": lambda args: args["dh0"] is not None,
15
+ }
16
+ )
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ for num_warps in [2, 4, 8, 16, 32]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=["BT", "BK", "BV", "V"],
24
+ use_cuda_graph=use_cuda_graph,
25
+ )
26
+ @triton.jit(do_not_specialize=["T"])
27
+ def chunk_dplr_bwd_kernel_dhu(
28
+ qg,
29
+ bg,
30
+ w,
31
+ gk,
32
+ dht,
33
+ dv,
34
+ do,
35
+ T,
36
+ dh,
37
+ dh0,
38
+ dv2,
39
+ H: tl.constexpr,
40
+ K: tl.constexpr,
41
+ V: tl.constexpr,
42
+ BT: tl.constexpr,
43
+ BC: tl.constexpr,
44
+ BK: tl.constexpr,
45
+ BV: tl.constexpr,
46
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
47
+ USE_INITIAL_STATE: tl.constexpr,
48
+ ):
49
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
50
+ i_n, i_h = i_nh // H, i_nh % H
51
+ if False:
52
+ bos, eos = (
53
+ tl.load(cu_seqlens + i_n).to(tl.int32),
54
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
55
+ )
56
+ T = eos - bos
57
+ NT = tl.cdiv(T, BT)
58
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
59
+ else:
60
+ bos, eos = i_n * T, i_n * T + T
61
+ NT = tl.cdiv(T, BT)
62
+ boh = i_n * NT
63
+
64
+ # [BK, BV]
65
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
66
+ if USE_FINAL_STATE_GRADIENT:
67
+ p_dht = tl.make_block_ptr(
68
+ dht + i_nh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
69
+ )
70
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
71
+
72
+ mask_k = tl.arange(0, BK) < K
73
+ for i_t in range(NT - 1, -1, -1):
74
+ p_dh = tl.make_block_ptr(
75
+ dh + ((boh + i_t) * H + i_h) * K * V,
76
+ (K, V),
77
+ (V, 1),
78
+ (i_k * BK, i_v * BV),
79
+ (BK, BV),
80
+ (1, 0),
81
+ )
82
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
83
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
84
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
85
+ p_qg = tl.make_block_ptr(
86
+ qg + (bos * H + i_h) * K,
87
+ (K, T),
88
+ (1, H * K),
89
+ (i_k * BK, i_t * BT + i_c * BC),
90
+ (BK, BC),
91
+ (0, 1),
92
+ )
93
+ p_bg = tl.make_block_ptr(
94
+ bg + (bos * H + i_h) * K,
95
+ (T, K),
96
+ (H * K, 1),
97
+ (i_t * BT + i_c * BC, i_k * BK),
98
+ (BC, BK),
99
+ (1, 0),
100
+ )
101
+ p_w = tl.make_block_ptr(
102
+ w + (bos * H + i_h) * K,
103
+ (K, T),
104
+ (1, H * K),
105
+ (i_k * BK, i_t * BT + i_c * BC),
106
+ (BK, BC),
107
+ (0, 1),
108
+ )
109
+ p_dv = tl.make_block_ptr(
110
+ dv + (bos * H + i_h) * V,
111
+ (T, V),
112
+ (H * V, 1),
113
+ (i_t * BT + i_c * BC, i_v * BV),
114
+ (BC, BV),
115
+ (1, 0),
116
+ )
117
+ p_do = tl.make_block_ptr(
118
+ do + (bos * H + i_h) * V,
119
+ (T, V),
120
+ (H * V, 1),
121
+ (i_t * BT + i_c * BC, i_v * BV),
122
+ (BC, BV),
123
+ (1, 0),
124
+ )
125
+ p_dv2 = tl.make_block_ptr(
126
+ dv2 + (bos * H + i_h) * V,
127
+ (T, V),
128
+ (H * V, 1),
129
+ (i_t * BT + i_c * BC, i_v * BV),
130
+ (BC, BV),
131
+ (1, 0),
132
+ )
133
+ # [BK, BT]
134
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
135
+ # [BT, BK]
136
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
137
+ b_w = tl.load(p_w, boundary_check=(0, 1))
138
+ # [BT, V]
139
+ b_do = tl.load(p_do, boundary_check=(0, 1))
140
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
141
+ b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype))
142
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
143
+ # [BK, BV]
144
+ b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype))
145
+ b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype))
146
+ last_idx = min((i_t + 1) * BT, T) - 1
147
+ bg_last = tl.load(
148
+ gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k
149
+ )
150
+ b_dh *= exp(bg_last)[:, None]
151
+ b_dh += b_dh_tmp
152
+
153
+ if USE_INITIAL_STATE:
154
+ p_dh0 = tl.make_block_ptr(
155
+ dh0 + i_nh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
156
+ )
157
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))