rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,311 @@
1
+ """
2
+ JAX 版 RWKV7 wkv kernel + generalized_delta_rule
3
+ 延迟编译 CUDA 扩展,接口与 Torch 版本 1:1 对齐
4
+ """
5
+
6
+ from __future__ import annotations
7
+ import pathlib
8
+ import subprocess
9
+ import ctypes
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from typing import Optional, Tuple, Union
13
+ from jax.ad_checkpoint import checkpoint_policies as cp
14
+
15
+ CHUNK_LEN = 16 # 这是一个常数
16
+ # ---------- 延迟编译(改到当前目录) ----------
17
+ _CURRENT_DIR = pathlib.Path(
18
+ __file__
19
+ ).parent.absolute() # rwkv_ops/rwkv7_kernel/jax_cuda_kernel
20
+
21
+
22
+ def get_jax_generalized_delta_rule(HEAD_SIZE=64):
23
+ _BUILD_DIR = _CURRENT_DIR / f"build_{HEAD_SIZE}"
24
+ _SO_PATH = _CURRENT_DIR / f"build_{HEAD_SIZE}/wkv7.so"
25
+
26
+ def _ensure_compiled() -> pathlib.Path:
27
+ """首次调用时编译 CUDA 扩展,产出放在当前源码目录"""
28
+ if _SO_PATH.exists():
29
+ return _SO_PATH
30
+
31
+ print("[rwkv7_jax] First use – compiling CUDA kernel…")
32
+ src_dir = _CURRENT_DIR
33
+ build_dir = _BUILD_DIR
34
+ build_dir.mkdir(exist_ok=True)
35
+
36
+ # ---------- 关键:拿到 JAX 的 XLA 头文件路径 ----------
37
+ xla_include_dir = jax.ffi.include_dir() # 方案 3 核心 API
38
+ if not xla_include_dir:
39
+ raise RuntimeError("jax.ffi.include_dir() 返回空,请检查 JAX >= 0.4.31")
40
+
41
+ # ---------- 关键:把数值稳定性 flag 写死 ----------
42
+ cuda_flags = [
43
+ "-ftz=true", # flush sub-normal to zero
44
+ "-prec-div=false", # 更快除法,避免特殊路径
45
+ "-prec-sqrt=false", # 更快开方
46
+ "--use_fast_math", # 统一 fast math
47
+ "-O3",
48
+ "-Xptxas=-O3",
49
+ "-res-usage",
50
+ "--extra-device-vectorization",
51
+ "-D_C_=64",
52
+ f"-D_C_={HEAD_SIZE}",
53
+ f"-D_CHUNK_LEN_={CHUNK_LEN}",
54
+ ]
55
+
56
+ # 1. 配置
57
+ cmake_args = [
58
+ "cmake",
59
+ "-S",
60
+ str(src_dir),
61
+ "-B",
62
+ str(build_dir),
63
+ "-DCMAKE_BUILD_TYPE=Release",
64
+ f"-DCMAKE_INSTALL_PREFIX={_CURRENT_DIR}",
65
+ f"-DXLA_INCLUDE_DIR={xla_include_dir}", # 传给 CMake
66
+ f"-DCMAKE_CUDA_FLAGS={' '.join(cuda_flags)}",
67
+ ]
68
+ subprocess.check_call(cmake_args)
69
+
70
+ # 2. 构建
71
+ subprocess.check_call(["cmake", "--build", str(build_dir), "-j"])
72
+
73
+ # 3. 安装(把 .so 拷贝到当前目录)
74
+ subprocess.check_call(["cmake", "--install", str(build_dir)])
75
+
76
+ if not _SO_PATH.exists():
77
+ raise RuntimeError("Compilation failed – wkv7.so not found.")
78
+
79
+ print("[rwkv7_jax] Compilation finished – output at", _SO_PATH)
80
+ return _SO_PATH
81
+
82
+ # 注册 FFI 符号
83
+ _lib = ctypes.CDLL(_ensure_compiled())
84
+ jax.ffi.register_ffi_target(
85
+ "wkv7_fwd", jax.ffi.pycapsule(_lib.Wkv7Fwd), platform="CUDA"
86
+ )
87
+ jax.ffi.register_ffi_target(
88
+ "wkv7_bwd", jax.ffi.pycapsule(_lib.Wkv7Bwd), platform="CUDA"
89
+ )
90
+ jax.ffi.register_ffi_target(
91
+ "wkv7_inference", jax.ffi.pycapsule(_lib.Wkv7Inference), platform="CUDA"
92
+ )
93
+
94
+ # ---------- 工具 ----------
95
+ def _transpose_head(x: jnp.ndarray, head_first: bool) -> jnp.ndarray:
96
+ """(B, T, H, K) <-> (B, H, T, K)"""
97
+ x = jnp.asarray(x, dtype=jnp.bfloat16)
98
+ if head_first:
99
+ return jnp.transpose(x, (0, 2, 1, 3))
100
+ return x
101
+
102
+ # ---------- 前向 + 反向 kernel ----------
103
+
104
+ def _wkv7_kernel(
105
+ w: jnp.ndarray,
106
+ q: jnp.ndarray,
107
+ k: jnp.ndarray,
108
+ v: jnp.ndarray,
109
+ a: jnp.ndarray,
110
+ b: jnp.ndarray,
111
+ h0: jnp.ndarray,
112
+ ):
113
+ """
114
+ 内部 kernel 接口
115
+ 参数顺序与 wkv7_ffi.cc 声明完全一致:
116
+ w,q,k,v,z,a,b -> y,s,sa
117
+ """
118
+ B, T, H, K = q.shape
119
+ dtype = q.dtype
120
+ chunk_num = int(T // CHUNK_LEN)
121
+ out_type = jax.ShapeDtypeStruct((B, T, H, K), dtype)
122
+ s_type = jax.ShapeDtypeStruct((B, H, chunk_num, K, K), jnp.float32)
123
+ sa_type = jax.ShapeDtypeStruct((B, T, H, K), jnp.float32)
124
+
125
+ y, s, sa = jax.ffi.ffi_call(
126
+ "wkv7_fwd", (out_type, s_type, sa_type), vmap_method="broadcast_all"
127
+ )(w, q, k, v, a, b, h0)
128
+
129
+ return y, s, sa
130
+
131
+ @jax.custom_vjp
132
+ def wk7_kernel(
133
+ w: jnp.ndarray,
134
+ q: jnp.ndarray,
135
+ k: jnp.ndarray,
136
+ v: jnp.ndarray,
137
+ a: jnp.ndarray,
138
+ b: jnp.ndarray,
139
+ h0: jnp.ndarray,
140
+ ):
141
+ y, s, sa = _wkv7_kernel(w, q, k, v, a, b, h0)
142
+ finnal_state = s[:, :, -1]
143
+ return (y, jnp.transpose(finnal_state, [0, 1, 3, 2]))
144
+
145
+ # 前向定义
146
+ def _fwd(
147
+ w: jnp.ndarray,
148
+ q: jnp.ndarray,
149
+ k: jnp.ndarray,
150
+ v: jnp.ndarray,
151
+ a: jnp.ndarray,
152
+ b: jnp.ndarray,
153
+ h0: jnp.ndarray,
154
+ ):
155
+ y, s, sa = _wkv7_kernel(w, q, k, v, a, b, h0)
156
+ finnal_state = s[:, :, -1]
157
+ return (y, jnp.transpose(finnal_state, [0, 1, 3, 2])), (w, q, k, v, a, b, s, sa)
158
+
159
+ def _wkv7_bwd_kernel(w, q, k, v, a, b, dy, s, sa, dht):
160
+ dh0_type = jax.ShapeDtypeStruct(dht.shape, dht.dtype)
161
+ dw_type = jax.ShapeDtypeStruct(w.shape, w.dtype)
162
+ dq_type = jax.ShapeDtypeStruct(q.shape, q.dtype)
163
+ dk_type = jax.ShapeDtypeStruct(k.shape, k.dtype)
164
+ dv_type = jax.ShapeDtypeStruct(v.shape, v.dtype)
165
+ da_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
166
+ db_type = jax.ShapeDtypeStruct(b.shape, b.dtype)
167
+
168
+ dh0, dw, dq, dk, dv, da, db = jax.ffi.ffi_call(
169
+ "wkv7_bwd",
170
+ (dh0_type, dw_type, dq_type, dk_type, dv_type, da_type, db_type),
171
+ vmap_method="broadcast_all",
172
+ )(w, q, k, v, a, b, dy, s, sa, dht)
173
+
174
+ return dw, dq, dk, dv, da, db, dh0
175
+
176
+ # 反向定义
177
+ def _bwd(res, grads):
178
+ w, q, k, v, a, b, s, sa = res
179
+ dy, dht = grads
180
+ dy = jnp.asarray(dy, jnp.bfloat16)
181
+ # 调用反向 kernel
182
+ return _wkv7_bwd_kernel(w, q, k, v, a, b, dy, s, sa, dht)
183
+
184
+ wk7_kernel.defvjp(_fwd, _bwd)
185
+
186
+ def generalized_delta_rule(
187
+ r: jnp.ndarray,
188
+ w: jnp.ndarray,
189
+ k: jnp.ndarray,
190
+ v: jnp.ndarray,
191
+ a: jnp.ndarray,
192
+ b: jnp.ndarray,
193
+ initial_state: Optional[jnp.ndarray] = None,
194
+ output_final_state: bool = True,
195
+ head_first: bool = False,
196
+ ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
197
+ """
198
+ 广义 delta 规则,接口与 Torch 实现完全一致
199
+ 参数:
200
+ r,w,k,v,a,b: 输入张量,形状 (B, T, H, K) 或 (B, H, T, K) 当 head_first=True
201
+ initial_state: 可选 (B, H, K, K) 初始状态,None 则零初始化
202
+ output_final_state: 是否同时返回最后状态
203
+ head_first: 是否将 head 维提前
204
+ chunk_len: 必须整除 T,默认 16
205
+ 返回:
206
+ out: (B, T, H, K) 与输入 dtype 一致
207
+ last_state: (B, H, K, K) 当 output_final_state=True
208
+ """
209
+ # 统一转 (B, T, H, K)
210
+ dtype = r.dtype
211
+ r = _transpose_head(r, head_first)
212
+ w = _transpose_head(w, head_first)
213
+ k = _transpose_head(k, head_first)
214
+ v = _transpose_head(v, head_first)
215
+ a = _transpose_head(a, head_first)
216
+ b = _transpose_head(b, head_first)
217
+
218
+ B, T, H, K = r.shape
219
+ if T % CHUNK_LEN:
220
+ raise ValueError(
221
+ f"Sequence length T={T} must be divisible by chunk_len={CHUNK_LEN}"
222
+ )
223
+
224
+ # 处理初始状态
225
+ if initial_state is None:
226
+ h0 = jnp.zeros((B, H, K, K), jnp.float32)
227
+ else:
228
+ h0 = jnp.asarray(initial_state, jnp.float32)
229
+
230
+ # 调用 kernel
231
+
232
+ out, last_state = jax.checkpoint(
233
+ wk7_kernel, policy=cp.save_anything_except_these_names(())
234
+ )(w, r, k, v, a, b, h0)
235
+ out = jnp.asarray(out, dtype) # 保证输出 dtype 与输入一致
236
+
237
+ if output_final_state:
238
+ return out, last_state
239
+ return out
240
+
241
+ def _wkv7_inference_kernel(
242
+ w: jnp.ndarray,
243
+ q: jnp.ndarray,
244
+ k: jnp.ndarray,
245
+ v: jnp.ndarray,
246
+ a: jnp.ndarray,
247
+ b: jnp.ndarray,
248
+ h0: jnp.ndarray,
249
+ ):
250
+ """
251
+ 推理专用 kernel,不保存 sa 和中间 s
252
+ 返回: y (B, T, H, K), final_state (B, H, K, K)
253
+ """
254
+ B, T, H, K = q.shape
255
+ dtype = q.dtype
256
+ out_type = jax.ShapeDtypeStruct((B, T, H, K), dtype)
257
+ # **关键:仅返回最终状态,非 chunk 历史**
258
+ s_type = jax.ShapeDtypeStruct((B, H, K, K), jnp.float32)
259
+
260
+ y, s = jax.ffi.ffi_call(
261
+ "wkv7_inference", (out_type, s_type), vmap_method="broadcast_all"
262
+ )(w, q, k, v, a, b, h0) # z 参数自动忽略
263
+
264
+ return y, s
265
+
266
+ # -------------------- 公共推理 API --------------------
267
+ def generalized_delta_rule_inference(
268
+ r: jnp.ndarray,
269
+ w: jnp.ndarray,
270
+ k: jnp.ndarray,
271
+ v: jnp.ndarray,
272
+ a: jnp.ndarray,
273
+ b: jnp.ndarray,
274
+ output_final_state: bool = True,
275
+ initial_state: Optional[jnp.ndarray] = None,
276
+ head_first: bool = False,
277
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
278
+ """
279
+ 纯推理版本的广义 delta 规则
280
+
281
+ 参数:
282
+ r,w,k,v,a,b: 输入张量,形状 (B, T, H, K) 或 (B, H, T, K)
283
+ initial_state: (B, H, K, K) 初始状态,None 则零初始化
284
+ head_first: 是否将 head 维提前
285
+ 返回:
286
+ out: (B, T, H, K) 输出,dtype 与输入一致
287
+ final_state: (B, H, K, K) 仅最终状态
288
+ """
289
+ dtype = r.dtype
290
+ r = _transpose_head(r, head_first)
291
+ w = _transpose_head(w, head_first)
292
+ k = _transpose_head(k, head_first)
293
+ v = _transpose_head(v, head_first)
294
+ a = _transpose_head(a, head_first)
295
+ b = _transpose_head(b, head_first)
296
+
297
+ B, T, H, K = r.shape
298
+
299
+ # 处理初始状态
300
+ if initial_state is None:
301
+ h0 = jnp.zeros((B, H, K, K), jnp.float32)
302
+ else:
303
+ h0 = jnp.asarray(initial_state, jnp.float32)
304
+
305
+ # **无需 checkpoint,推理不保存中间值**
306
+ out, final_state = _wkv7_inference_kernel(w, r, k, v, a, b, h0)
307
+ out = jnp.asarray(out, dtype)
308
+ return out, final_state if output_final_state else out
309
+
310
+ # 返回两个函数,用户按需选择
311
+ return [generalized_delta_rule, generalized_delta_rule_inference]
@@ -0,0 +1,42 @@
1
+ cmake_minimum_required(VERSION 3.18)
2
+ project(wkv7_single_step LANGUAGES CXX CUDA)
3
+
4
+ find_package(CUDAToolkit REQUIRED)
5
+
6
+ # ---------- 1. 找到 Python ----------
7
+ find_package(Python3 REQUIRED COMPONENTS Interpreter)
8
+
9
+ # ---------- 2. 取 XLA 头文件路径 ----------
10
+ execute_process(
11
+ COMMAND "${Python3_EXECUTABLE}" -c "import jax; print(jax.ffi.include_dir())"
12
+ OUTPUT_VARIABLE XLA_INCLUDE_DIR
13
+ OUTPUT_STRIP_TRAILING_WHITESPACE
14
+ )
15
+ if(NOT XLA_INCLUDE_DIR)
16
+ message(FATAL_ERROR "Cannot get XLA include dir from jax.ffi")
17
+ endif()
18
+ message(STATUS "XLA include directory: ${XLA_INCLUDE_DIR}")
19
+
20
+ # ---------- 3. 生成共享库 ----------
21
+ add_library(wkv7_single_step SHARED wkv7_single_step_ffi.cu)
22
+
23
+ # 3-1. 头文件搜索路径
24
+ target_include_directories(wkv7_single_step PRIVATE ${XLA_INCLUDE_DIR})
25
+
26
+ # 3-2. 链接 CUDA 运行时
27
+ target_link_libraries(wkv7_single_step PRIVATE CUDA::cudart)
28
+
29
+ # 3-3. 关键:C++17 / CUDA17 标准
30
+ target_compile_features(wkv7_single_step PUBLIC cxx_std_17)
31
+ set_target_properties(wkv7_single_step PROPERTIES
32
+ CUDA_STANDARD 17
33
+ CUDA_SEPARABLE_COMPILATION ON
34
+ POSITION_INDEPENDENT_CODE ON
35
+ PREFIX "" # 去掉默认的 "lib" 前缀
36
+ )
37
+
38
+ # ---------- 4. 安装 ----------
39
+ # 把 .so 直接装到源码目录,方便 ctypes.CDLL 加载
40
+ install(TARGETS wkv7_single_step
41
+ LIBRARY DESTINATION "${CMAKE_SOURCE_DIR}"
42
+ RUNTIME DESTINATION "${CMAKE_SOURCE_DIR}") # Windows 用 RUNTIME
@@ -0,0 +1,172 @@
1
+ #include <cuda_bf16.h>
2
+ #include <cuda_runtime.h>
3
+ #include <xla/ffi/api/ffi.h>
4
+ #include <vector>
5
+ #include <cstdint>
6
+
7
+ namespace ffi = xla::ffi;
8
+ using bf = __nv_bfloat16;
9
+
10
+ /* -------------------- 设备端辅助 -------------------- */
11
+ __device__ inline float to_float(const bf &u) {
12
+ return __bfloat162float(u);
13
+ }
14
+ __device__ inline bf to_bf(const float &u) {
15
+ return __float2bfloat16_rn(u);
16
+ }
17
+ typedef bf *__restrict__ F_;
18
+
19
+ /* -------------------- 前向 Kernel(修复) -------------------- */
20
+ template<int C>
21
+ __launch_bounds__(C, 2)
22
+ __global__ void forward_kernel_single_step(
23
+ int B, int H,
24
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
25
+ bf *y_, float *s_, float *h0_)
26
+ {
27
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
28
+ float state[C] = {0};
29
+ __shared__ float q[C], k[C], w[C], a[C], b[C];
30
+
31
+ // 加载初始状态 (B, H, C, C)
32
+ int64_t h0_base = ((int64_t)bb * H + hh) * C * C + i * C;
33
+ #pragma unroll
34
+ for (int j = 0; j < C; ++j) state[j] = h0_[h0_base + j];
35
+
36
+ // 单步索引: (B, H, C)
37
+ int64_t ind = (int64_t)bb * H * C + hh * C + i;
38
+
39
+ __syncthreads();
40
+ q[i] = to_float(q_[ind]);
41
+ w[i] = __expf(-__expf(to_float(w_[ind])));
42
+ k[i] = to_float(k_[ind]);
43
+ a[i] = to_float(a_[ind]);
44
+ b[i] = to_float(b_[ind]);
45
+ __syncthreads();
46
+
47
+ float sa = 0.f;
48
+ #pragma unroll
49
+ for (int j = 0; j < C; ++j) sa += a[j] * state[j];
50
+
51
+ float v_val = to_float(v_[ind]);
52
+ float y = 0.f;
53
+ #pragma unroll
54
+ for (int j = 0; j < C; ++j) {
55
+ float &s = state[j];
56
+ s = s * w[j] + sa * b[j] + k[j] * v_val;
57
+ y += s * q[j];
58
+ }
59
+ y_[ind] = to_bf(y);
60
+
61
+ // 写入最终状态
62
+ int64_t s_base = ((int64_t)bb * H + hh) * C * C + i * C;
63
+ #pragma unroll
64
+ for (int j = 0; j < C; ++j) s_[s_base + j] = state[j];
65
+ }
66
+
67
+ /* -------------------- 反向 Kernel(补充) -------------------- */
68
+ template<int C>
69
+ __launch_bounds__(C, 2)
70
+ __global__ void backward_kernel_single_step(
71
+ int B, int H,
72
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ dy_,
73
+ float *s_, float *dht_, bf *dw_, bf *dq_, bf *dk_, bf *dv_, bf *da_, bf *db_)
74
+ {
75
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
76
+ float stateT[C] = {0}, dstate[C] = {0};
77
+
78
+ int64_t dht_base = ((int64_t)bb * H + hh) * C * C + i * C;
79
+ #pragma unroll
80
+ for (int j = 0; j < C; ++j) dstate[j] = dht_[dht_base + j];
81
+
82
+ __shared__ float w[C], q[C], k[C], v[C], dy[C];
83
+ int64_t ind = (int64_t)bb * H * C + hh * C + i;
84
+
85
+ __syncthreads();
86
+ q[i] = to_float(q_[ind]);
87
+ float wi_fac = -__expf(to_float(w_[ind]));
88
+ w[i] = __expf(wi_fac);
89
+ k[i] = to_float(k_[ind]);
90
+ v[i] = to_float(v_[ind]);
91
+ dy[i] = to_float(dy_[ind]);
92
+ __syncthreads();
93
+
94
+ // 从 s_ 加载 stateT(float4 优化可在此处添加)
95
+ int64_t s_base = ((int64_t)bb * H + hh) * C * C + i * C;
96
+ #pragma unroll
97
+ for (int j = 0; j < C; ++j) stateT[j] = s_[s_base + j];
98
+
99
+ float dq_val = 0.f, dw_val = 0.f, dk_val = 0.f, dv_val = 0.f, da_val = 0.f, db_val = 0.f;
100
+ float iwi = 1.0f / (w[i] + 1e-6f);
101
+
102
+ #pragma unroll
103
+ for (int j = 0; j < C; ++j) {
104
+ stateT[j] = (stateT[j] - k[i] * v[j]) * iwi;
105
+ dstate[j] += dy[i] * q[j];
106
+
107
+ dq_val += stateT[j] * dy[j];
108
+ dw_val += dstate[j] * stateT[j];
109
+ dk_val += dstate[j] * v[j];
110
+ dv_val += dstate[j] * k[j];
111
+ }
112
+
113
+ dq_[ind] = to_bf(dq_val);
114
+ dw_[ind] = to_bf(dw_val * w[i] * wi_fac);
115
+ dk_[ind] = to_bf(dk_val);
116
+ dv_[ind] = to_bf(dv_val);
117
+ }
118
+
119
+ /* -------------------- Host 函数(修复调用) -------------------- */
120
+ static ffi::Error WKV7SingleStepFwdHost(
121
+ cudaStream_t stream,
122
+ ffi::Buffer<ffi::BF16> w,
123
+ ffi::Buffer<ffi::BF16> q,
124
+ ffi::Buffer<ffi::BF16> k,
125
+ ffi::Buffer<ffi::BF16> v,
126
+ ffi::Buffer<ffi::BF16> a,
127
+ ffi::Buffer<ffi::BF16> b,
128
+ ffi::Buffer<ffi::F32> h0,
129
+ ffi::ResultBuffer<ffi::BF16> y,
130
+ ffi::ResultBuffer<ffi::F32> s)
131
+ {
132
+ auto dims = w.dimensions();
133
+ int B = dims[0], H = dims[1];
134
+ constexpr int C = _C_; // 从编译选项获取
135
+ dim3 block(C);
136
+ dim3 grid(H, B);
137
+
138
+ // ✅ 修复:显式指定模板参数 <_C_>
139
+ forward_kernel_single_step<_C_><<<grid, block, 0, stream>>>(
140
+ B, H,
141
+ reinterpret_cast<bf *>(w.typed_data()),
142
+ reinterpret_cast<bf *>(q.typed_data()),
143
+ reinterpret_cast<bf *>(k.typed_data()),
144
+ reinterpret_cast<bf *>(v.typed_data()),
145
+ reinterpret_cast<bf *>(a.typed_data()),
146
+ reinterpret_cast<bf *>(b.typed_data()),
147
+ reinterpret_cast<bf *>(y->typed_data()),
148
+ s->typed_data(),
149
+ h0.typed_data());
150
+
151
+ cudaError_t err = cudaGetLastError();
152
+ if (err != cudaSuccess)
153
+ return ffi::Error::Internal(
154
+ std::string("CUDA forward_kernel_single_step error: ") + cudaGetErrorString(err));
155
+ return ffi::Error::Success();
156
+ }
157
+
158
+ /* -------------------- FFI 符号注册 -------------------- */
159
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
160
+ Wkv7SingleStepFwd, WKV7SingleStepFwdHost,
161
+ ffi::Ffi::Bind()
162
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
163
+ .Arg<ffi::Buffer<ffi::BF16>>() // w
164
+ .Arg<ffi::Buffer<ffi::BF16>>() // q
165
+ .Arg<ffi::Buffer<ffi::BF16>>() // k
166
+ .Arg<ffi::Buffer<ffi::BF16>>() // v
167
+ .Arg<ffi::Buffer<ffi::BF16>>() // a
168
+ .Arg<ffi::Buffer<ffi::BF16>>() // b
169
+ .Arg<ffi::Buffer<ffi::F32>>() // h0
170
+ .Ret<ffi::Buffer<ffi::BF16>>() // y
171
+ .Ret<ffi::Buffer<ffi::F32>>() // s
172
+ , {ffi::Traits::kCmdBufferCompatible});