rwkv-ops 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_ops/__init__.py +45 -0
- rwkv_ops/mhc_kernel/__init__.py +50 -0
- rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
- rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
- rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
- rwkv_ops/rwkv6_kernel/__init__.py +120 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +113 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.6.1.dist-info/METADATA +495 -0
- rwkv_ops-0.6.1.dist-info/RECORD +89 -0
- rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
- rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JAX 版 RWKV7 wkv kernel + generalized_delta_rule
|
|
3
|
+
延迟编译 CUDA 扩展,接口与 Torch 版本 1:1 对齐
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
import pathlib
|
|
8
|
+
import subprocess
|
|
9
|
+
import ctypes
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
from typing import Optional, Tuple, Union
|
|
13
|
+
from jax.ad_checkpoint import checkpoint_policies as cp
|
|
14
|
+
|
|
15
|
+
CHUNK_LEN = 16 # 这是一个常数
|
|
16
|
+
# ---------- 延迟编译(改到当前目录) ----------
|
|
17
|
+
_CURRENT_DIR = pathlib.Path(
|
|
18
|
+
__file__
|
|
19
|
+
).parent.absolute() # rwkv_ops/rwkv7_kernel/jax_cuda_kernel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_jax_generalized_delta_rule(HEAD_SIZE=64):
|
|
23
|
+
_BUILD_DIR = _CURRENT_DIR / f"build_{HEAD_SIZE}"
|
|
24
|
+
_SO_PATH = _CURRENT_DIR / f"build_{HEAD_SIZE}/wkv7.so"
|
|
25
|
+
|
|
26
|
+
def _ensure_compiled() -> pathlib.Path:
|
|
27
|
+
"""首次调用时编译 CUDA 扩展,产出放在当前源码目录"""
|
|
28
|
+
if _SO_PATH.exists():
|
|
29
|
+
return _SO_PATH
|
|
30
|
+
|
|
31
|
+
print("[rwkv7_jax] First use – compiling CUDA kernel…")
|
|
32
|
+
src_dir = _CURRENT_DIR
|
|
33
|
+
build_dir = _BUILD_DIR
|
|
34
|
+
build_dir.mkdir(exist_ok=True)
|
|
35
|
+
|
|
36
|
+
# ---------- 关键:拿到 JAX 的 XLA 头文件路径 ----------
|
|
37
|
+
xla_include_dir = jax.ffi.include_dir() # 方案 3 核心 API
|
|
38
|
+
if not xla_include_dir:
|
|
39
|
+
raise RuntimeError("jax.ffi.include_dir() 返回空,请检查 JAX >= 0.4.31")
|
|
40
|
+
|
|
41
|
+
# ---------- 关键:把数值稳定性 flag 写死 ----------
|
|
42
|
+
cuda_flags = [
|
|
43
|
+
"-ftz=true", # flush sub-normal to zero
|
|
44
|
+
"-prec-div=false", # 更快除法,避免特殊路径
|
|
45
|
+
"-prec-sqrt=false", # 更快开方
|
|
46
|
+
"--use_fast_math", # 统一 fast math
|
|
47
|
+
"-O3",
|
|
48
|
+
"-Xptxas=-O3",
|
|
49
|
+
"-res-usage",
|
|
50
|
+
"--extra-device-vectorization",
|
|
51
|
+
"-D_C_=64",
|
|
52
|
+
f"-D_C_={HEAD_SIZE}",
|
|
53
|
+
f"-D_CHUNK_LEN_={CHUNK_LEN}",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
# 1. 配置
|
|
57
|
+
cmake_args = [
|
|
58
|
+
"cmake",
|
|
59
|
+
"-S",
|
|
60
|
+
str(src_dir),
|
|
61
|
+
"-B",
|
|
62
|
+
str(build_dir),
|
|
63
|
+
"-DCMAKE_BUILD_TYPE=Release",
|
|
64
|
+
f"-DCMAKE_INSTALL_PREFIX={_CURRENT_DIR}",
|
|
65
|
+
f"-DXLA_INCLUDE_DIR={xla_include_dir}", # 传给 CMake
|
|
66
|
+
f"-DCMAKE_CUDA_FLAGS={' '.join(cuda_flags)}",
|
|
67
|
+
]
|
|
68
|
+
subprocess.check_call(cmake_args)
|
|
69
|
+
|
|
70
|
+
# 2. 构建
|
|
71
|
+
subprocess.check_call(["cmake", "--build", str(build_dir), "-j"])
|
|
72
|
+
|
|
73
|
+
# 3. 安装(把 .so 拷贝到当前目录)
|
|
74
|
+
subprocess.check_call(["cmake", "--install", str(build_dir)])
|
|
75
|
+
|
|
76
|
+
if not _SO_PATH.exists():
|
|
77
|
+
raise RuntimeError("Compilation failed – wkv7.so not found.")
|
|
78
|
+
|
|
79
|
+
print("[rwkv7_jax] Compilation finished – output at", _SO_PATH)
|
|
80
|
+
return _SO_PATH
|
|
81
|
+
|
|
82
|
+
# 注册 FFI 符号
|
|
83
|
+
_lib = ctypes.CDLL(_ensure_compiled())
|
|
84
|
+
jax.ffi.register_ffi_target(
|
|
85
|
+
"wkv7_fwd", jax.ffi.pycapsule(_lib.Wkv7Fwd), platform="CUDA"
|
|
86
|
+
)
|
|
87
|
+
jax.ffi.register_ffi_target(
|
|
88
|
+
"wkv7_bwd", jax.ffi.pycapsule(_lib.Wkv7Bwd), platform="CUDA"
|
|
89
|
+
)
|
|
90
|
+
jax.ffi.register_ffi_target(
|
|
91
|
+
"wkv7_inference", jax.ffi.pycapsule(_lib.Wkv7Inference), platform="CUDA"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# ---------- 工具 ----------
|
|
95
|
+
def _transpose_head(x: jnp.ndarray, head_first: bool) -> jnp.ndarray:
|
|
96
|
+
"""(B, T, H, K) <-> (B, H, T, K)"""
|
|
97
|
+
x = jnp.asarray(x, dtype=jnp.bfloat16)
|
|
98
|
+
if head_first:
|
|
99
|
+
return jnp.transpose(x, (0, 2, 1, 3))
|
|
100
|
+
return x
|
|
101
|
+
|
|
102
|
+
# ---------- 前向 + 反向 kernel ----------
|
|
103
|
+
|
|
104
|
+
def _wkv7_kernel(
|
|
105
|
+
w: jnp.ndarray,
|
|
106
|
+
q: jnp.ndarray,
|
|
107
|
+
k: jnp.ndarray,
|
|
108
|
+
v: jnp.ndarray,
|
|
109
|
+
a: jnp.ndarray,
|
|
110
|
+
b: jnp.ndarray,
|
|
111
|
+
h0: jnp.ndarray,
|
|
112
|
+
):
|
|
113
|
+
"""
|
|
114
|
+
内部 kernel 接口
|
|
115
|
+
参数顺序与 wkv7_ffi.cc 声明完全一致:
|
|
116
|
+
w,q,k,v,z,a,b -> y,s,sa
|
|
117
|
+
"""
|
|
118
|
+
B, T, H, K = q.shape
|
|
119
|
+
dtype = q.dtype
|
|
120
|
+
chunk_num = int(T // CHUNK_LEN)
|
|
121
|
+
out_type = jax.ShapeDtypeStruct((B, T, H, K), dtype)
|
|
122
|
+
s_type = jax.ShapeDtypeStruct((B, H, chunk_num, K, K), jnp.float32)
|
|
123
|
+
sa_type = jax.ShapeDtypeStruct((B, T, H, K), jnp.float32)
|
|
124
|
+
|
|
125
|
+
y, s, sa = jax.ffi.ffi_call(
|
|
126
|
+
"wkv7_fwd", (out_type, s_type, sa_type), vmap_method="broadcast_all"
|
|
127
|
+
)(w, q, k, v, a, b, h0)
|
|
128
|
+
|
|
129
|
+
return y, s, sa
|
|
130
|
+
|
|
131
|
+
@jax.custom_vjp
|
|
132
|
+
def wk7_kernel(
|
|
133
|
+
w: jnp.ndarray,
|
|
134
|
+
q: jnp.ndarray,
|
|
135
|
+
k: jnp.ndarray,
|
|
136
|
+
v: jnp.ndarray,
|
|
137
|
+
a: jnp.ndarray,
|
|
138
|
+
b: jnp.ndarray,
|
|
139
|
+
h0: jnp.ndarray,
|
|
140
|
+
):
|
|
141
|
+
y, s, sa = _wkv7_kernel(w, q, k, v, a, b, h0)
|
|
142
|
+
finnal_state = s[:, :, -1]
|
|
143
|
+
return (y, jnp.transpose(finnal_state, [0, 1, 3, 2]))
|
|
144
|
+
|
|
145
|
+
# 前向定义
|
|
146
|
+
def _fwd(
|
|
147
|
+
w: jnp.ndarray,
|
|
148
|
+
q: jnp.ndarray,
|
|
149
|
+
k: jnp.ndarray,
|
|
150
|
+
v: jnp.ndarray,
|
|
151
|
+
a: jnp.ndarray,
|
|
152
|
+
b: jnp.ndarray,
|
|
153
|
+
h0: jnp.ndarray,
|
|
154
|
+
):
|
|
155
|
+
y, s, sa = _wkv7_kernel(w, q, k, v, a, b, h0)
|
|
156
|
+
finnal_state = s[:, :, -1]
|
|
157
|
+
return (y, jnp.transpose(finnal_state, [0, 1, 3, 2])), (w, q, k, v, a, b, s, sa)
|
|
158
|
+
|
|
159
|
+
def _wkv7_bwd_kernel(w, q, k, v, a, b, dy, s, sa, dht):
|
|
160
|
+
dh0_type = jax.ShapeDtypeStruct(dht.shape, dht.dtype)
|
|
161
|
+
dw_type = jax.ShapeDtypeStruct(w.shape, w.dtype)
|
|
162
|
+
dq_type = jax.ShapeDtypeStruct(q.shape, q.dtype)
|
|
163
|
+
dk_type = jax.ShapeDtypeStruct(k.shape, k.dtype)
|
|
164
|
+
dv_type = jax.ShapeDtypeStruct(v.shape, v.dtype)
|
|
165
|
+
da_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
|
|
166
|
+
db_type = jax.ShapeDtypeStruct(b.shape, b.dtype)
|
|
167
|
+
|
|
168
|
+
dh0, dw, dq, dk, dv, da, db = jax.ffi.ffi_call(
|
|
169
|
+
"wkv7_bwd",
|
|
170
|
+
(dh0_type, dw_type, dq_type, dk_type, dv_type, da_type, db_type),
|
|
171
|
+
vmap_method="broadcast_all",
|
|
172
|
+
)(w, q, k, v, a, b, dy, s, sa, dht)
|
|
173
|
+
|
|
174
|
+
return dw, dq, dk, dv, da, db, dh0
|
|
175
|
+
|
|
176
|
+
# 反向定义
|
|
177
|
+
def _bwd(res, grads):
|
|
178
|
+
w, q, k, v, a, b, s, sa = res
|
|
179
|
+
dy, dht = grads
|
|
180
|
+
dy = jnp.asarray(dy, jnp.bfloat16)
|
|
181
|
+
# 调用反向 kernel
|
|
182
|
+
return _wkv7_bwd_kernel(w, q, k, v, a, b, dy, s, sa, dht)
|
|
183
|
+
|
|
184
|
+
wk7_kernel.defvjp(_fwd, _bwd)
|
|
185
|
+
|
|
186
|
+
def generalized_delta_rule(
|
|
187
|
+
r: jnp.ndarray,
|
|
188
|
+
w: jnp.ndarray,
|
|
189
|
+
k: jnp.ndarray,
|
|
190
|
+
v: jnp.ndarray,
|
|
191
|
+
a: jnp.ndarray,
|
|
192
|
+
b: jnp.ndarray,
|
|
193
|
+
initial_state: Optional[jnp.ndarray] = None,
|
|
194
|
+
output_final_state: bool = True,
|
|
195
|
+
head_first: bool = False,
|
|
196
|
+
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
|
|
197
|
+
"""
|
|
198
|
+
广义 delta 规则,接口与 Torch 实现完全一致
|
|
199
|
+
参数:
|
|
200
|
+
r,w,k,v,a,b: 输入张量,形状 (B, T, H, K) 或 (B, H, T, K) 当 head_first=True
|
|
201
|
+
initial_state: 可选 (B, H, K, K) 初始状态,None 则零初始化
|
|
202
|
+
output_final_state: 是否同时返回最后状态
|
|
203
|
+
head_first: 是否将 head 维提前
|
|
204
|
+
chunk_len: 必须整除 T,默认 16
|
|
205
|
+
返回:
|
|
206
|
+
out: (B, T, H, K) 与输入 dtype 一致
|
|
207
|
+
last_state: (B, H, K, K) 当 output_final_state=True
|
|
208
|
+
"""
|
|
209
|
+
# 统一转 (B, T, H, K)
|
|
210
|
+
dtype = r.dtype
|
|
211
|
+
r = _transpose_head(r, head_first)
|
|
212
|
+
w = _transpose_head(w, head_first)
|
|
213
|
+
k = _transpose_head(k, head_first)
|
|
214
|
+
v = _transpose_head(v, head_first)
|
|
215
|
+
a = _transpose_head(a, head_first)
|
|
216
|
+
b = _transpose_head(b, head_first)
|
|
217
|
+
|
|
218
|
+
B, T, H, K = r.shape
|
|
219
|
+
if T % CHUNK_LEN:
|
|
220
|
+
raise ValueError(
|
|
221
|
+
f"Sequence length T={T} must be divisible by chunk_len={CHUNK_LEN}"
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# 处理初始状态
|
|
225
|
+
if initial_state is None:
|
|
226
|
+
h0 = jnp.zeros((B, H, K, K), jnp.float32)
|
|
227
|
+
else:
|
|
228
|
+
h0 = jnp.asarray(initial_state, jnp.float32)
|
|
229
|
+
|
|
230
|
+
# 调用 kernel
|
|
231
|
+
|
|
232
|
+
out, last_state = jax.checkpoint(
|
|
233
|
+
wk7_kernel, policy=cp.save_anything_except_these_names(())
|
|
234
|
+
)(w, r, k, v, a, b, h0)
|
|
235
|
+
out = jnp.asarray(out, dtype) # 保证输出 dtype 与输入一致
|
|
236
|
+
|
|
237
|
+
if output_final_state:
|
|
238
|
+
return out, last_state
|
|
239
|
+
return out
|
|
240
|
+
|
|
241
|
+
def _wkv7_inference_kernel(
|
|
242
|
+
w: jnp.ndarray,
|
|
243
|
+
q: jnp.ndarray,
|
|
244
|
+
k: jnp.ndarray,
|
|
245
|
+
v: jnp.ndarray,
|
|
246
|
+
a: jnp.ndarray,
|
|
247
|
+
b: jnp.ndarray,
|
|
248
|
+
h0: jnp.ndarray,
|
|
249
|
+
):
|
|
250
|
+
"""
|
|
251
|
+
推理专用 kernel,不保存 sa 和中间 s
|
|
252
|
+
返回: y (B, T, H, K), final_state (B, H, K, K)
|
|
253
|
+
"""
|
|
254
|
+
B, T, H, K = q.shape
|
|
255
|
+
dtype = q.dtype
|
|
256
|
+
out_type = jax.ShapeDtypeStruct((B, T, H, K), dtype)
|
|
257
|
+
# **关键:仅返回最终状态,非 chunk 历史**
|
|
258
|
+
s_type = jax.ShapeDtypeStruct((B, H, K, K), jnp.float32)
|
|
259
|
+
|
|
260
|
+
y, s = jax.ffi.ffi_call(
|
|
261
|
+
"wkv7_inference", (out_type, s_type), vmap_method="broadcast_all"
|
|
262
|
+
)(w, q, k, v, a, b, h0) # z 参数自动忽略
|
|
263
|
+
|
|
264
|
+
return y, s
|
|
265
|
+
|
|
266
|
+
# -------------------- 公共推理 API --------------------
|
|
267
|
+
def generalized_delta_rule_inference(
|
|
268
|
+
r: jnp.ndarray,
|
|
269
|
+
w: jnp.ndarray,
|
|
270
|
+
k: jnp.ndarray,
|
|
271
|
+
v: jnp.ndarray,
|
|
272
|
+
a: jnp.ndarray,
|
|
273
|
+
b: jnp.ndarray,
|
|
274
|
+
output_final_state: bool = True,
|
|
275
|
+
initial_state: Optional[jnp.ndarray] = None,
|
|
276
|
+
head_first: bool = False,
|
|
277
|
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
|
278
|
+
"""
|
|
279
|
+
纯推理版本的广义 delta 规则
|
|
280
|
+
|
|
281
|
+
参数:
|
|
282
|
+
r,w,k,v,a,b: 输入张量,形状 (B, T, H, K) 或 (B, H, T, K)
|
|
283
|
+
initial_state: (B, H, K, K) 初始状态,None 则零初始化
|
|
284
|
+
head_first: 是否将 head 维提前
|
|
285
|
+
返回:
|
|
286
|
+
out: (B, T, H, K) 输出,dtype 与输入一致
|
|
287
|
+
final_state: (B, H, K, K) 仅最终状态
|
|
288
|
+
"""
|
|
289
|
+
dtype = r.dtype
|
|
290
|
+
r = _transpose_head(r, head_first)
|
|
291
|
+
w = _transpose_head(w, head_first)
|
|
292
|
+
k = _transpose_head(k, head_first)
|
|
293
|
+
v = _transpose_head(v, head_first)
|
|
294
|
+
a = _transpose_head(a, head_first)
|
|
295
|
+
b = _transpose_head(b, head_first)
|
|
296
|
+
|
|
297
|
+
B, T, H, K = r.shape
|
|
298
|
+
|
|
299
|
+
# 处理初始状态
|
|
300
|
+
if initial_state is None:
|
|
301
|
+
h0 = jnp.zeros((B, H, K, K), jnp.float32)
|
|
302
|
+
else:
|
|
303
|
+
h0 = jnp.asarray(initial_state, jnp.float32)
|
|
304
|
+
|
|
305
|
+
# **无需 checkpoint,推理不保存中间值**
|
|
306
|
+
out, final_state = _wkv7_inference_kernel(w, r, k, v, a, b, h0)
|
|
307
|
+
out = jnp.asarray(out, dtype)
|
|
308
|
+
return out, final_state if output_final_state else out
|
|
309
|
+
|
|
310
|
+
# 返回两个函数,用户按需选择
|
|
311
|
+
return [generalized_delta_rule, generalized_delta_rule_inference]
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.18)
|
|
2
|
+
project(wkv7_single_step LANGUAGES CXX CUDA)
|
|
3
|
+
|
|
4
|
+
find_package(CUDAToolkit REQUIRED)
|
|
5
|
+
|
|
6
|
+
# ---------- 1. 找到 Python ----------
|
|
7
|
+
find_package(Python3 REQUIRED COMPONENTS Interpreter)
|
|
8
|
+
|
|
9
|
+
# ---------- 2. 取 XLA 头文件路径 ----------
|
|
10
|
+
execute_process(
|
|
11
|
+
COMMAND "${Python3_EXECUTABLE}" -c "import jax; print(jax.ffi.include_dir())"
|
|
12
|
+
OUTPUT_VARIABLE XLA_INCLUDE_DIR
|
|
13
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
14
|
+
)
|
|
15
|
+
if(NOT XLA_INCLUDE_DIR)
|
|
16
|
+
message(FATAL_ERROR "Cannot get XLA include dir from jax.ffi")
|
|
17
|
+
endif()
|
|
18
|
+
message(STATUS "XLA include directory: ${XLA_INCLUDE_DIR}")
|
|
19
|
+
|
|
20
|
+
# ---------- 3. 生成共享库 ----------
|
|
21
|
+
add_library(wkv7_single_step SHARED wkv7_single_step_ffi.cu)
|
|
22
|
+
|
|
23
|
+
# 3-1. 头文件搜索路径
|
|
24
|
+
target_include_directories(wkv7_single_step PRIVATE ${XLA_INCLUDE_DIR})
|
|
25
|
+
|
|
26
|
+
# 3-2. 链接 CUDA 运行时
|
|
27
|
+
target_link_libraries(wkv7_single_step PRIVATE CUDA::cudart)
|
|
28
|
+
|
|
29
|
+
# 3-3. 关键:C++17 / CUDA17 标准
|
|
30
|
+
target_compile_features(wkv7_single_step PUBLIC cxx_std_17)
|
|
31
|
+
set_target_properties(wkv7_single_step PROPERTIES
|
|
32
|
+
CUDA_STANDARD 17
|
|
33
|
+
CUDA_SEPARABLE_COMPILATION ON
|
|
34
|
+
POSITION_INDEPENDENT_CODE ON
|
|
35
|
+
PREFIX "" # 去掉默认的 "lib" 前缀
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# ---------- 4. 安装 ----------
|
|
39
|
+
# 把 .so 直接装到源码目录,方便 ctypes.CDLL 加载
|
|
40
|
+
install(TARGETS wkv7_single_step
|
|
41
|
+
LIBRARY DESTINATION "${CMAKE_SOURCE_DIR}"
|
|
42
|
+
RUNTIME DESTINATION "${CMAKE_SOURCE_DIR}") # Windows 用 RUNTIME
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
#include <cuda_bf16.h>
|
|
2
|
+
#include <cuda_runtime.h>
|
|
3
|
+
#include <xla/ffi/api/ffi.h>
|
|
4
|
+
#include <vector>
|
|
5
|
+
#include <cstdint>
|
|
6
|
+
|
|
7
|
+
namespace ffi = xla::ffi;
|
|
8
|
+
using bf = __nv_bfloat16;
|
|
9
|
+
|
|
10
|
+
/* -------------------- 设备端辅助 -------------------- */
|
|
11
|
+
__device__ inline float to_float(const bf &u) {
|
|
12
|
+
return __bfloat162float(u);
|
|
13
|
+
}
|
|
14
|
+
__device__ inline bf to_bf(const float &u) {
|
|
15
|
+
return __float2bfloat16_rn(u);
|
|
16
|
+
}
|
|
17
|
+
typedef bf *__restrict__ F_;
|
|
18
|
+
|
|
19
|
+
/* -------------------- 前向 Kernel(修复) -------------------- */
|
|
20
|
+
template<int C>
|
|
21
|
+
__launch_bounds__(C, 2)
|
|
22
|
+
__global__ void forward_kernel_single_step(
|
|
23
|
+
int B, int H,
|
|
24
|
+
F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
|
|
25
|
+
bf *y_, float *s_, float *h0_)
|
|
26
|
+
{
|
|
27
|
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
|
28
|
+
float state[C] = {0};
|
|
29
|
+
__shared__ float q[C], k[C], w[C], a[C], b[C];
|
|
30
|
+
|
|
31
|
+
// 加载初始状态 (B, H, C, C)
|
|
32
|
+
int64_t h0_base = ((int64_t)bb * H + hh) * C * C + i * C;
|
|
33
|
+
#pragma unroll
|
|
34
|
+
for (int j = 0; j < C; ++j) state[j] = h0_[h0_base + j];
|
|
35
|
+
|
|
36
|
+
// 单步索引: (B, H, C)
|
|
37
|
+
int64_t ind = (int64_t)bb * H * C + hh * C + i;
|
|
38
|
+
|
|
39
|
+
__syncthreads();
|
|
40
|
+
q[i] = to_float(q_[ind]);
|
|
41
|
+
w[i] = __expf(-__expf(to_float(w_[ind])));
|
|
42
|
+
k[i] = to_float(k_[ind]);
|
|
43
|
+
a[i] = to_float(a_[ind]);
|
|
44
|
+
b[i] = to_float(b_[ind]);
|
|
45
|
+
__syncthreads();
|
|
46
|
+
|
|
47
|
+
float sa = 0.f;
|
|
48
|
+
#pragma unroll
|
|
49
|
+
for (int j = 0; j < C; ++j) sa += a[j] * state[j];
|
|
50
|
+
|
|
51
|
+
float v_val = to_float(v_[ind]);
|
|
52
|
+
float y = 0.f;
|
|
53
|
+
#pragma unroll
|
|
54
|
+
for (int j = 0; j < C; ++j) {
|
|
55
|
+
float &s = state[j];
|
|
56
|
+
s = s * w[j] + sa * b[j] + k[j] * v_val;
|
|
57
|
+
y += s * q[j];
|
|
58
|
+
}
|
|
59
|
+
y_[ind] = to_bf(y);
|
|
60
|
+
|
|
61
|
+
// 写入最终状态
|
|
62
|
+
int64_t s_base = ((int64_t)bb * H + hh) * C * C + i * C;
|
|
63
|
+
#pragma unroll
|
|
64
|
+
for (int j = 0; j < C; ++j) s_[s_base + j] = state[j];
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
/* -------------------- 反向 Kernel(补充) -------------------- */
|
|
68
|
+
template<int C>
|
|
69
|
+
__launch_bounds__(C, 2)
|
|
70
|
+
__global__ void backward_kernel_single_step(
|
|
71
|
+
int B, int H,
|
|
72
|
+
F_ w_, F_ q_, F_ k_, F_ v_, F_ dy_,
|
|
73
|
+
float *s_, float *dht_, bf *dw_, bf *dq_, bf *dk_, bf *dv_, bf *da_, bf *db_)
|
|
74
|
+
{
|
|
75
|
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
|
76
|
+
float stateT[C] = {0}, dstate[C] = {0};
|
|
77
|
+
|
|
78
|
+
int64_t dht_base = ((int64_t)bb * H + hh) * C * C + i * C;
|
|
79
|
+
#pragma unroll
|
|
80
|
+
for (int j = 0; j < C; ++j) dstate[j] = dht_[dht_base + j];
|
|
81
|
+
|
|
82
|
+
__shared__ float w[C], q[C], k[C], v[C], dy[C];
|
|
83
|
+
int64_t ind = (int64_t)bb * H * C + hh * C + i;
|
|
84
|
+
|
|
85
|
+
__syncthreads();
|
|
86
|
+
q[i] = to_float(q_[ind]);
|
|
87
|
+
float wi_fac = -__expf(to_float(w_[ind]));
|
|
88
|
+
w[i] = __expf(wi_fac);
|
|
89
|
+
k[i] = to_float(k_[ind]);
|
|
90
|
+
v[i] = to_float(v_[ind]);
|
|
91
|
+
dy[i] = to_float(dy_[ind]);
|
|
92
|
+
__syncthreads();
|
|
93
|
+
|
|
94
|
+
// 从 s_ 加载 stateT(float4 优化可在此处添加)
|
|
95
|
+
int64_t s_base = ((int64_t)bb * H + hh) * C * C + i * C;
|
|
96
|
+
#pragma unroll
|
|
97
|
+
for (int j = 0; j < C; ++j) stateT[j] = s_[s_base + j];
|
|
98
|
+
|
|
99
|
+
float dq_val = 0.f, dw_val = 0.f, dk_val = 0.f, dv_val = 0.f, da_val = 0.f, db_val = 0.f;
|
|
100
|
+
float iwi = 1.0f / (w[i] + 1e-6f);
|
|
101
|
+
|
|
102
|
+
#pragma unroll
|
|
103
|
+
for (int j = 0; j < C; ++j) {
|
|
104
|
+
stateT[j] = (stateT[j] - k[i] * v[j]) * iwi;
|
|
105
|
+
dstate[j] += dy[i] * q[j];
|
|
106
|
+
|
|
107
|
+
dq_val += stateT[j] * dy[j];
|
|
108
|
+
dw_val += dstate[j] * stateT[j];
|
|
109
|
+
dk_val += dstate[j] * v[j];
|
|
110
|
+
dv_val += dstate[j] * k[j];
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
dq_[ind] = to_bf(dq_val);
|
|
114
|
+
dw_[ind] = to_bf(dw_val * w[i] * wi_fac);
|
|
115
|
+
dk_[ind] = to_bf(dk_val);
|
|
116
|
+
dv_[ind] = to_bf(dv_val);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
/* -------------------- Host 函数(修复调用) -------------------- */
|
|
120
|
+
static ffi::Error WKV7SingleStepFwdHost(
|
|
121
|
+
cudaStream_t stream,
|
|
122
|
+
ffi::Buffer<ffi::BF16> w,
|
|
123
|
+
ffi::Buffer<ffi::BF16> q,
|
|
124
|
+
ffi::Buffer<ffi::BF16> k,
|
|
125
|
+
ffi::Buffer<ffi::BF16> v,
|
|
126
|
+
ffi::Buffer<ffi::BF16> a,
|
|
127
|
+
ffi::Buffer<ffi::BF16> b,
|
|
128
|
+
ffi::Buffer<ffi::F32> h0,
|
|
129
|
+
ffi::ResultBuffer<ffi::BF16> y,
|
|
130
|
+
ffi::ResultBuffer<ffi::F32> s)
|
|
131
|
+
{
|
|
132
|
+
auto dims = w.dimensions();
|
|
133
|
+
int B = dims[0], H = dims[1];
|
|
134
|
+
constexpr int C = _C_; // 从编译选项获取
|
|
135
|
+
dim3 block(C);
|
|
136
|
+
dim3 grid(H, B);
|
|
137
|
+
|
|
138
|
+
// ✅ 修复:显式指定模板参数 <_C_>
|
|
139
|
+
forward_kernel_single_step<_C_><<<grid, block, 0, stream>>>(
|
|
140
|
+
B, H,
|
|
141
|
+
reinterpret_cast<bf *>(w.typed_data()),
|
|
142
|
+
reinterpret_cast<bf *>(q.typed_data()),
|
|
143
|
+
reinterpret_cast<bf *>(k.typed_data()),
|
|
144
|
+
reinterpret_cast<bf *>(v.typed_data()),
|
|
145
|
+
reinterpret_cast<bf *>(a.typed_data()),
|
|
146
|
+
reinterpret_cast<bf *>(b.typed_data()),
|
|
147
|
+
reinterpret_cast<bf *>(y->typed_data()),
|
|
148
|
+
s->typed_data(),
|
|
149
|
+
h0.typed_data());
|
|
150
|
+
|
|
151
|
+
cudaError_t err = cudaGetLastError();
|
|
152
|
+
if (err != cudaSuccess)
|
|
153
|
+
return ffi::Error::Internal(
|
|
154
|
+
std::string("CUDA forward_kernel_single_step error: ") + cudaGetErrorString(err));
|
|
155
|
+
return ffi::Error::Success();
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
/* -------------------- FFI 符号注册 -------------------- */
|
|
159
|
+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
|
160
|
+
Wkv7SingleStepFwd, WKV7SingleStepFwdHost,
|
|
161
|
+
ffi::Ffi::Bind()
|
|
162
|
+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
|
|
163
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // w
|
|
164
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // q
|
|
165
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // k
|
|
166
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // v
|
|
167
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // a
|
|
168
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // b
|
|
169
|
+
.Arg<ffi::Buffer<ffi::F32>>() // h0
|
|
170
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // y
|
|
171
|
+
.Ret<ffi::Buffer<ffi::F32>>() // s
|
|
172
|
+
, {ffi::Traits::kCmdBufferCompatible});
|