rwkv-ops 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (43) hide show
  1. rwkv_ops/__init__.py +26 -0
  2. rwkv_ops/rwkv7_kernel/__init__.py +153 -0
  3. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +221 -0
  4. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  5. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  6. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  7. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  8. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  9. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  10. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  11. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  12. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  13. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  14. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  15. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  16. rwkv_ops/rwkv7_kernel/native_keras_op.py +95 -0
  17. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  18. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  19. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  20. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  21. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  22. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  23. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  24. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  25. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  26. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  27. rwkv_ops/rwkv7_kernel/torch_op.py +523 -0
  28. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  29. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  30. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  31. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  32. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  33. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  34. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  35. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  36. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  37. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  38. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  39. rwkv_ops-0.1.0.dist-info/LICENSE.txt +201 -0
  40. rwkv_ops-0.1.0.dist-info/METADATA +118 -0
  41. rwkv_ops-0.1.0.dist-info/RECORD +43 -0
  42. rwkv_ops-0.1.0.dist-info/WHEEL +5 -0
  43. rwkv_ops-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,382 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import triton
4
+ from .jax_kernel.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
5
+ from .jax_kernel.chunk_A_fwd import chunk_dplr_fwd_intra
6
+ from .jax_kernel.chunk_h_bwd import chunk_dplr_bwd_dhu
7
+ from .jax_kernel.chunk_h_fwd import chunk_dplr_fwd_h
8
+ from .jax_kernel.chunk_o_bwd import (
9
+ chunk_dplr_bwd_dAu,
10
+ chunk_dplr_bwd_dv,
11
+ chunk_dplr_bwd_o,
12
+ )
13
+ from .jax_kernel.chunk_o_fwd import chunk_dplr_fwd_o
14
+ from .jax_kernel.wy_fast_bwd import chunk_dplr_bwd_wy
15
+ from .jax_kernel.wy_fast_fwd import prepare_wy_repr_fwd
16
+ from .jax_kernel.cumsum import chunk_rwkv6_fwd_cumsum
17
+
18
+ CHUNKSIZE = 16
19
+
20
+
21
+ def chunk_dplr_fwd(
22
+ q: jax.Array,
23
+ k: jax.Array,
24
+ v: jax.Array,
25
+ a: jax.Array,
26
+ b: jax.Array,
27
+ gk: jax.Array,
28
+ scale: float,
29
+ initial_state: jax.Array,
30
+ output_final_state: bool,
31
+ chunk_size: int = 16,
32
+ ):
33
+ T = q.shape[1]
34
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
35
+
36
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
37
+
38
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
39
+ q=q,
40
+ k=k,
41
+ a=a,
42
+ b=b,
43
+ gi=gi,
44
+ ge=ge,
45
+ scale=scale,
46
+ chunk_size=BT,
47
+ )
48
+
49
+ del ge
50
+
51
+ # A_ab, A_ak, gi, ge torch.float32
52
+ # A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16
53
+ w, u, _ = prepare_wy_repr_fwd(ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT)
54
+
55
+ del A_ab, A_ak
56
+ h, v_new, final_state = chunk_dplr_fwd_h(
57
+ kg=kg,
58
+ bg=bg,
59
+ v=v,
60
+ w=w,
61
+ u=u,
62
+ gk=gi,
63
+ initial_state=initial_state,
64
+ output_final_state=output_final_state,
65
+ chunk_size=BT,
66
+ )
67
+
68
+ del u, kg, bg, gi
69
+
70
+ o = chunk_dplr_fwd_o(
71
+ qg=qg, v=v, v_new=v_new, A_qk=A_qk, A_qb=A_qb, h=h, chunk_size=BT
72
+ )
73
+ del v_new, h, A_qk, A_qb
74
+
75
+ return o, final_state
76
+
77
+
78
+ def chunk_dplr_delta_rule_fwd(
79
+ q: jax.Array,
80
+ k: jax.Array,
81
+ v: jax.Array,
82
+ a: jax.Array,
83
+ b: jax.Array,
84
+ gk: jax.Array,
85
+ scale=None,
86
+ initial_state=None,
87
+ output_final_state: bool = True,
88
+ ):
89
+ assert q.dtype == k.dtype == v.dtype
90
+ # assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
91
+ # gk = gk.float()
92
+
93
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
94
+ chunk_size = CHUNKSIZE
95
+
96
+ o, final_state = chunk_dplr_fwd(
97
+ q=q,
98
+ k=k,
99
+ v=v,
100
+ a=a,
101
+ b=b,
102
+ gk=gk,
103
+ scale=scale,
104
+ initial_state=initial_state,
105
+ output_final_state=output_final_state,
106
+ chunk_size=chunk_size,
107
+ )
108
+ return o, final_state
109
+
110
+
111
+ def cal_log_w(w: jax.Array) -> jax.Array:
112
+ return -jnp.exp(w)
113
+
114
+
115
+ @jax.custom_vjp
116
+ def chunk_dplr(
117
+ r: jax.Array,
118
+ k: jax.Array,
119
+ v: jax.Array,
120
+ a: jax.Array,
121
+ b: jax.Array,
122
+ gk: jax.Array,
123
+ initial_state: jax.Array = None,
124
+ ):
125
+ return chunk_dplr_delta_rule_fwd(
126
+ q=r,
127
+ k=k,
128
+ v=v,
129
+ a=a,
130
+ b=b,
131
+ gk=gk,
132
+ scale=1,
133
+ initial_state=initial_state,
134
+ output_final_state=True,
135
+ )
136
+
137
+
138
+ def chunk_dplr_fwd_jax(
139
+ r: jax.Array,
140
+ k: jax.Array,
141
+ v: jax.Array,
142
+ a: jax.Array,
143
+ b: jax.Array,
144
+ gk: jax.Array,
145
+ initial_state: jax.Array = None,
146
+ ):
147
+ o, state = chunk_dplr_delta_rule_fwd(
148
+ q=r,
149
+ k=k,
150
+ v=v,
151
+ a=a,
152
+ b=b,
153
+ gk=gk,
154
+ scale=1,
155
+ initial_state=initial_state,
156
+ output_final_state=True,
157
+ )
158
+ cache = (r, k, v, a, b, gk, initial_state)
159
+ return [o, state], cache
160
+
161
+
162
+ def chunk_dplr_bwd(
163
+ q: jax.Array,
164
+ k: jax.Array,
165
+ v: jax.Array,
166
+ a: jax.Array,
167
+ b: jax.Array,
168
+ gk: jax.Array,
169
+ initial_state,
170
+ scale,
171
+ do: jax.Array,
172
+ dht: jax.Array,
173
+ chunk_size: int = CHUNKSIZE,
174
+ ):
175
+ # DTYPE = do.dtype
176
+ BT = chunk_size
177
+ scale = scale
178
+ # if do != None:
179
+ # do = do, q.dtype)
180
+ # if dht != None:
181
+ # dht = dht, q.dtype)
182
+
183
+ # ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
184
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
185
+
186
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
187
+ q=q,
188
+ k=k,
189
+ a=a,
190
+ b=b,
191
+ gi=gi,
192
+ ge=ge,
193
+ scale=scale,
194
+ chunk_size=BT,
195
+ )
196
+ w, u, A_ab_inv = prepare_wy_repr_fwd(
197
+ ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT
198
+ )
199
+ del A_ab
200
+ h, v_new, _ = chunk_dplr_fwd_h(
201
+ kg=kg, bg=bg, v=v, w=w, u=u, gk=gi, initial_state=initial_state, chunk_size=BT
202
+ )
203
+ del u
204
+ # ******* end of recomputation *******
205
+ # A_ak, A_ab_inv, gi, ge torch.float32
206
+ # A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16
207
+
208
+ dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
209
+ v=v, v_new=v_new, do=do, A_qb=A_qb, scale=scale, chunk_size=BT
210
+ )
211
+
212
+ dh, dh0, dv_new = chunk_dplr_bwd_dhu(
213
+ qg=qg,
214
+ bg=bg,
215
+ w=w,
216
+ gk=gi,
217
+ h0=initial_state,
218
+ dht=dht,
219
+ do=do,
220
+ dv=dv_new_intra,
221
+ chunk_size=BT,
222
+ )
223
+
224
+ dv = chunk_dplr_bwd_dv(A_qk=A_qk, kg=kg, do=do, dh=dh, chunk_size=BT)
225
+ del A_qk
226
+
227
+ dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
228
+ k=kg,
229
+ b=bg,
230
+ v=v,
231
+ v_new=v_new,
232
+ do=do,
233
+ h=h,
234
+ dh=dh,
235
+ dv=dv_new,
236
+ w=w,
237
+ gk=gi,
238
+ chunk_size=BT,
239
+ scale=scale,
240
+ )
241
+ del v_new
242
+
243
+ dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
244
+ A_ab_inv=A_ab_inv,
245
+ A_ak=A_ak,
246
+ v=v,
247
+ ag=ag,
248
+ dw=dw,
249
+ du=dv_new,
250
+ dv0=dv,
251
+ chunk_size=BT,
252
+ )
253
+ del A_ak
254
+
255
+ dq, dk, da, db, dgk = chunk_dplr_bwd_dqk_intra(
256
+ q=q,
257
+ k=k,
258
+ a=a,
259
+ b=b,
260
+ gi=gi,
261
+ ge=ge,
262
+ dAqk=dA_qk,
263
+ dAqb=dA_qb,
264
+ dAak=dA_ak,
265
+ dAab=dA_ab,
266
+ dgk_last=dgk_last,
267
+ dqg=dqg,
268
+ dkg=dkg,
269
+ dag=dag,
270
+ dbg=dbg,
271
+ chunk_size=BT,
272
+ scale=scale,
273
+ )
274
+ return (
275
+ jnp.asarray(dq, q.dtype),
276
+ jnp.asarray(dk, k.dtype),
277
+ jnp.asarray(dv, v.dtype),
278
+ jnp.asarray(da, a.dtype),
279
+ jnp.asarray(db, b.dtype),
280
+ jnp.asarray(dgk, gk.dtype),
281
+ None if initial_state is None else jnp.asarray(dh0, initial_state.dtype),
282
+ )
283
+
284
+
285
+ def chunk_dplr_bwd_jax(res, g):
286
+ q, k, v, a, b, gk, initial_state = res
287
+ do, dht = g
288
+ return chunk_dplr_bwd(
289
+ q,
290
+ k,
291
+ v,
292
+ a,
293
+ b,
294
+ gk,
295
+ initial_state,
296
+ scale=1,
297
+ do=do,
298
+ dht=dht,
299
+ )
300
+
301
+
302
+ chunk_dplr.defvjp(chunk_dplr_fwd_jax, chunk_dplr_bwd_jax)
303
+
304
+
305
+ def transpose_head(x, head_first):
306
+ # x = jnp.asarray(x,"bfloat16")
307
+ if head_first:
308
+ return jnp.transpose(x, (0, 2, 1, 3))
309
+ else:
310
+ return x
311
+
312
+
313
+ # @partial(jax.jit, static_argnames=['initial_state',"output_final_state","head_first","use_chunk"])
314
+ def generalized_delta_rule(
315
+ r: jax.Array,
316
+ w: jax.Array,
317
+ k: jax.Array,
318
+ v: jax.Array,
319
+ a: jax.Array,
320
+ b: jax.Array,
321
+ initial_state: jax.Array = None,
322
+ output_final_state: bool = True,
323
+ head_first: bool = False,
324
+ ):
325
+ r"""
326
+ Main interface function for chunked delta rule attention.
327
+
328
+ 分块 Delta Rule 注意力机制的主要接口函数。
329
+
330
+ Args:
331
+ q (jax.Array):
332
+ queries of shape `[B, T, H, K]`
333
+ k (jax.Array):
334
+ keys of shape `[B, T, H, K]`
335
+ v (jax.Array):
336
+ values of shape `[B, T, H, V]`
337
+ a (jax.Array):
338
+ activations of shape `[B, T, H, K]`
339
+ b (jax.Array):
340
+ betas of shape `[B, T, H, K]`
341
+ gk (jax.Array):
342
+ gk of shape `[B, T, H, K]` decay term in log space!
343
+ initial_state (Optional[jax.Array]):
344
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
345
+ For equal-length input sequences, `N` equals the batch size `B`.
346
+ Default: `None`.
347
+ output_final_state (Optional[bool]):
348
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
349
+ head_first (Optional[bool]):
350
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
351
+ Default: `False`.
352
+
353
+ Returns:
354
+ o (jax.Array):
355
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
356
+ final_state (jax.Array):
357
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
358
+ """
359
+ DTYPE = r.dtype
360
+ r = transpose_head(r, head_first)
361
+ k = transpose_head(k, head_first)
362
+ v = transpose_head(v, head_first)
363
+ a = transpose_head(a, head_first)
364
+ b = transpose_head(b, head_first)
365
+
366
+ if w is not None:
367
+ log_w = cal_log_w(w)
368
+ else:
369
+ assert log_w is not None, "Either w or log_w must be provided!"
370
+ log_w = transpose_head(log_w, head_first)
371
+ o, final_state = chunk_dplr(
372
+ r=r,
373
+ k=k,
374
+ v=v,
375
+ a=a,
376
+ b=b,
377
+ gk=log_w,
378
+ initial_state=initial_state,
379
+ )
380
+ if output_final_state:
381
+ return jnp.asarray(o, DTYPE), final_state
382
+ return jnp.asarray(o, DTYPE)
@@ -0,0 +1,95 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+
5
+ def transpose_head(x, head_first):
6
+ """
7
+ 对输入张量进行转置操作。
8
+
9
+ 参数:
10
+ x: 输入张量。
11
+ head_first: 布尔值,决定是否进行转置。
12
+
13
+ 返回:
14
+ 转置后的张量(如果head_first为True),否则返回原张量。
15
+ """
16
+ x = ops.cast(x, "float32")
17
+ if head_first:
18
+ return ops.transpose(x, (0, 2, 1, 3))
19
+ else:
20
+ return x
21
+
22
+
23
+ def generalized_delta_rule(
24
+ r,
25
+ w,
26
+ k,
27
+ v,
28
+ a,
29
+ b,
30
+ initial_state=None,
31
+ output_final_state: bool = True,
32
+ head_first: bool = False,
33
+ ):
34
+ """
35
+ 实现广义delta规则的函数。
36
+
37
+ 参数:
38
+ r: 输入张量。
39
+ w: 权重张量。
40
+ k, v, a, b: 其他输入张量。
41
+ initial_state: 初始状态张量。
42
+ output_final_state: 是否输出最终状态。
43
+ head_first: 是否在计算中将head维度放在第一位。
44
+
45
+ 返回:
46
+ 根据output_final_state参数决定是否返回最终状态。
47
+ """
48
+ DTYPE = r.dtype
49
+ B, T, H, N = ops.shape(r)
50
+ r = transpose_head(r, head_first)
51
+
52
+ k = transpose_head(k, head_first)
53
+
54
+ v = transpose_head(v, head_first)
55
+ a = transpose_head(a, head_first)
56
+ b = transpose_head(b, head_first)
57
+ w = transpose_head(w, head_first)
58
+ w = ops.exp(-ops.exp(w))
59
+
60
+ if initial_state is not None:
61
+ state = initial_state
62
+ if ops.shape(state)[0] == 1:
63
+ state = ops.broadcast_to(state, (B, H, N, N))
64
+ else:
65
+ state = ops.zeros((B, H, N, N), dtype="float32")
66
+ out = ops.zeros((B, T, H, N), dtype=r.dtype)
67
+
68
+ def step(t, inputs):
69
+ """
70
+ 执行单个时间步的计算。
71
+
72
+ 参数:
73
+ t: 当前时间步。
74
+ inputs: 包含当前状态和输出的列表。
75
+
76
+ 返回:
77
+ 更新后的状态和输出。
78
+ """
79
+ state, out = inputs
80
+ kk = ops.reshape(k[:, t, :], (B, H, 1, N))
81
+ rr = ops.reshape(r[:, t, :], (B, H, N, 1))
82
+ vv = ops.reshape(v[:, t, :], (B, H, N, 1))
83
+ aa = ops.reshape(a[:, t, :], (B, H, N, 1))
84
+ bb = ops.reshape(b[:, t, :], (B, H, 1, N))
85
+ state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk
86
+ out = ops.slice_update(
87
+ out, [0, t, 0, 0], ops.reshape((state @ rr), (B, 1, H, N))
88
+ )
89
+ return [state, out]
90
+
91
+ state, out = ops.fori_loop(0, T, step, [state, out])
92
+
93
+ if output_final_state:
94
+ return ops.cast(out, DTYPE), state
95
+ return ops.cast(out, DTYPE)
@@ -0,0 +1,13 @@
1
+ from ..torch_kernel.chunk_A_fwd import *
2
+ from ..torch_kernel.chunk_A_bwd import *
3
+
4
+ # ---------- chunk_h ----------
5
+ from ..torch_kernel.chunk_h_fwd import *
6
+ from ..torch_kernel.chunk_h_bwd import *
7
+
8
+ # ---------- chunk_o ----------
9
+ from ..torch_kernel.chunk_o_fwd import *
10
+ from ..torch_kernel.chunk_o_bwd import *
11
+ from ..torch_kernel.cumsum import *
12
+ from ..torch_kernel.wy_fast_fwd import *
13
+ from ..torch_kernel.wy_fast_bwd import *
@@ -0,0 +1,96 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ import torch
6
+ import triton
7
+ from ..triton_kernel.chunk_A_bwd import *
8
+ from ..triton_kernel.utils import is_gather_supported
9
+ from ..get_torch_devices_info import check_shared_mem
10
+
11
+
12
+ def chunk_dplr_bwd_dqk_intra(
13
+ q: torch.Tensor,
14
+ k: torch.Tensor,
15
+ a: torch.Tensor,
16
+ b: torch.Tensor,
17
+ gi: torch.Tensor,
18
+ ge: torch.Tensor,
19
+ dAqk: torch.Tensor,
20
+ dAqb: torch.Tensor,
21
+ dAak: torch.Tensor,
22
+ dAab: torch.Tensor,
23
+ dqg: torch.Tensor,
24
+ dkg: torch.Tensor,
25
+ dag: torch.Tensor,
26
+ dbg: torch.Tensor,
27
+ dgk_last: torch.Tensor,
28
+ scale: float = 1.0,
29
+ chunk_size: int = 16,
30
+ ):
31
+ B, T, H, K = q.shape
32
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
33
+ BK = (
34
+ min(64, triton.next_power_of_2(K))
35
+ if check_shared_mem()
36
+ else min(32, triton.next_power_of_2(K))
37
+ )
38
+
39
+ NT = triton.cdiv(T, BT)
40
+ NK = triton.cdiv(K, BK)
41
+ grid = (NK, NT, B * H)
42
+
43
+ dq = torch.empty_like(q)
44
+ dk = torch.empty_like(k)
45
+ da = torch.empty_like(a)
46
+ db = torch.empty_like(b)
47
+ dgk = torch.empty_like(gi, dtype=torch.float)
48
+ dgk_offset = torch.empty_like(gi, dtype=torch.float)
49
+
50
+ chunk_dplr_bwd_kernel_intra[grid](
51
+ q=q,
52
+ k=k,
53
+ a=a,
54
+ b=b,
55
+ gi=gi,
56
+ ge=ge,
57
+ dAqk=dAqk,
58
+ dAqb=dAqb,
59
+ dAak=dAak,
60
+ dAab=dAab,
61
+ dq=dq,
62
+ dk=dk,
63
+ dgk=dgk,
64
+ dgk_offset=dgk_offset,
65
+ dqg=dqg,
66
+ dkg=dkg,
67
+ dag=dag,
68
+ dbg=dbg,
69
+ da=da,
70
+ db=db,
71
+ scale=scale,
72
+ T=T,
73
+ H=H,
74
+ K=K,
75
+ BT=BT,
76
+ BC=BT,
77
+ BK=BK,
78
+ GATHER_SUPPORTED=is_gather_supported,
79
+ )
80
+
81
+ dgk_output = torch.empty_like(dgk)
82
+
83
+ def grid(meta):
84
+ return (NT, triton.cdiv(K, meta["BK"]), B * H)
85
+
86
+ chunk_dplr_bwd_dgk_kernel[grid](
87
+ dgk=dgk,
88
+ dgk_offset=dgk_offset,
89
+ dgk_last=dgk_last,
90
+ dgk_output=dgk_output,
91
+ T=T,
92
+ H=H,
93
+ K=K,
94
+ BT=BT,
95
+ )
96
+ return dq, dk, da, db, dgk_output
@@ -0,0 +1,64 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ import torch
6
+ import triton
7
+
8
+ from ..triton_kernel.utils import is_gather_supported
9
+
10
+ from ..triton_kernel.chunk_A_fwd import *
11
+
12
+
13
+ def chunk_dplr_fwd_intra(
14
+ q: torch.Tensor,
15
+ k: torch.Tensor,
16
+ a: torch.Tensor,
17
+ b: torch.Tensor,
18
+ gi: torch.Tensor,
19
+ ge: torch.Tensor,
20
+ scale: float,
21
+ chunk_size: int,
22
+ ):
23
+ B, T, H, K = k.shape
24
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
25
+
26
+ NT = triton.cdiv(T, BT)
27
+
28
+ Aqk = q.new_empty(B, T, H, BT, dtype=q.dtype)
29
+ Aqb = q.new_empty(B, T, H, BT, dtype=q.dtype)
30
+ # involving matrix inverse and it'd be better to use float here.
31
+ Aab = q.new_empty(B, T, H, BT, dtype=torch.float)
32
+ Aak = q.new_empty(B, T, H, BT, dtype=torch.float)
33
+
34
+ grid = (NT, B, H)
35
+ BK = triton.next_power_of_2(K)
36
+ qg = torch.empty_like(q)
37
+ kg = torch.empty_like(k, dtype=q.dtype)
38
+ ag = torch.empty_like(a, dtype=q.dtype)
39
+ bg = torch.empty_like(b, dtype=q.dtype)
40
+ chunk_dplr_fwd_A_kernel_intra_sub_intra[grid](
41
+ q=q,
42
+ k=k,
43
+ a=a,
44
+ b=b,
45
+ gi=gi,
46
+ ge=ge,
47
+ Aqk=Aqk,
48
+ Aqb=Aqb,
49
+ Aab=Aab,
50
+ Aak=Aak,
51
+ qg=qg,
52
+ kg=kg,
53
+ ag=ag,
54
+ bg=bg,
55
+ scale=scale,
56
+ T=T,
57
+ H=H,
58
+ K=K,
59
+ BT=BT,
60
+ BC=BT,
61
+ BK=BK,
62
+ GATHER_SUPPORTED=is_gather_supported,
63
+ )
64
+ return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
@@ -0,0 +1,74 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+
9
+ from ..get_torch_devices_info import check_shared_mem
10
+ from ..triton_kernel.chunk_h_bwd import *
11
+
12
+
13
+ def chunk_dplr_bwd_dhu(
14
+ qg: torch.Tensor,
15
+ bg: torch.Tensor,
16
+ w: torch.Tensor,
17
+ gk: torch.Tensor,
18
+ h0: torch.Tensor,
19
+ dht: Optional[torch.Tensor],
20
+ do: torch.Tensor,
21
+ dv: torch.Tensor,
22
+ chunk_size: int = 64,
23
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
24
+ B, T, H, K, V = *qg.shape, do.shape[-1]
25
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
26
+ BK = triton.next_power_of_2(K)
27
+ assert BK <= 256, (
28
+ "current kernel does not support head dimension being larger than 256."
29
+ )
30
+ # H100
31
+ if check_shared_mem("hopper", qg.device.index):
32
+ BV = 64
33
+ BC = 64 if K <= 128 else 32
34
+ elif check_shared_mem("ampere", qg.device.index): # A100
35
+ BV = 32
36
+ BC = 32
37
+ else: # Etc: 4090
38
+ BV = 16
39
+ BC = 16
40
+
41
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
42
+
43
+ BC = min(BT, BC)
44
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
45
+ assert NK == 1, (
46
+ "NK > 1 is not supported because it involves time-consuming synchronization"
47
+ )
48
+
49
+ dh = qg.new_empty(B, NT, H, K, V)
50
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
51
+ dv2 = torch.zeros_like(dv)
52
+
53
+ grid = (NK, NV, N * H)
54
+ chunk_dplr_bwd_kernel_dhu[grid](
55
+ qg=qg,
56
+ bg=bg,
57
+ w=w,
58
+ gk=gk,
59
+ dht=dht,
60
+ dh0=dh0,
61
+ do=do,
62
+ dh=dh,
63
+ dv=dv,
64
+ dv2=dv2,
65
+ T=T,
66
+ H=H,
67
+ K=K,
68
+ V=V,
69
+ BT=BT,
70
+ BC=BC,
71
+ BK=BK,
72
+ BV=BV,
73
+ )
74
+ return dh, dh0, dv2