rwkv-ops 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_ops/__init__.py +45 -0
- rwkv_ops/mhc_kernel/__init__.py +50 -0
- rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
- rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
- rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
- rwkv_ops/rwkv6_kernel/__init__.py +120 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +113 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.6.1.dist-info/METADATA +495 -0
- rwkv_ops-0.6.1.dist-info/RECORD +89 -0
- rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
- rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from ..triton_kernel.cumsum import *
|
|
2
|
+
import jax_triton as jt
|
|
3
|
+
import jax
|
|
4
|
+
import triton
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def chunk_rwkv6_fwd_cumsum(
|
|
8
|
+
g: jax.Array,
|
|
9
|
+
chunk_size: int,
|
|
10
|
+
) -> jax.Array:
|
|
11
|
+
B, T, H, S = g.shape
|
|
12
|
+
BT = chunk_size
|
|
13
|
+
NT = triton.cdiv(T, BT)
|
|
14
|
+
|
|
15
|
+
out_shapes = [
|
|
16
|
+
jax.ShapeDtypeStruct(g.shape, "float32"),
|
|
17
|
+
jax.ShapeDtypeStruct(g.shape, "float32"),
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
def grid(meta):
|
|
21
|
+
return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H)
|
|
22
|
+
|
|
23
|
+
gi, ge = jt.triton_call(
|
|
24
|
+
g,
|
|
25
|
+
T,
|
|
26
|
+
H=H,
|
|
27
|
+
S=S,
|
|
28
|
+
BT=BT,
|
|
29
|
+
grid=grid,
|
|
30
|
+
kernel=chunk_rwkv6_fwd_cumsum_kernel,
|
|
31
|
+
out_shape=out_shapes,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
return gi, ge
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025,Qingwen Lin
|
|
3
|
+
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import jax_triton as jt
|
|
7
|
+
import jax
|
|
8
|
+
import triton
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from ..get_torch_devices_info import check_shared_mem
|
|
12
|
+
from ..triton_kernel.wy_fast_bwd import *
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def chunk_dplr_bwd_wy(
|
|
16
|
+
A_ab_inv: jax.Array,
|
|
17
|
+
A_ak: jax.Array,
|
|
18
|
+
v: jax.Array,
|
|
19
|
+
ag: jax.Array,
|
|
20
|
+
dw: jax.Array,
|
|
21
|
+
du: jax.Array,
|
|
22
|
+
dv0: jax.Array,
|
|
23
|
+
chunk_size: int = 16,
|
|
24
|
+
) -> Tuple[jax.Array, jax.Array, jax.Array]:
|
|
25
|
+
B, T, H, K, V = *dw.shape, du.shape[-1]
|
|
26
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
27
|
+
|
|
28
|
+
NT = triton.cdiv(T, BT)
|
|
29
|
+
BK = min(triton.next_power_of_2(K), 64)
|
|
30
|
+
BV = (
|
|
31
|
+
min(triton.next_power_of_2(V), 64)
|
|
32
|
+
if check_shared_mem()
|
|
33
|
+
else min(triton.next_power_of_2(V), 32)
|
|
34
|
+
)
|
|
35
|
+
grid = (NT, B * H)
|
|
36
|
+
out_shapes = [
|
|
37
|
+
jax.ShapeDtypeStruct(A_ak.shape, "float32"),
|
|
38
|
+
jax.ShapeDtypeStruct(A_ab_inv.shape, "float32"),
|
|
39
|
+
jax.ShapeDtypeStruct(v.shape, v.dtype),
|
|
40
|
+
jax.ShapeDtypeStruct(ag.shape, ag.dtype),
|
|
41
|
+
]
|
|
42
|
+
dA_ak, dA_ab, dv, dag = jt.triton_call(
|
|
43
|
+
A_ab_inv,
|
|
44
|
+
A_ak,
|
|
45
|
+
ag,
|
|
46
|
+
v,
|
|
47
|
+
dw,
|
|
48
|
+
du,
|
|
49
|
+
dv0,
|
|
50
|
+
T,
|
|
51
|
+
H=H,
|
|
52
|
+
K=K,
|
|
53
|
+
V=V,
|
|
54
|
+
BT=BT,
|
|
55
|
+
BK=BK,
|
|
56
|
+
BV=BV,
|
|
57
|
+
grid=grid,
|
|
58
|
+
kernel=prepare_wy_repr_bwd_kernel,
|
|
59
|
+
out_shape=out_shapes,
|
|
60
|
+
)
|
|
61
|
+
return dA_ab, dA_ak, dv, dag
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025,Qingwen Lin
|
|
3
|
+
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import jax_triton as jt
|
|
7
|
+
import jax
|
|
8
|
+
import triton
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from ..triton_kernel.wy_fast_fwd import *
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def wu_fwd(
|
|
15
|
+
ag: jax.Array,
|
|
16
|
+
v: jax.Array,
|
|
17
|
+
A_ak: jax.Array,
|
|
18
|
+
A_ab_inv: jax.Array,
|
|
19
|
+
chunk_size: int,
|
|
20
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
21
|
+
B, T, H, K, V = *ag.shape, v.shape[-1]
|
|
22
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
23
|
+
|
|
24
|
+
NT = triton.cdiv(T, BT)
|
|
25
|
+
BK = min(triton.next_power_of_2(K), 64)
|
|
26
|
+
BV = min(triton.next_power_of_2(V), 64)
|
|
27
|
+
|
|
28
|
+
out_shapes = [
|
|
29
|
+
jax.ShapeDtypeStruct(v.shape, v.dtype),
|
|
30
|
+
jax.ShapeDtypeStruct(ag.shape, ag.dtype),
|
|
31
|
+
]
|
|
32
|
+
grid = (NT, B * H)
|
|
33
|
+
w, u = jt.triton_call(
|
|
34
|
+
ag,
|
|
35
|
+
v,
|
|
36
|
+
A_ab_inv,
|
|
37
|
+
A_ak,
|
|
38
|
+
T,
|
|
39
|
+
H=H,
|
|
40
|
+
K=K,
|
|
41
|
+
V=V,
|
|
42
|
+
BT=BT,
|
|
43
|
+
BK=BK,
|
|
44
|
+
BV=BV,
|
|
45
|
+
grid=grid,
|
|
46
|
+
kernel=wu_fwd_kernel,
|
|
47
|
+
out_shape=out_shapes,
|
|
48
|
+
)
|
|
49
|
+
return w, u
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def prepare_wy_repr_fwd(
|
|
53
|
+
ag: jax.Array,
|
|
54
|
+
v: jax.Array,
|
|
55
|
+
A_ak: jax.Array,
|
|
56
|
+
A_ab: jax.Array,
|
|
57
|
+
chunk_size: int = 64,
|
|
58
|
+
) -> Tuple[jax.Array, jax.Array, jax.Array]:
|
|
59
|
+
B, T, H, _ = ag.shape
|
|
60
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
61
|
+
|
|
62
|
+
NT = triton.cdiv(T, BT)
|
|
63
|
+
BC = min(BT, 32)
|
|
64
|
+
fwd_fn = (
|
|
65
|
+
prepare_wy_repr_fwd_kernel_chunk64
|
|
66
|
+
if BT == 64
|
|
67
|
+
else prepare_wy_repr_fwd_kernel_chunk32
|
|
68
|
+
)
|
|
69
|
+
grid = (NT, B * H)
|
|
70
|
+
A_ab_inv = jt.triton_call(
|
|
71
|
+
A_ab,
|
|
72
|
+
T,
|
|
73
|
+
H=H,
|
|
74
|
+
BT=BT,
|
|
75
|
+
BC=BC,
|
|
76
|
+
grid=grid,
|
|
77
|
+
kernel=fwd_fn,
|
|
78
|
+
out_shape=jax.ShapeDtypeStruct(A_ab.shape, A_ab.dtype),
|
|
79
|
+
)
|
|
80
|
+
w, u = wu_fwd(ag=ag, v=v, A_ak=A_ak, A_ab_inv=A_ab_inv, chunk_size=BT)
|
|
81
|
+
return w, u, A_ab_inv
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
fwd_prepare_wy_repr = prepare_wy_repr_fwd
|
|
85
|
+
|
|
86
|
+
fwd_wu = wu_fwd
|
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import triton
|
|
4
|
+
from .jax_kernel.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
|
|
5
|
+
from .jax_kernel.chunk_A_fwd import chunk_dplr_fwd_intra
|
|
6
|
+
from .jax_kernel.chunk_h_bwd import chunk_dplr_bwd_dhu
|
|
7
|
+
from .jax_kernel.chunk_h_fwd import chunk_dplr_fwd_h
|
|
8
|
+
from .jax_kernel.chunk_o_bwd import (
|
|
9
|
+
chunk_dplr_bwd_dAu,
|
|
10
|
+
chunk_dplr_bwd_dv,
|
|
11
|
+
chunk_dplr_bwd_o,
|
|
12
|
+
)
|
|
13
|
+
from .jax_kernel.chunk_o_fwd import chunk_dplr_fwd_o
|
|
14
|
+
from .jax_kernel.wy_fast_bwd import chunk_dplr_bwd_wy
|
|
15
|
+
from .jax_kernel.wy_fast_fwd import prepare_wy_repr_fwd
|
|
16
|
+
from .jax_kernel.cumsum import chunk_rwkv6_fwd_cumsum
|
|
17
|
+
from jax.ad_checkpoint import checkpoint_policies as cp
|
|
18
|
+
|
|
19
|
+
CHUNKSIZE = 16
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def chunk_dplr_fwd(
|
|
23
|
+
q: jax.Array,
|
|
24
|
+
k: jax.Array,
|
|
25
|
+
v: jax.Array,
|
|
26
|
+
a: jax.Array,
|
|
27
|
+
b: jax.Array,
|
|
28
|
+
gk: jax.Array,
|
|
29
|
+
scale: float,
|
|
30
|
+
initial_state: jax.Array,
|
|
31
|
+
output_final_state: bool,
|
|
32
|
+
chunk_size: int = 16,
|
|
33
|
+
):
|
|
34
|
+
T = q.shape[1]
|
|
35
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
36
|
+
|
|
37
|
+
gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
|
|
38
|
+
|
|
39
|
+
A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
|
|
40
|
+
q=q,
|
|
41
|
+
k=k,
|
|
42
|
+
a=a,
|
|
43
|
+
b=b,
|
|
44
|
+
gi=gi,
|
|
45
|
+
ge=ge,
|
|
46
|
+
scale=scale,
|
|
47
|
+
chunk_size=BT,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
del ge
|
|
51
|
+
|
|
52
|
+
# A_ab, A_ak, gi, ge torch.float32
|
|
53
|
+
# A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16
|
|
54
|
+
w, u, _ = prepare_wy_repr_fwd(ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT)
|
|
55
|
+
|
|
56
|
+
del A_ab, A_ak
|
|
57
|
+
h, v_new, final_state = chunk_dplr_fwd_h(
|
|
58
|
+
kg=kg,
|
|
59
|
+
bg=bg,
|
|
60
|
+
v=v,
|
|
61
|
+
w=w,
|
|
62
|
+
u=u,
|
|
63
|
+
gk=gi,
|
|
64
|
+
initial_state=initial_state,
|
|
65
|
+
output_final_state=output_final_state,
|
|
66
|
+
chunk_size=BT,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
del u, kg, bg, gi
|
|
70
|
+
|
|
71
|
+
o = chunk_dplr_fwd_o(
|
|
72
|
+
qg=qg, v=v, v_new=v_new, A_qk=A_qk, A_qb=A_qb, h=h, chunk_size=BT
|
|
73
|
+
)
|
|
74
|
+
del v_new, h, A_qk, A_qb
|
|
75
|
+
|
|
76
|
+
return o, final_state
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def chunk_dplr_delta_rule_fwd(
|
|
80
|
+
q: jax.Array,
|
|
81
|
+
k: jax.Array,
|
|
82
|
+
v: jax.Array,
|
|
83
|
+
a: jax.Array,
|
|
84
|
+
b: jax.Array,
|
|
85
|
+
gk: jax.Array,
|
|
86
|
+
scale=None,
|
|
87
|
+
initial_state=None,
|
|
88
|
+
output_final_state: bool = True,
|
|
89
|
+
):
|
|
90
|
+
assert q.dtype == k.dtype == v.dtype
|
|
91
|
+
# assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
|
|
92
|
+
# gk = gk.float()
|
|
93
|
+
|
|
94
|
+
scale = k.shape[-1] ** -0.5 if scale is None else scale
|
|
95
|
+
chunk_size = CHUNKSIZE
|
|
96
|
+
|
|
97
|
+
o, final_state = chunk_dplr_fwd(
|
|
98
|
+
q=q,
|
|
99
|
+
k=k,
|
|
100
|
+
v=v,
|
|
101
|
+
a=a,
|
|
102
|
+
b=b,
|
|
103
|
+
gk=gk,
|
|
104
|
+
scale=scale,
|
|
105
|
+
initial_state=initial_state,
|
|
106
|
+
output_final_state=output_final_state,
|
|
107
|
+
chunk_size=chunk_size,
|
|
108
|
+
)
|
|
109
|
+
return o, final_state
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def cal_log_w(w: jax.Array) -> jax.Array:
|
|
113
|
+
return -jnp.exp(w)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@jax.custom_vjp
|
|
117
|
+
def chunk_dplr(
|
|
118
|
+
r: jax.Array,
|
|
119
|
+
k: jax.Array,
|
|
120
|
+
v: jax.Array,
|
|
121
|
+
a: jax.Array,
|
|
122
|
+
b: jax.Array,
|
|
123
|
+
gk: jax.Array,
|
|
124
|
+
initial_state: jax.Array = None,
|
|
125
|
+
):
|
|
126
|
+
return chunk_dplr_delta_rule_fwd(
|
|
127
|
+
q=r,
|
|
128
|
+
k=k,
|
|
129
|
+
v=v,
|
|
130
|
+
a=a,
|
|
131
|
+
b=b,
|
|
132
|
+
gk=gk,
|
|
133
|
+
scale=1,
|
|
134
|
+
initial_state=initial_state,
|
|
135
|
+
output_final_state=True,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def chunk_dplr_fwd_jax(
|
|
140
|
+
r: jax.Array,
|
|
141
|
+
k: jax.Array,
|
|
142
|
+
v: jax.Array,
|
|
143
|
+
a: jax.Array,
|
|
144
|
+
b: jax.Array,
|
|
145
|
+
gk: jax.Array,
|
|
146
|
+
initial_state: jax.Array = None,
|
|
147
|
+
):
|
|
148
|
+
o, state = chunk_dplr_delta_rule_fwd(
|
|
149
|
+
q=r,
|
|
150
|
+
k=k,
|
|
151
|
+
v=v,
|
|
152
|
+
a=a,
|
|
153
|
+
b=b,
|
|
154
|
+
gk=gk,
|
|
155
|
+
scale=1,
|
|
156
|
+
initial_state=initial_state,
|
|
157
|
+
output_final_state=True,
|
|
158
|
+
)
|
|
159
|
+
cache = (r, k, v, a, b, gk, initial_state)
|
|
160
|
+
return (o, state), cache
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def chunk_dplr_bwd(
|
|
164
|
+
q: jax.Array,
|
|
165
|
+
k: jax.Array,
|
|
166
|
+
v: jax.Array,
|
|
167
|
+
a: jax.Array,
|
|
168
|
+
b: jax.Array,
|
|
169
|
+
gk: jax.Array,
|
|
170
|
+
initial_state,
|
|
171
|
+
scale,
|
|
172
|
+
do: jax.Array,
|
|
173
|
+
dht: jax.Array,
|
|
174
|
+
chunk_size: int = CHUNKSIZE,
|
|
175
|
+
):
|
|
176
|
+
# DTYPE = do.dtype
|
|
177
|
+
BT = chunk_size
|
|
178
|
+
scale = scale
|
|
179
|
+
# if do != None:
|
|
180
|
+
# do = do, q.dtype)
|
|
181
|
+
# if dht != None:
|
|
182
|
+
# dht = dht, q.dtype)
|
|
183
|
+
|
|
184
|
+
# ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
|
|
185
|
+
gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
|
|
186
|
+
|
|
187
|
+
A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
|
|
188
|
+
q=q,
|
|
189
|
+
k=k,
|
|
190
|
+
a=a,
|
|
191
|
+
b=b,
|
|
192
|
+
gi=gi,
|
|
193
|
+
ge=ge,
|
|
194
|
+
scale=scale,
|
|
195
|
+
chunk_size=BT,
|
|
196
|
+
)
|
|
197
|
+
w, u, A_ab_inv = prepare_wy_repr_fwd(
|
|
198
|
+
ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT
|
|
199
|
+
)
|
|
200
|
+
del A_ab
|
|
201
|
+
h, v_new, _ = chunk_dplr_fwd_h(
|
|
202
|
+
kg=kg, bg=bg, v=v, w=w, u=u, gk=gi, initial_state=initial_state, chunk_size=BT
|
|
203
|
+
)
|
|
204
|
+
del u
|
|
205
|
+
# ******* end of recomputation *******
|
|
206
|
+
# A_ak, A_ab_inv, gi, ge torch.float32
|
|
207
|
+
# A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16
|
|
208
|
+
|
|
209
|
+
dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
|
|
210
|
+
v=v, v_new=v_new, do=do, A_qb=A_qb, scale=scale, chunk_size=BT
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
dh, dh0, dv_new = chunk_dplr_bwd_dhu(
|
|
214
|
+
qg=qg,
|
|
215
|
+
bg=bg,
|
|
216
|
+
w=w,
|
|
217
|
+
gk=gi,
|
|
218
|
+
h0=initial_state,
|
|
219
|
+
dht=dht,
|
|
220
|
+
do=do,
|
|
221
|
+
dv=dv_new_intra,
|
|
222
|
+
chunk_size=BT,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
dv = chunk_dplr_bwd_dv(A_qk=A_qk, kg=kg, do=do, dh=dh, chunk_size=BT)
|
|
226
|
+
del A_qk
|
|
227
|
+
dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
|
|
228
|
+
k=kg,
|
|
229
|
+
b=bg,
|
|
230
|
+
v=v,
|
|
231
|
+
v_new=v_new,
|
|
232
|
+
do=do,
|
|
233
|
+
h=h,
|
|
234
|
+
dh=dh,
|
|
235
|
+
dv=dv_new,
|
|
236
|
+
w=w,
|
|
237
|
+
gk=gi,
|
|
238
|
+
chunk_size=BT,
|
|
239
|
+
scale=scale,
|
|
240
|
+
)
|
|
241
|
+
del v_new
|
|
242
|
+
dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
|
|
243
|
+
A_ab_inv=A_ab_inv,
|
|
244
|
+
A_ak=A_ak,
|
|
245
|
+
v=v,
|
|
246
|
+
ag=ag,
|
|
247
|
+
dw=dw,
|
|
248
|
+
du=dv_new,
|
|
249
|
+
dv0=dv,
|
|
250
|
+
chunk_size=BT,
|
|
251
|
+
)
|
|
252
|
+
del A_ak
|
|
253
|
+
|
|
254
|
+
dq, dk, da, db, dgk = chunk_dplr_bwd_dqk_intra(
|
|
255
|
+
q=q,
|
|
256
|
+
k=k,
|
|
257
|
+
a=a,
|
|
258
|
+
b=b,
|
|
259
|
+
gi=gi,
|
|
260
|
+
ge=ge,
|
|
261
|
+
dAqk=dA_qk,
|
|
262
|
+
dAqb=dA_qb,
|
|
263
|
+
dAak=dA_ak,
|
|
264
|
+
dAab=dA_ab,
|
|
265
|
+
dgk_last=dgk_last,
|
|
266
|
+
dqg=dqg,
|
|
267
|
+
dkg=dkg,
|
|
268
|
+
dag=dag,
|
|
269
|
+
dbg=dbg,
|
|
270
|
+
chunk_size=BT,
|
|
271
|
+
scale=scale,
|
|
272
|
+
)
|
|
273
|
+
return (
|
|
274
|
+
jnp.asarray(dq, q.dtype),
|
|
275
|
+
jnp.asarray(dk, k.dtype),
|
|
276
|
+
jnp.asarray(dv, v.dtype),
|
|
277
|
+
jnp.asarray(da, a.dtype),
|
|
278
|
+
jnp.asarray(db, b.dtype),
|
|
279
|
+
jnp.asarray(dgk, gk.dtype),
|
|
280
|
+
None if initial_state is None else jnp.asarray(dh0, initial_state.dtype),
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def chunk_dplr_bwd_jax(res, g):
|
|
285
|
+
q, k, v, a, b, gk, initial_state = res
|
|
286
|
+
do, dht = g
|
|
287
|
+
return chunk_dplr_bwd(
|
|
288
|
+
q,
|
|
289
|
+
k,
|
|
290
|
+
v,
|
|
291
|
+
a,
|
|
292
|
+
b,
|
|
293
|
+
gk,
|
|
294
|
+
initial_state,
|
|
295
|
+
scale=1,
|
|
296
|
+
do=do,
|
|
297
|
+
dht=dht,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
chunk_dplr.defvjp(chunk_dplr_fwd_jax, chunk_dplr_bwd_jax)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def transpose_head(x, head_first):
|
|
305
|
+
# x = jnp.asarray(x,"bfloat16")
|
|
306
|
+
if head_first:
|
|
307
|
+
return jnp.transpose(x, (0, 2, 1, 3))
|
|
308
|
+
else:
|
|
309
|
+
return x
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def generalized_delta_rule(
|
|
313
|
+
r: jax.Array,
|
|
314
|
+
w: jax.Array,
|
|
315
|
+
k: jax.Array,
|
|
316
|
+
v: jax.Array,
|
|
317
|
+
a: jax.Array,
|
|
318
|
+
b: jax.Array,
|
|
319
|
+
initial_state: jax.Array = None,
|
|
320
|
+
output_final_state: bool = True,
|
|
321
|
+
head_first: bool = False,
|
|
322
|
+
):
|
|
323
|
+
r"""
|
|
324
|
+
Main interface function for chunked delta rule attention.
|
|
325
|
+
|
|
326
|
+
分块 Delta Rule 注意力机制的主要接口函数。
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
q (jax.Array):
|
|
330
|
+
queries of shape `[B, T, H, K]`
|
|
331
|
+
k (jax.Array):
|
|
332
|
+
keys of shape `[B, T, H, K]`
|
|
333
|
+
v (jax.Array):
|
|
334
|
+
values of shape `[B, T, H, V]`
|
|
335
|
+
a (jax.Array):
|
|
336
|
+
activations of shape `[B, T, H, K]`
|
|
337
|
+
b (jax.Array):
|
|
338
|
+
betas of shape `[B, T, H, K]`
|
|
339
|
+
gk (jax.Array):
|
|
340
|
+
gk of shape `[B, T, H, K]` decay term in log space!
|
|
341
|
+
initial_state (Optional[jax.Array]):
|
|
342
|
+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
|
343
|
+
For equal-length input sequences, `N` equals the batch size `B`.
|
|
344
|
+
Default: `None`.
|
|
345
|
+
output_final_state (Optional[bool]):
|
|
346
|
+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
|
347
|
+
head_first (Optional[bool]):
|
|
348
|
+
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
|
349
|
+
Default: `False`.
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
o (jax.Array):
|
|
353
|
+
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
|
354
|
+
final_state (jax.Array):
|
|
355
|
+
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
|
356
|
+
"""
|
|
357
|
+
DTYPE = r.dtype
|
|
358
|
+
r = transpose_head(r, head_first)
|
|
359
|
+
k = transpose_head(k, head_first)
|
|
360
|
+
v = transpose_head(v, head_first)
|
|
361
|
+
a = transpose_head(a, head_first)
|
|
362
|
+
b = transpose_head(b, head_first)
|
|
363
|
+
|
|
364
|
+
if w is not None:
|
|
365
|
+
log_w = cal_log_w(w)
|
|
366
|
+
else:
|
|
367
|
+
assert log_w is not None, "Either w or log_w must be provided!"
|
|
368
|
+
log_w = transpose_head(log_w, head_first)
|
|
369
|
+
o, final_state = jax.checkpoint(
|
|
370
|
+
chunk_dplr, policy=cp.save_anything_except_these_names(())
|
|
371
|
+
)(
|
|
372
|
+
r=r,
|
|
373
|
+
k=k,
|
|
374
|
+
v=v,
|
|
375
|
+
a=a,
|
|
376
|
+
b=b,
|
|
377
|
+
gk=log_w,
|
|
378
|
+
initial_state=initial_state,
|
|
379
|
+
)
|
|
380
|
+
if output_final_state:
|
|
381
|
+
return jnp.asarray(o, DTYPE), final_state
|
|
382
|
+
return jnp.asarray(o, DTYPE)
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# copy from https://github.com/ml-explore/mlx-lm/pull/580
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _make_wkv7_kernel():
|
|
7
|
+
if not mx.metal.is_available():
|
|
8
|
+
return None
|
|
9
|
+
source = """
|
|
10
|
+
auto n = thread_position_in_grid.z;
|
|
11
|
+
auto b_idx = n / H;
|
|
12
|
+
auto h_idx = n % H;
|
|
13
|
+
constexpr int n_per_t = D / 32;
|
|
14
|
+
// [B, T, H, D]
|
|
15
|
+
auto r_ = r + b_idx * T * H * D + h_idx * D;
|
|
16
|
+
auto w_ = w + b_idx * T * H * D + h_idx * D;
|
|
17
|
+
auto k_ = k + b_idx * T * H * D + h_idx * D;
|
|
18
|
+
auto v_ = v + b_idx * T * H * D + h_idx * D;
|
|
19
|
+
auto a_ = a + b_idx * T * H * D + h_idx * D;
|
|
20
|
+
auto b_ = b + b_idx * T * H * D + h_idx * D;
|
|
21
|
+
y += b_idx * T * H * D + h_idx * D;
|
|
22
|
+
auto dk_idx = thread_position_in_threadgroup.x;
|
|
23
|
+
auto dv_idx = thread_position_in_grid.y;
|
|
24
|
+
// state_in, state_out: [B, H, D, D]
|
|
25
|
+
auto i_state = state_in + (n * D + dv_idx) * D;
|
|
26
|
+
auto o_state = state_out + (n * D + dv_idx) * D;
|
|
27
|
+
float state[n_per_t];
|
|
28
|
+
for (int i = 0; i < n_per_t; ++i) {
|
|
29
|
+
auto s_idx = n_per_t * dk_idx + i;
|
|
30
|
+
state[i] = static_cast<float>(i_state[s_idx]);
|
|
31
|
+
}
|
|
32
|
+
for (int t = 0; t < T; ++t) {
|
|
33
|
+
float sa = 0.0f;
|
|
34
|
+
for (int i = 0; i < n_per_t; ++i) {
|
|
35
|
+
auto s_idx = n_per_t * dk_idx + i;
|
|
36
|
+
sa += state[i] * a_[s_idx];
|
|
37
|
+
state[i] = state[i] * w_[s_idx];
|
|
38
|
+
}
|
|
39
|
+
sa = simd_sum(sa);
|
|
40
|
+
float out = 0.0f;
|
|
41
|
+
for (int i = 0; i < n_per_t; ++i) {
|
|
42
|
+
auto s_idx = n_per_t * dk_idx + i;
|
|
43
|
+
state[i] = state[i] + k_[s_idx] * v_[dv_idx] + sa * b_[s_idx];
|
|
44
|
+
out += state[i] * r_[s_idx];
|
|
45
|
+
}
|
|
46
|
+
out = simd_sum(out);
|
|
47
|
+
if (thread_index_in_simdgroup == 0) {
|
|
48
|
+
y[dv_idx] = static_cast<InT>(out);
|
|
49
|
+
}
|
|
50
|
+
// Increment data pointers to next time step
|
|
51
|
+
r_ += H * D;
|
|
52
|
+
w_ += H * D;
|
|
53
|
+
k_ += H * D;
|
|
54
|
+
v_ += H * D;
|
|
55
|
+
a_ += H * D;
|
|
56
|
+
b_ += H * D;
|
|
57
|
+
y += H * D;
|
|
58
|
+
}
|
|
59
|
+
for (int i = 0; i < n_per_t; ++i) {
|
|
60
|
+
auto s_idx = n_per_t * dk_idx + i;
|
|
61
|
+
o_state[s_idx] = static_cast<InT>(state[i]);
|
|
62
|
+
}
|
|
63
|
+
"""
|
|
64
|
+
inputs = ["r", "w", "k", "v", "a", "b", "state_in", "T"]
|
|
65
|
+
return mx.fast.metal_kernel(
|
|
66
|
+
name="wkv7_kernel",
|
|
67
|
+
input_names=inputs,
|
|
68
|
+
output_names=["y", "state_out"],
|
|
69
|
+
source=source,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
_wkv7_kernel = _make_wkv7_kernel()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def transpose_head(x, head_first: bool = True):
|
|
77
|
+
if head_first:
|
|
78
|
+
return mx.transpose(x, (0, 2, 1, 3))
|
|
79
|
+
return x
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def generalized_delta_rule(
|
|
83
|
+
r,
|
|
84
|
+
w,
|
|
85
|
+
k,
|
|
86
|
+
v,
|
|
87
|
+
a,
|
|
88
|
+
b,
|
|
89
|
+
initial_state=None,
|
|
90
|
+
output_final_state: bool = True,
|
|
91
|
+
head_first: bool = False,
|
|
92
|
+
):
|
|
93
|
+
state = initial_state
|
|
94
|
+
|
|
95
|
+
r = transpose_head(r, head_first)
|
|
96
|
+
k = transpose_head(k, head_first)
|
|
97
|
+
v = transpose_head(v, head_first)
|
|
98
|
+
a = transpose_head(a, head_first)
|
|
99
|
+
b = transpose_head(b, head_first)
|
|
100
|
+
|
|
101
|
+
B, T, H, D = r.shape
|
|
102
|
+
input_dtype = r.dtype
|
|
103
|
+
|
|
104
|
+
y, out_state = _wkv7_kernel(
|
|
105
|
+
inputs=[r, w, k, v, a, b, state, T],
|
|
106
|
+
template=[
|
|
107
|
+
("InT", input_dtype),
|
|
108
|
+
("H", H),
|
|
109
|
+
("D", D),
|
|
110
|
+
],
|
|
111
|
+
grid=(32, D, B * H),
|
|
112
|
+
threadgroup=(32, 4, 1),
|
|
113
|
+
output_shapes=[(B, T, H, D), state.shape],
|
|
114
|
+
output_dtypes=[input_dtype, input_dtype],
|
|
115
|
+
)
|
|
116
|
+
if output_final_state:
|
|
117
|
+
return y, out_state
|
|
118
|
+
return y
|