rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,34 @@
1
+ from ..triton_kernel.cumsum import *
2
+ import jax_triton as jt
3
+ import jax
4
+ import triton
5
+
6
+
7
+ def chunk_rwkv6_fwd_cumsum(
8
+ g: jax.Array,
9
+ chunk_size: int,
10
+ ) -> jax.Array:
11
+ B, T, H, S = g.shape
12
+ BT = chunk_size
13
+ NT = triton.cdiv(T, BT)
14
+
15
+ out_shapes = [
16
+ jax.ShapeDtypeStruct(g.shape, "float32"),
17
+ jax.ShapeDtypeStruct(g.shape, "float32"),
18
+ ]
19
+
20
+ def grid(meta):
21
+ return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H)
22
+
23
+ gi, ge = jt.triton_call(
24
+ g,
25
+ T,
26
+ H=H,
27
+ S=S,
28
+ BT=BT,
29
+ grid=grid,
30
+ kernel=chunk_rwkv6_fwd_cumsum_kernel,
31
+ out_shape=out_shapes,
32
+ )
33
+
34
+ return gi, ge
@@ -0,0 +1,61 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+ from typing import Tuple
5
+
6
+ import jax_triton as jt
7
+ import jax
8
+ import triton
9
+
10
+
11
+ from ..get_torch_devices_info import check_shared_mem
12
+ from ..triton_kernel.wy_fast_bwd import *
13
+
14
+
15
+ def chunk_dplr_bwd_wy(
16
+ A_ab_inv: jax.Array,
17
+ A_ak: jax.Array,
18
+ v: jax.Array,
19
+ ag: jax.Array,
20
+ dw: jax.Array,
21
+ du: jax.Array,
22
+ dv0: jax.Array,
23
+ chunk_size: int = 16,
24
+ ) -> Tuple[jax.Array, jax.Array, jax.Array]:
25
+ B, T, H, K, V = *dw.shape, du.shape[-1]
26
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
27
+
28
+ NT = triton.cdiv(T, BT)
29
+ BK = min(triton.next_power_of_2(K), 64)
30
+ BV = (
31
+ min(triton.next_power_of_2(V), 64)
32
+ if check_shared_mem()
33
+ else min(triton.next_power_of_2(V), 32)
34
+ )
35
+ grid = (NT, B * H)
36
+ out_shapes = [
37
+ jax.ShapeDtypeStruct(A_ak.shape, "float32"),
38
+ jax.ShapeDtypeStruct(A_ab_inv.shape, "float32"),
39
+ jax.ShapeDtypeStruct(v.shape, v.dtype),
40
+ jax.ShapeDtypeStruct(ag.shape, ag.dtype),
41
+ ]
42
+ dA_ak, dA_ab, dv, dag = jt.triton_call(
43
+ A_ab_inv,
44
+ A_ak,
45
+ ag,
46
+ v,
47
+ dw,
48
+ du,
49
+ dv0,
50
+ T,
51
+ H=H,
52
+ K=K,
53
+ V=V,
54
+ BT=BT,
55
+ BK=BK,
56
+ BV=BV,
57
+ grid=grid,
58
+ kernel=prepare_wy_repr_bwd_kernel,
59
+ out_shape=out_shapes,
60
+ )
61
+ return dA_ab, dA_ak, dv, dag
@@ -0,0 +1,86 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+ from typing import Tuple
5
+
6
+ import jax_triton as jt
7
+ import jax
8
+ import triton
9
+
10
+
11
+ from ..triton_kernel.wy_fast_fwd import *
12
+
13
+
14
+ def wu_fwd(
15
+ ag: jax.Array,
16
+ v: jax.Array,
17
+ A_ak: jax.Array,
18
+ A_ab_inv: jax.Array,
19
+ chunk_size: int,
20
+ ) -> Tuple[jax.Array, jax.Array]:
21
+ B, T, H, K, V = *ag.shape, v.shape[-1]
22
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
23
+
24
+ NT = triton.cdiv(T, BT)
25
+ BK = min(triton.next_power_of_2(K), 64)
26
+ BV = min(triton.next_power_of_2(V), 64)
27
+
28
+ out_shapes = [
29
+ jax.ShapeDtypeStruct(v.shape, v.dtype),
30
+ jax.ShapeDtypeStruct(ag.shape, ag.dtype),
31
+ ]
32
+ grid = (NT, B * H)
33
+ w, u = jt.triton_call(
34
+ ag,
35
+ v,
36
+ A_ab_inv,
37
+ A_ak,
38
+ T,
39
+ H=H,
40
+ K=K,
41
+ V=V,
42
+ BT=BT,
43
+ BK=BK,
44
+ BV=BV,
45
+ grid=grid,
46
+ kernel=wu_fwd_kernel,
47
+ out_shape=out_shapes,
48
+ )
49
+ return w, u
50
+
51
+
52
+ def prepare_wy_repr_fwd(
53
+ ag: jax.Array,
54
+ v: jax.Array,
55
+ A_ak: jax.Array,
56
+ A_ab: jax.Array,
57
+ chunk_size: int = 64,
58
+ ) -> Tuple[jax.Array, jax.Array, jax.Array]:
59
+ B, T, H, _ = ag.shape
60
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
61
+
62
+ NT = triton.cdiv(T, BT)
63
+ BC = min(BT, 32)
64
+ fwd_fn = (
65
+ prepare_wy_repr_fwd_kernel_chunk64
66
+ if BT == 64
67
+ else prepare_wy_repr_fwd_kernel_chunk32
68
+ )
69
+ grid = (NT, B * H)
70
+ A_ab_inv = jt.triton_call(
71
+ A_ab,
72
+ T,
73
+ H=H,
74
+ BT=BT,
75
+ BC=BC,
76
+ grid=grid,
77
+ kernel=fwd_fn,
78
+ out_shape=jax.ShapeDtypeStruct(A_ab.shape, A_ab.dtype),
79
+ )
80
+ w, u = wu_fwd(ag=ag, v=v, A_ak=A_ak, A_ab_inv=A_ab_inv, chunk_size=BT)
81
+ return w, u, A_ab_inv
82
+
83
+
84
+ fwd_prepare_wy_repr = prepare_wy_repr_fwd
85
+
86
+ fwd_wu = wu_fwd
@@ -0,0 +1,382 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import triton
4
+ from .jax_kernel.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
5
+ from .jax_kernel.chunk_A_fwd import chunk_dplr_fwd_intra
6
+ from .jax_kernel.chunk_h_bwd import chunk_dplr_bwd_dhu
7
+ from .jax_kernel.chunk_h_fwd import chunk_dplr_fwd_h
8
+ from .jax_kernel.chunk_o_bwd import (
9
+ chunk_dplr_bwd_dAu,
10
+ chunk_dplr_bwd_dv,
11
+ chunk_dplr_bwd_o,
12
+ )
13
+ from .jax_kernel.chunk_o_fwd import chunk_dplr_fwd_o
14
+ from .jax_kernel.wy_fast_bwd import chunk_dplr_bwd_wy
15
+ from .jax_kernel.wy_fast_fwd import prepare_wy_repr_fwd
16
+ from .jax_kernel.cumsum import chunk_rwkv6_fwd_cumsum
17
+ from jax.ad_checkpoint import checkpoint_policies as cp
18
+
19
+ CHUNKSIZE = 16
20
+
21
+
22
+ def chunk_dplr_fwd(
23
+ q: jax.Array,
24
+ k: jax.Array,
25
+ v: jax.Array,
26
+ a: jax.Array,
27
+ b: jax.Array,
28
+ gk: jax.Array,
29
+ scale: float,
30
+ initial_state: jax.Array,
31
+ output_final_state: bool,
32
+ chunk_size: int = 16,
33
+ ):
34
+ T = q.shape[1]
35
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
36
+
37
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
38
+
39
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
40
+ q=q,
41
+ k=k,
42
+ a=a,
43
+ b=b,
44
+ gi=gi,
45
+ ge=ge,
46
+ scale=scale,
47
+ chunk_size=BT,
48
+ )
49
+
50
+ del ge
51
+
52
+ # A_ab, A_ak, gi, ge torch.float32
53
+ # A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16
54
+ w, u, _ = prepare_wy_repr_fwd(ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT)
55
+
56
+ del A_ab, A_ak
57
+ h, v_new, final_state = chunk_dplr_fwd_h(
58
+ kg=kg,
59
+ bg=bg,
60
+ v=v,
61
+ w=w,
62
+ u=u,
63
+ gk=gi,
64
+ initial_state=initial_state,
65
+ output_final_state=output_final_state,
66
+ chunk_size=BT,
67
+ )
68
+
69
+ del u, kg, bg, gi
70
+
71
+ o = chunk_dplr_fwd_o(
72
+ qg=qg, v=v, v_new=v_new, A_qk=A_qk, A_qb=A_qb, h=h, chunk_size=BT
73
+ )
74
+ del v_new, h, A_qk, A_qb
75
+
76
+ return o, final_state
77
+
78
+
79
+ def chunk_dplr_delta_rule_fwd(
80
+ q: jax.Array,
81
+ k: jax.Array,
82
+ v: jax.Array,
83
+ a: jax.Array,
84
+ b: jax.Array,
85
+ gk: jax.Array,
86
+ scale=None,
87
+ initial_state=None,
88
+ output_final_state: bool = True,
89
+ ):
90
+ assert q.dtype == k.dtype == v.dtype
91
+ # assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
92
+ # gk = gk.float()
93
+
94
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
95
+ chunk_size = CHUNKSIZE
96
+
97
+ o, final_state = chunk_dplr_fwd(
98
+ q=q,
99
+ k=k,
100
+ v=v,
101
+ a=a,
102
+ b=b,
103
+ gk=gk,
104
+ scale=scale,
105
+ initial_state=initial_state,
106
+ output_final_state=output_final_state,
107
+ chunk_size=chunk_size,
108
+ )
109
+ return o, final_state
110
+
111
+
112
+ def cal_log_w(w: jax.Array) -> jax.Array:
113
+ return -jnp.exp(w)
114
+
115
+
116
+ @jax.custom_vjp
117
+ def chunk_dplr(
118
+ r: jax.Array,
119
+ k: jax.Array,
120
+ v: jax.Array,
121
+ a: jax.Array,
122
+ b: jax.Array,
123
+ gk: jax.Array,
124
+ initial_state: jax.Array = None,
125
+ ):
126
+ return chunk_dplr_delta_rule_fwd(
127
+ q=r,
128
+ k=k,
129
+ v=v,
130
+ a=a,
131
+ b=b,
132
+ gk=gk,
133
+ scale=1,
134
+ initial_state=initial_state,
135
+ output_final_state=True,
136
+ )
137
+
138
+
139
+ def chunk_dplr_fwd_jax(
140
+ r: jax.Array,
141
+ k: jax.Array,
142
+ v: jax.Array,
143
+ a: jax.Array,
144
+ b: jax.Array,
145
+ gk: jax.Array,
146
+ initial_state: jax.Array = None,
147
+ ):
148
+ o, state = chunk_dplr_delta_rule_fwd(
149
+ q=r,
150
+ k=k,
151
+ v=v,
152
+ a=a,
153
+ b=b,
154
+ gk=gk,
155
+ scale=1,
156
+ initial_state=initial_state,
157
+ output_final_state=True,
158
+ )
159
+ cache = (r, k, v, a, b, gk, initial_state)
160
+ return (o, state), cache
161
+
162
+
163
+ def chunk_dplr_bwd(
164
+ q: jax.Array,
165
+ k: jax.Array,
166
+ v: jax.Array,
167
+ a: jax.Array,
168
+ b: jax.Array,
169
+ gk: jax.Array,
170
+ initial_state,
171
+ scale,
172
+ do: jax.Array,
173
+ dht: jax.Array,
174
+ chunk_size: int = CHUNKSIZE,
175
+ ):
176
+ # DTYPE = do.dtype
177
+ BT = chunk_size
178
+ scale = scale
179
+ # if do != None:
180
+ # do = do, q.dtype)
181
+ # if dht != None:
182
+ # dht = dht, q.dtype)
183
+
184
+ # ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
185
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
186
+
187
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
188
+ q=q,
189
+ k=k,
190
+ a=a,
191
+ b=b,
192
+ gi=gi,
193
+ ge=ge,
194
+ scale=scale,
195
+ chunk_size=BT,
196
+ )
197
+ w, u, A_ab_inv = prepare_wy_repr_fwd(
198
+ ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT
199
+ )
200
+ del A_ab
201
+ h, v_new, _ = chunk_dplr_fwd_h(
202
+ kg=kg, bg=bg, v=v, w=w, u=u, gk=gi, initial_state=initial_state, chunk_size=BT
203
+ )
204
+ del u
205
+ # ******* end of recomputation *******
206
+ # A_ak, A_ab_inv, gi, ge torch.float32
207
+ # A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16
208
+
209
+ dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
210
+ v=v, v_new=v_new, do=do, A_qb=A_qb, scale=scale, chunk_size=BT
211
+ )
212
+
213
+ dh, dh0, dv_new = chunk_dplr_bwd_dhu(
214
+ qg=qg,
215
+ bg=bg,
216
+ w=w,
217
+ gk=gi,
218
+ h0=initial_state,
219
+ dht=dht,
220
+ do=do,
221
+ dv=dv_new_intra,
222
+ chunk_size=BT,
223
+ )
224
+
225
+ dv = chunk_dplr_bwd_dv(A_qk=A_qk, kg=kg, do=do, dh=dh, chunk_size=BT)
226
+ del A_qk
227
+ dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
228
+ k=kg,
229
+ b=bg,
230
+ v=v,
231
+ v_new=v_new,
232
+ do=do,
233
+ h=h,
234
+ dh=dh,
235
+ dv=dv_new,
236
+ w=w,
237
+ gk=gi,
238
+ chunk_size=BT,
239
+ scale=scale,
240
+ )
241
+ del v_new
242
+ dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
243
+ A_ab_inv=A_ab_inv,
244
+ A_ak=A_ak,
245
+ v=v,
246
+ ag=ag,
247
+ dw=dw,
248
+ du=dv_new,
249
+ dv0=dv,
250
+ chunk_size=BT,
251
+ )
252
+ del A_ak
253
+
254
+ dq, dk, da, db, dgk = chunk_dplr_bwd_dqk_intra(
255
+ q=q,
256
+ k=k,
257
+ a=a,
258
+ b=b,
259
+ gi=gi,
260
+ ge=ge,
261
+ dAqk=dA_qk,
262
+ dAqb=dA_qb,
263
+ dAak=dA_ak,
264
+ dAab=dA_ab,
265
+ dgk_last=dgk_last,
266
+ dqg=dqg,
267
+ dkg=dkg,
268
+ dag=dag,
269
+ dbg=dbg,
270
+ chunk_size=BT,
271
+ scale=scale,
272
+ )
273
+ return (
274
+ jnp.asarray(dq, q.dtype),
275
+ jnp.asarray(dk, k.dtype),
276
+ jnp.asarray(dv, v.dtype),
277
+ jnp.asarray(da, a.dtype),
278
+ jnp.asarray(db, b.dtype),
279
+ jnp.asarray(dgk, gk.dtype),
280
+ None if initial_state is None else jnp.asarray(dh0, initial_state.dtype),
281
+ )
282
+
283
+
284
+ def chunk_dplr_bwd_jax(res, g):
285
+ q, k, v, a, b, gk, initial_state = res
286
+ do, dht = g
287
+ return chunk_dplr_bwd(
288
+ q,
289
+ k,
290
+ v,
291
+ a,
292
+ b,
293
+ gk,
294
+ initial_state,
295
+ scale=1,
296
+ do=do,
297
+ dht=dht,
298
+ )
299
+
300
+
301
+ chunk_dplr.defvjp(chunk_dplr_fwd_jax, chunk_dplr_bwd_jax)
302
+
303
+
304
+ def transpose_head(x, head_first):
305
+ # x = jnp.asarray(x,"bfloat16")
306
+ if head_first:
307
+ return jnp.transpose(x, (0, 2, 1, 3))
308
+ else:
309
+ return x
310
+
311
+
312
+ def generalized_delta_rule(
313
+ r: jax.Array,
314
+ w: jax.Array,
315
+ k: jax.Array,
316
+ v: jax.Array,
317
+ a: jax.Array,
318
+ b: jax.Array,
319
+ initial_state: jax.Array = None,
320
+ output_final_state: bool = True,
321
+ head_first: bool = False,
322
+ ):
323
+ r"""
324
+ Main interface function for chunked delta rule attention.
325
+
326
+ 分块 Delta Rule 注意力机制的主要接口函数。
327
+
328
+ Args:
329
+ q (jax.Array):
330
+ queries of shape `[B, T, H, K]`
331
+ k (jax.Array):
332
+ keys of shape `[B, T, H, K]`
333
+ v (jax.Array):
334
+ values of shape `[B, T, H, V]`
335
+ a (jax.Array):
336
+ activations of shape `[B, T, H, K]`
337
+ b (jax.Array):
338
+ betas of shape `[B, T, H, K]`
339
+ gk (jax.Array):
340
+ gk of shape `[B, T, H, K]` decay term in log space!
341
+ initial_state (Optional[jax.Array]):
342
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
343
+ For equal-length input sequences, `N` equals the batch size `B`.
344
+ Default: `None`.
345
+ output_final_state (Optional[bool]):
346
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
347
+ head_first (Optional[bool]):
348
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
349
+ Default: `False`.
350
+
351
+ Returns:
352
+ o (jax.Array):
353
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
354
+ final_state (jax.Array):
355
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
356
+ """
357
+ DTYPE = r.dtype
358
+ r = transpose_head(r, head_first)
359
+ k = transpose_head(k, head_first)
360
+ v = transpose_head(v, head_first)
361
+ a = transpose_head(a, head_first)
362
+ b = transpose_head(b, head_first)
363
+
364
+ if w is not None:
365
+ log_w = cal_log_w(w)
366
+ else:
367
+ assert log_w is not None, "Either w or log_w must be provided!"
368
+ log_w = transpose_head(log_w, head_first)
369
+ o, final_state = jax.checkpoint(
370
+ chunk_dplr, policy=cp.save_anything_except_these_names(())
371
+ )(
372
+ r=r,
373
+ k=k,
374
+ v=v,
375
+ a=a,
376
+ b=b,
377
+ gk=log_w,
378
+ initial_state=initial_state,
379
+ )
380
+ if output_final_state:
381
+ return jnp.asarray(o, DTYPE), final_state
382
+ return jnp.asarray(o, DTYPE)
@@ -0,0 +1,118 @@
1
+ # copy from https://github.com/ml-explore/mlx-lm/pull/580
2
+
3
+ import mlx.core as mx
4
+
5
+
6
+ def _make_wkv7_kernel():
7
+ if not mx.metal.is_available():
8
+ return None
9
+ source = """
10
+ auto n = thread_position_in_grid.z;
11
+ auto b_idx = n / H;
12
+ auto h_idx = n % H;
13
+ constexpr int n_per_t = D / 32;
14
+ // [B, T, H, D]
15
+ auto r_ = r + b_idx * T * H * D + h_idx * D;
16
+ auto w_ = w + b_idx * T * H * D + h_idx * D;
17
+ auto k_ = k + b_idx * T * H * D + h_idx * D;
18
+ auto v_ = v + b_idx * T * H * D + h_idx * D;
19
+ auto a_ = a + b_idx * T * H * D + h_idx * D;
20
+ auto b_ = b + b_idx * T * H * D + h_idx * D;
21
+ y += b_idx * T * H * D + h_idx * D;
22
+ auto dk_idx = thread_position_in_threadgroup.x;
23
+ auto dv_idx = thread_position_in_grid.y;
24
+ // state_in, state_out: [B, H, D, D]
25
+ auto i_state = state_in + (n * D + dv_idx) * D;
26
+ auto o_state = state_out + (n * D + dv_idx) * D;
27
+ float state[n_per_t];
28
+ for (int i = 0; i < n_per_t; ++i) {
29
+ auto s_idx = n_per_t * dk_idx + i;
30
+ state[i] = static_cast<float>(i_state[s_idx]);
31
+ }
32
+ for (int t = 0; t < T; ++t) {
33
+ float sa = 0.0f;
34
+ for (int i = 0; i < n_per_t; ++i) {
35
+ auto s_idx = n_per_t * dk_idx + i;
36
+ sa += state[i] * a_[s_idx];
37
+ state[i] = state[i] * w_[s_idx];
38
+ }
39
+ sa = simd_sum(sa);
40
+ float out = 0.0f;
41
+ for (int i = 0; i < n_per_t; ++i) {
42
+ auto s_idx = n_per_t * dk_idx + i;
43
+ state[i] = state[i] + k_[s_idx] * v_[dv_idx] + sa * b_[s_idx];
44
+ out += state[i] * r_[s_idx];
45
+ }
46
+ out = simd_sum(out);
47
+ if (thread_index_in_simdgroup == 0) {
48
+ y[dv_idx] = static_cast<InT>(out);
49
+ }
50
+ // Increment data pointers to next time step
51
+ r_ += H * D;
52
+ w_ += H * D;
53
+ k_ += H * D;
54
+ v_ += H * D;
55
+ a_ += H * D;
56
+ b_ += H * D;
57
+ y += H * D;
58
+ }
59
+ for (int i = 0; i < n_per_t; ++i) {
60
+ auto s_idx = n_per_t * dk_idx + i;
61
+ o_state[s_idx] = static_cast<InT>(state[i]);
62
+ }
63
+ """
64
+ inputs = ["r", "w", "k", "v", "a", "b", "state_in", "T"]
65
+ return mx.fast.metal_kernel(
66
+ name="wkv7_kernel",
67
+ input_names=inputs,
68
+ output_names=["y", "state_out"],
69
+ source=source,
70
+ )
71
+
72
+
73
+ _wkv7_kernel = _make_wkv7_kernel()
74
+
75
+
76
+ def transpose_head(x, head_first: bool = True):
77
+ if head_first:
78
+ return mx.transpose(x, (0, 2, 1, 3))
79
+ return x
80
+
81
+
82
+ def generalized_delta_rule(
83
+ r,
84
+ w,
85
+ k,
86
+ v,
87
+ a,
88
+ b,
89
+ initial_state=None,
90
+ output_final_state: bool = True,
91
+ head_first: bool = False,
92
+ ):
93
+ state = initial_state
94
+
95
+ r = transpose_head(r, head_first)
96
+ k = transpose_head(k, head_first)
97
+ v = transpose_head(v, head_first)
98
+ a = transpose_head(a, head_first)
99
+ b = transpose_head(b, head_first)
100
+
101
+ B, T, H, D = r.shape
102
+ input_dtype = r.dtype
103
+
104
+ y, out_state = _wkv7_kernel(
105
+ inputs=[r, w, k, v, a, b, state, T],
106
+ template=[
107
+ ("InT", input_dtype),
108
+ ("H", H),
109
+ ("D", D),
110
+ ],
111
+ grid=(32, D, B * H),
112
+ threadgroup=(32, 4, 1),
113
+ output_shapes=[(B, T, H, D), state.shape],
114
+ output_dtypes=[input_dtype, input_dtype],
115
+ )
116
+ if output_final_state:
117
+ return y, out_state
118
+ return y