rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,250 @@
1
+ import functools
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import Literal
5
+
6
+ import triton
7
+ from packaging import version
8
+ import torch
9
+ from enum import Enum
10
+ import contextlib
11
+
12
+
13
+ @lru_cache(maxsize=None)
14
+ def get_multiprocessor_count(tensor_idx: int = 0) -> int:
15
+ return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
16
+ "multiprocessor_count"
17
+ ]
18
+
19
+
20
+ @lru_cache(maxsize=None)
21
+ def get_available_device() -> str:
22
+ try:
23
+ return triton.runtime.driver.active.get_current_target().backend
24
+ except BaseException:
25
+ import warnings
26
+
27
+ warnings.warn(
28
+ ("Triton is not supported on current platform, roll back to CPU."),
29
+ stacklevel=1,
30
+ )
31
+ return "cpu"
32
+
33
+
34
+ @lru_cache(maxsize=None)
35
+ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
36
+ device = get_available_device()
37
+ if device == "cuda":
38
+ return "nvidia"
39
+ elif device == "hip":
40
+ return "amd"
41
+ elif device == "xpu":
42
+ return "intel"
43
+ else:
44
+ return device
45
+
46
+
47
+ # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
48
+ # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
49
+ # Therefore, we need to check the triton backend to determine the actual GPU vendor.
50
+ device = get_available_device() if get_available_device() != "hip" else "cuda"
51
+
52
+ device_platform = _check_platform()
53
+
54
+ is_intel = device_platform == "intel"
55
+ is_nvidia = device_platform == "nvidia"
56
+ is_amd = device_platform == "amd"
57
+
58
+ use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
59
+
60
+
61
+ @lru_cache(maxsize=None)
62
+ def check_pytorch_version(version_s: str = "2.4") -> bool:
63
+ return version.parse(torch.__version__) >= version.parse(version_s)
64
+
65
+
66
+ is_intel_a770 = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
67
+ device = get_available_device() if get_available_device() != "hip" else "cuda"
68
+ device_torch_lib = getattr(torch, device)
69
+ if check_pytorch_version("2.4"):
70
+ device = "cuda" if device == "cpu" else device
71
+ autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
72
+ autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
73
+
74
+ def custom_device_ctx(index: int):
75
+ return device_torch_lib.device(index)
76
+ else:
77
+ assert device == "cuda", (
78
+ "Only cuda device is supported for PyTorch version < 2.4.0."
79
+ )
80
+ autocast_custom_fwd = device_torch_lib.amp.custom_fwd
81
+ autocast_custom_bwd = device_torch_lib.amp.custom_bwd
82
+
83
+ def custom_device_ctx(index: int):
84
+ return torch.cuda.device(index)
85
+
86
+
87
+ # Nvidia Ampere or newer, haven't check AMD and intel yet.
88
+ is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
89
+
90
+
91
+ def get_all_max_shared_memory():
92
+ return [
93
+ triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
94
+ for i in range(device_torch_lib.device_count())
95
+ ]
96
+
97
+
98
+ device_shared_mem_list = get_all_max_shared_memory()
99
+
100
+
101
+ @lru_cache(maxsize=None)
102
+ def is_triton_shared_mem_enough(
103
+ max_shared_mem: int = 102400, tensor_idx: int = 0
104
+ ) -> bool:
105
+ max_shared_memory = device_shared_mem_list[tensor_idx]
106
+ return max_shared_memory >= max_shared_mem
107
+
108
+
109
+ device_capacity = is_triton_shared_mem_enough()
110
+
111
+
112
+ def _cpu_device_warning():
113
+ import warnings
114
+
115
+ warnings.warn(
116
+ ("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
117
+ )
118
+
119
+
120
+ def get_all_max_shared_mem():
121
+ try:
122
+ return [
123
+ triton.runtime.driver.active.utils.get_device_properties(i)[
124
+ "max_shared_mem"
125
+ ]
126
+ for i in range(device_torch_lib.device_count())
127
+ ]
128
+ except BaseException:
129
+ _cpu_device_warning()
130
+ return [-1]
131
+
132
+
133
+ class Backend(Enum):
134
+ ADA = 101376 # RTX 4090
135
+ AMPERE = 166912 # A100
136
+ HOPPER = 232448 # H100
137
+ DEFAULT = 102400 # Default
138
+
139
+ @classmethod
140
+ def get_shared_memory(cls, arch: str) -> int:
141
+ try:
142
+ return cls[arch.upper()].value
143
+ except KeyError:
144
+ return cls.DEFAULT.value
145
+
146
+
147
+ @lru_cache(maxsize=None)
148
+ def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
149
+ try:
150
+ device_shared_mem_list = get_all_max_shared_mem()
151
+ max_shared_memory = device_shared_mem_list[tensor_idx]
152
+ return max_shared_memory >= Backend.get_shared_memory(arch)
153
+ except Exception:
154
+ return False
155
+
156
+
157
+ def tensor_cache(fn):
158
+ """
159
+ A decorator that caches the most recent result of a function with tensor inputs.
160
+
161
+ This decorator will store the output of the decorated function for the most recent set of input tensors.
162
+ If the function is called again with the same input tensors, it will return the cached result.
163
+
164
+
165
+ Args:
166
+ fn (Callable[..., torch.Tensor]):
167
+ The function to be decorated. It should take tensor inputs and return tensor outputs.
168
+
169
+ Returns:
170
+ Callable[..., torch.Tensor]:
171
+ A wrapped version of the input function with single-entry caching.
172
+ """
173
+ last_args = None
174
+ last_kwargs = None
175
+ last_result = None
176
+
177
+ @functools.wraps(fn)
178
+ def wrapper(*args, **kwargs):
179
+ nonlocal last_args, last_kwargs, last_result
180
+
181
+ if last_args is not None and last_kwargs is not None:
182
+ if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
183
+ if all(a is b for a, b in zip(args, last_args)) and all(
184
+ k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
185
+ ):
186
+ return last_result
187
+
188
+ result = fn(*args, **kwargs)
189
+ last_args, last_kwargs, last_result = args, kwargs, result
190
+ return result
191
+
192
+ return wrapper
193
+
194
+
195
+ @tensor_cache
196
+ def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
197
+ return cu_seqlens[1:] - cu_seqlens[:-1]
198
+
199
+
200
+ @tensor_cache
201
+ def prepare_chunk_indices(
202
+ cu_seqlens: torch.LongTensor, chunk_size: int
203
+ ) -> torch.LongTensor:
204
+ indices = torch.cat(
205
+ [
206
+ torch.arange(n)
207
+ for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
208
+ ]
209
+ )
210
+ return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
211
+
212
+
213
+ def input_guard(fn):
214
+ """
215
+ A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
216
+ """
217
+
218
+ @functools.wraps(fn)
219
+ def wrapper(*args, **kwargs):
220
+ contiguous_args = (
221
+ i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
222
+ )
223
+ contiguous_kwargs = {
224
+ k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
225
+ for k, v in kwargs.items()
226
+ }
227
+
228
+ tensor = None
229
+ for arg in args:
230
+ if isinstance(arg, torch.Tensor):
231
+ tensor = arg
232
+ break
233
+ if tensor is None:
234
+ for value in kwargs.values():
235
+ if isinstance(value, torch.Tensor):
236
+ tensor = value
237
+ break
238
+
239
+ if tensor is not None:
240
+ ctx = custom_device_ctx(tensor.device.index)
241
+ else:
242
+ ctx = contextlib.nullcontext()
243
+
244
+ with ctx:
245
+ return fn(*contiguous_args, **contiguous_kwargs)
246
+
247
+ return wrapper
248
+
249
+
250
+ is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
@@ -0,0 +1,42 @@
1
+ cmake_minimum_required(VERSION 3.18)
2
+ project(wkv7 LANGUAGES CXX CUDA)
3
+
4
+ find_package(CUDAToolkit REQUIRED)
5
+
6
+ # ---------- 1. 找到 Python ----------
7
+ find_package(Python3 REQUIRED COMPONENTS Interpreter)
8
+
9
+ # ---------- 2. 取 XLA 头文件路径 ----------
10
+ execute_process(
11
+ COMMAND "${Python3_EXECUTABLE}" -c "from jax import ffi; print(ffi.include_dir())"
12
+ OUTPUT_VARIABLE XLA_INCLUDE_DIR
13
+ OUTPUT_STRIP_TRAILING_WHITESPACE
14
+ )
15
+ if(NOT XLA_INCLUDE_DIR)
16
+ message(FATAL_ERROR "Cannot get XLA include dir from jax.ffi")
17
+ endif()
18
+ message(STATUS "XLA include directory: ${XLA_INCLUDE_DIR}")
19
+
20
+ # ---------- 3. 生成共享库 ----------
21
+ add_library(wkv7 SHARED wkv7_ffi.cu)
22
+
23
+ # 3-1. 头文件搜索路径
24
+ target_include_directories(wkv7 PRIVATE ${XLA_INCLUDE_DIR})
25
+
26
+ # 3-2. 链接 CUDA 运行时
27
+ target_link_libraries(wkv7 PRIVATE CUDA::cudart)
28
+
29
+ # 3-3. 关键:C++17 / CUDA17 标准
30
+ target_compile_features(wkv7 PUBLIC cxx_std_17)
31
+ set_target_properties(wkv7 PROPERTIES
32
+ CUDA_STANDARD 17
33
+ CUDA_SEPARABLE_COMPILATION ON
34
+ POSITION_INDEPENDENT_CODE ON
35
+ PREFIX "" # 去掉默认的 "lib" 前缀
36
+ )
37
+
38
+ # ---------- 4. 安装 ----------
39
+ # 把 .so 直接装到源码目录(与 wkv7_jax.py 同一级),方便 ctypes.CDLL 加载
40
+ install(TARGETS wkv7
41
+ LIBRARY DESTINATION "${CMAKE_SOURCE_DIR}"
42
+ RUNTIME DESTINATION "${CMAKE_SOURCE_DIR}") # Windows 用 RUNTIME
@@ -0,0 +1,399 @@
1
+ #include <cuda_bf16.h>
2
+ #include <cuda_runtime.h>
3
+ #include <xla/ffi/api/ffi.h>
4
+ #include <vector>
5
+ #include <cstdint>
6
+ // ref link:https://github.com/BlinkDL/RWKV-CUDA/tree/main/rwkv7_fast_fused
7
+ namespace ffi = xla::ffi;
8
+
9
+ /* -------------------- 类型别名 -------------------- */
10
+ using bf = __nv_bfloat16;
11
+
12
+ /* -------------------- 设备端辅助 -------------------- */
13
+ __device__ inline float to_float(const bf &u) {
14
+ return __bfloat162float(u);
15
+ }
16
+ __device__ inline bf to_bf(const float &u) {
17
+ return __float2bfloat16_rn(u);
18
+ }
19
+ typedef bf *__restrict__ F_;
20
+
21
+ /* -------------------- Kernel -------------------- */
22
+ // 【优化1】模板化 + launch_bounds,提升 Occupancy
23
+ template<int C> __launch_bounds__(C, 2)
24
+ __global__ void forward_kernel(int T, int H,
25
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
26
+ bf *y_, float *s_, float *sa_, float *h0_) {
27
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
28
+ float state[C] = {0};
29
+ __shared__ float q[C], k[C], w[C], a[C], b[C];
30
+
31
+ int64_t h0_base = ((int64_t)bb * H + hh) * C * C + i * C;
32
+
33
+ #pragma unroll
34
+ for (int j = 0; j < C; ++j) state[j] = h0_[h0_base + j];
35
+
36
+ for (int t = 0; t < T; ++t) {
37
+ // 【优化2】强制 int64_t 防止溢出
38
+ int64_t ind = (int64_t)bb * T * H * C + (int64_t)t * H * C + hh * C + i;
39
+
40
+ __syncthreads();
41
+ q[i] = to_float(q_[ind]);
42
+ w[i] = __expf(-__expf(to_float(w_[ind])));
43
+ k[i] = to_float(k_[ind]);
44
+ a[i] = to_float(a_[ind]);
45
+ b[i] = to_float(b_[ind]);
46
+ __syncthreads();
47
+
48
+ float sa = 0.f;
49
+ #pragma unroll
50
+ for (int j = 0; j < C; ++j) sa += a[j] * state[j];
51
+ sa_[ind] = sa;
52
+
53
+ float v_val = to_float(v_[ind]);
54
+ float y = 0.f;
55
+ #pragma unroll
56
+ for (int j = 0; j < C; ++j) {
57
+ float &s = state[j];
58
+ s = s * w[j] + sa * b[j] + k[j] * v_val;
59
+ y += s * q[j];
60
+ }
61
+ y_[ind] = to_bf(y);
62
+
63
+ if ((t + 1) % _CHUNK_LEN_ == 0) {
64
+ int64_t base = ((int64_t)bb * H + hh) * (T / _CHUNK_LEN_) * C * C +
65
+ ((int64_t)t / _CHUNK_LEN_) * C * C + i;
66
+ #pragma unroll
67
+ for (int j = 0; j < C; ++j) s_[base + j * C] = state[j];
68
+ }
69
+ }
70
+ }
71
+
72
+ // 【优化3】反向 Kernel:模板化 + launch_bounds + float4 向量加载
73
+ template<int C> __launch_bounds__(C, 2)
74
+ __global__ void backward_kernel(int T, int H,
75
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_,
76
+ float *s_, float *sa_, float *dht_, float *dh0_,
77
+ bf *dw_, bf *dq_, bf *dk_, bf *dv_, bf *da_, bf *db_) {
78
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
79
+ float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
80
+
81
+ int64_t dht_base = ((int64_t)bb * H + hh) * C * C + i * C;
82
+
83
+ #pragma unroll
84
+ for (int j = 0; j < C; ++j) {
85
+ dstate[j] = dht_[dht_base + j];
86
+ dstateT[j] = dht_[dht_base + j];
87
+ }
88
+ __shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
89
+ float qi, wi, ki, ai, bi, dyi;
90
+
91
+ for (int t = T - 1; t >= 0; --t) {
92
+ int64_t ind = (int64_t)bb * T * H * C + (int64_t)t * H * C + hh * C + i;
93
+
94
+ __syncthreads();
95
+ q[i] = qi = to_float(q_[ind]);
96
+ float wi_fac = -__expf(to_float(w_[ind]));
97
+ w[i] = wi = __expf(wi_fac);
98
+ k[i] = ki = to_float(k_[ind]);
99
+ a[i] = ai = to_float(a_[ind]);
100
+ b[i] = bi = to_float(b_[ind]);
101
+ v[i] = to_float(v_[ind]);
102
+ dy[i] = dyi = to_float(dy_[ind]);
103
+ sa[i] = sa_[ind];
104
+ __syncthreads();
105
+
106
+ if ((t + 1) % _CHUNK_LEN_ == 0) {
107
+ int64_t base = ((int64_t)bb * H + hh) * (T / _CHUNK_LEN_) * C * C +
108
+ ((int64_t)t / _CHUNK_LEN_) * C * C + i * C;
109
+
110
+ // 【优化4】float4 向量加载,带宽利用率提升 4倍
111
+ const float4* s4 = (const float4*)(s_ + base);
112
+ #pragma unroll
113
+ for (int j4 = 0; j4 < C / 4; ++j4) {
114
+ float4 q_vec = s4[j4];
115
+ const int j = j4 * 4;
116
+ stateT[j + 0] = q_vec.x;
117
+ stateT[j + 1] = q_vec.y;
118
+ stateT[j + 2] = q_vec.z;
119
+ stateT[j + 3] = q_vec.w;
120
+ }
121
+ }
122
+
123
+ float dq_val = 0.f;
124
+ #pragma unroll
125
+ for (int j = 0; j < C; ++j) dq_val += stateT[j] * dy[j];
126
+ dq_[ind] = to_bf(dq_val);
127
+
128
+ float iwi = 1.f / (wi + 1e-6f);
129
+ #pragma unroll
130
+ for (int j = 0; j < C; ++j) {
131
+ stateT[j] = (stateT[j] - ki * v[j] - bi * sa[j]) * iwi;
132
+ dstate[j] += dyi * q[j];
133
+ dstateT[j] += qi * dy[j];
134
+ }
135
+
136
+ float dw = 0.f, dk = 0.f, dv = 0.f, db = 0.f, dSb = 0.f;
137
+ #pragma unroll
138
+ for (int j = 0; j < C; ++j) {
139
+ dw += dstateT[j] * stateT[j];
140
+ dk += dstateT[j] * v[j];
141
+ dv += dstate[j] * k[j];
142
+ dSb += dstate[j] * b[j];
143
+ db += dstateT[j] * sa[j];
144
+ }
145
+ dw_[ind] = to_bf(dw * wi * wi_fac);
146
+ dk_[ind] = to_bf(dk);
147
+ dv_[ind] = to_bf(dv);
148
+ db_[ind] = to_bf(db);
149
+
150
+ __syncthreads();
151
+ dSb_shared[i] = dSb;
152
+ __syncthreads();
153
+
154
+ float da = 0.f;
155
+ #pragma unroll
156
+ for (int j = 0; j < C; ++j) da += stateT[j] * dSb_shared[j];
157
+ da_[ind] = to_bf(da);
158
+
159
+ #pragma unroll
160
+ for (int j = 0; j < C; ++j) {
161
+ dstate[j] = dstate[j] * w[j] + dSb * a[j];
162
+ dstateT[j] = dstateT[j] * wi + ai * dSb_shared[j];
163
+ if (t == 0) dh0_[dht_base + j] = dstate[j];
164
+ }
165
+ }
166
+ }
167
+
168
+ /* -------------------- 推理专用 Kernel -------------------- */
169
+ template<int C> __launch_bounds__(C, 2)
170
+ __global__ void forward_inference_kernel(int T, int H,
171
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
172
+ bf *y_, float *s_, float *h0_) {
173
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
174
+ float state[C] = {0};
175
+ __shared__ float q[C], k[C], w[C], a[C], b[C];
176
+
177
+ int64_t h0_base = ((int64_t)bb * H + hh) * C * C + i * C;
178
+
179
+ #pragma unroll
180
+ for (int j = 0; j < C; ++j) state[j] = h0_[h0_base + j];
181
+
182
+ for (int t = 0; t < T; ++t) {
183
+ int64_t ind = (int64_t)bb * T * H * C + (int64_t)t * H * C + hh * C + i;
184
+
185
+ __syncthreads();
186
+ q[i] = to_float(q_[ind]);
187
+ w[i] = __expf(-__expf(to_float(w_[ind])));
188
+ k[i] = to_float(k_[ind]);
189
+ a[i] = to_float(a_[ind]);
190
+ b[i] = to_float(b_[ind]);
191
+ __syncthreads();
192
+
193
+ float sa = 0.f;
194
+ #pragma unroll
195
+ for (int j = 0; j < C; ++j) sa += a[j] * state[j];
196
+
197
+ float v_val = to_float(v_[ind]);
198
+ float y = 0.f;
199
+ #pragma unroll
200
+ for (int j = 0; j < C; ++j) {
201
+ float &s = state[j];
202
+ s = s * w[j] + sa * b[j] + k[j] * v_val;
203
+ y += s * q[j];
204
+ }
205
+ y_[ind] = to_bf(y);
206
+ }
207
+
208
+ int64_t base = ((int64_t)bb * H + hh) * C * C + i * C;
209
+ #pragma unroll
210
+ for (int j = 0; j < C; ++j) s_[base + j] = state[j];
211
+ }
212
+
213
+ /* -------------------- Host 函数(参数名已统一) -------------------- */
214
+ static ffi::Error WKV7FwdHost(
215
+ cudaStream_t stream,
216
+ ffi::Buffer<ffi::BF16> w,
217
+ ffi::Buffer<ffi::BF16> q,
218
+ ffi::Buffer<ffi::BF16> k,
219
+ ffi::Buffer<ffi::BF16> v,
220
+ ffi::Buffer<ffi::BF16> a, // 原'z',直接对应 kernel 的 a_
221
+ ffi::Buffer<ffi::BF16> b, // 原'a',直接对应 kernel 的 b_
222
+ ffi::Buffer<ffi::F32> h0,
223
+ ffi::ResultBuffer<ffi::BF16> y,
224
+ ffi::ResultBuffer<ffi::F32> s,
225
+ ffi::ResultBuffer<ffi::F32> sa)
226
+ {
227
+ constexpr int C = _C_;
228
+ auto dims = w.dimensions();
229
+ int B = dims[0], T = dims[1], H = dims[2];
230
+ dim3 block(C);
231
+ dim3 grid(H, B);
232
+
233
+ // 【关键】模板实例化调用,参数直接映射
234
+ forward_kernel<_C_><<<grid, block, 0, stream>>>(
235
+ T, H,
236
+ reinterpret_cast<bf *>(w.typed_data()),
237
+ reinterpret_cast<bf *>(q.typed_data()),
238
+ reinterpret_cast<bf *>(k.typed_data()),
239
+ reinterpret_cast<bf *>(v.typed_data()),
240
+ reinterpret_cast<bf *>(a.typed_data()), // 直接映射到 a_
241
+ reinterpret_cast<bf *>(b.typed_data()), // 直接映射到 b_
242
+ reinterpret_cast<bf *>(y->typed_data()),
243
+ s->typed_data(),
244
+ sa->typed_data(),
245
+ h0.typed_data());
246
+
247
+ cudaError_t err = cudaGetLastError();
248
+ if (err != cudaSuccess)
249
+ return ffi::Error::Internal(
250
+ std::string("CUDA forward_kernel error: ") + cudaGetErrorString(err));
251
+ return ffi::Error::Success();
252
+ }
253
+
254
+ static ffi::Error WKV7BwdHost(
255
+ cudaStream_t stream,
256
+ ffi::Buffer<ffi::BF16> w,
257
+ ffi::Buffer<ffi::BF16> q,
258
+ ffi::Buffer<ffi::BF16> k,
259
+ ffi::Buffer<ffi::BF16> v,
260
+ ffi::Buffer<ffi::BF16> a, // 原'z',直接对应 kernel 的 a_
261
+ ffi::Buffer<ffi::BF16> b, // 原'a',直接对应 kernel 的 b_
262
+ ffi::Buffer<ffi::BF16> dy,
263
+ ffi::Buffer<ffi::F32> s,
264
+ ffi::Buffer<ffi::F32> sa,
265
+ ffi::Buffer<ffi::F32> dht,
266
+ ffi::ResultBuffer<ffi::F32> dh0,
267
+ ffi::ResultBuffer<ffi::BF16> dw,
268
+ ffi::ResultBuffer<ffi::BF16> dq,
269
+ ffi::ResultBuffer<ffi::BF16> dk,
270
+ ffi::ResultBuffer<ffi::BF16> dv,
271
+ ffi::ResultBuffer<ffi::BF16> da,
272
+ ffi::ResultBuffer<ffi::BF16> db)
273
+ {
274
+ auto dims = w.dimensions();
275
+ int B = dims[0], T = dims[1], H = dims[2];
276
+ constexpr int C = _C_;
277
+ dim3 block(C);
278
+ dim3 grid(H, B);
279
+
280
+ // 【关键】模板实例化调用,参数直接映射
281
+ backward_kernel<_C_><<<grid, block, 0, stream>>>(
282
+ T, H,
283
+ reinterpret_cast<bf *>(w.typed_data()),
284
+ reinterpret_cast<bf *>(q.typed_data()),
285
+ reinterpret_cast<bf *>(k.typed_data()),
286
+ reinterpret_cast<bf *>(v.typed_data()),
287
+ reinterpret_cast<bf *>(a.typed_data()), // 直接映射到 a_
288
+ reinterpret_cast<bf *>(b.typed_data()), // 直接映射到 b_
289
+ reinterpret_cast<bf *>(dy.typed_data()),
290
+ s.typed_data(),
291
+ sa.typed_data(),
292
+ dht.typed_data(),
293
+ dh0->typed_data(),
294
+ reinterpret_cast<bf *>(dw->typed_data()),
295
+ reinterpret_cast<bf *>(dq->typed_data()),
296
+ reinterpret_cast<bf *>(dk->typed_data()),
297
+ reinterpret_cast<bf *>(dv->typed_data()),
298
+ reinterpret_cast<bf *>(da->typed_data()),
299
+ reinterpret_cast<bf *>(db->typed_data()));
300
+
301
+ cudaError_t err = cudaGetLastError();
302
+ if (err != cudaSuccess)
303
+ return ffi::Error::Internal(
304
+ std::string("CUDA backward_kernel error: ") + cudaGetErrorString(err));
305
+ return ffi::Error::Success();
306
+ }
307
+
308
+ static ffi::Error WKV7InferenceHost(
309
+ cudaStream_t stream,
310
+ ffi::Buffer<ffi::BF16> w,
311
+ ffi::Buffer<ffi::BF16> q,
312
+ ffi::Buffer<ffi::BF16> k,
313
+ ffi::Buffer<ffi::BF16> v,
314
+ ffi::Buffer<ffi::BF16> a, // 直接对应 kernel 的 a_
315
+ ffi::Buffer<ffi::BF16> b, // 直接对应 kernel 的 b_
316
+ ffi::Buffer<ffi::F32> h0,
317
+ ffi::ResultBuffer<ffi::BF16> y,
318
+ ffi::ResultBuffer<ffi::F32> s)
319
+ {
320
+ constexpr int C = _C_;
321
+ auto dims = w.dimensions();
322
+ int B = dims[0], T = dims[1], H = dims[2];
323
+ dim3 block(C);
324
+ dim3 grid(H, B);
325
+
326
+ // 【关键】模板实例化调用,参数直接映射
327
+ forward_inference_kernel<_C_><<<grid, block, 0, stream>>>(
328
+ T, H,
329
+ reinterpret_cast<bf *>(w.typed_data()),
330
+ reinterpret_cast<bf *>(q.typed_data()),
331
+ reinterpret_cast<bf *>(k.typed_data()),
332
+ reinterpret_cast<bf *>(v.typed_data()),
333
+ reinterpret_cast<bf *>(a.typed_data()), // 直接映射到 a_
334
+ reinterpret_cast<bf *>(b.typed_data()), // 直接映射到 b_
335
+ reinterpret_cast<bf *>(y->typed_data()),
336
+ s->typed_data(),
337
+ h0.typed_data());
338
+
339
+ cudaError_t err = cudaGetLastError();
340
+ if (err != cudaSuccess)
341
+ return ffi::Error::Internal(
342
+ std::string("CUDA forward_inference_kernel error: ") + cudaGetErrorString(err));
343
+ return ffi::Error::Success();
344
+ }
345
+
346
+ /* -------------------- FFI 注册(参数名已对齐) -------------------- */
347
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
348
+ Wkv7Fwd, WKV7FwdHost,
349
+ ffi::Ffi::Bind()
350
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
351
+ .Arg<ffi::Buffer<ffi::BF16>>() // w
352
+ .Arg<ffi::Buffer<ffi::BF16>>() // q
353
+ .Arg<ffi::Buffer<ffi::BF16>>() // k
354
+ .Arg<ffi::Buffer<ffi::BF16>>() // v
355
+ .Arg<ffi::Buffer<ffi::BF16>>() // a (原z)
356
+ .Arg<ffi::Buffer<ffi::BF16>>() // b (原a)
357
+ .Arg<ffi::Buffer<ffi::F32>>() // h0
358
+ .Ret<ffi::Buffer<ffi::BF16>>() // y
359
+ .Ret<ffi::Buffer<ffi::F32>>() // s
360
+ .Ret<ffi::Buffer<ffi::F32>>() // sa
361
+ , {ffi::Traits::kCmdBufferCompatible});
362
+
363
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
364
+ Wkv7Bwd, WKV7BwdHost,
365
+ ffi::Ffi::Bind()
366
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
367
+ .Arg<ffi::Buffer<ffi::BF16>>() // w
368
+ .Arg<ffi::Buffer<ffi::BF16>>() // q
369
+ .Arg<ffi::Buffer<ffi::BF16>>() // k
370
+ .Arg<ffi::Buffer<ffi::BF16>>() // v
371
+ .Arg<ffi::Buffer<ffi::BF16>>() // a (原z)
372
+ .Arg<ffi::Buffer<ffi::BF16>>() // b (原a)
373
+ .Arg<ffi::Buffer<ffi::BF16>>() // dy
374
+ .Arg<ffi::Buffer<ffi::F32>>() // s
375
+ .Arg<ffi::Buffer<ffi::F32>>() // sa
376
+ .Arg<ffi::Buffer<ffi::F32>>() // dht
377
+ .Ret<ffi::Buffer<ffi::F32>>() // dh0
378
+ .Ret<ffi::Buffer<ffi::BF16>>() // dw
379
+ .Ret<ffi::Buffer<ffi::BF16>>() // dq
380
+ .Ret<ffi::Buffer<ffi::BF16>>() // dk
381
+ .Ret<ffi::Buffer<ffi::BF16>>() // dv
382
+ .Ret<ffi::Buffer<ffi::BF16>>() // da
383
+ .Ret<ffi::Buffer<ffi::BF16>>() // db
384
+ , {ffi::Traits::kCmdBufferCompatible});
385
+
386
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
387
+ Wkv7Inference, WKV7InferenceHost,
388
+ ffi::Ffi::Bind()
389
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
390
+ .Arg<ffi::Buffer<ffi::BF16>>() // w
391
+ .Arg<ffi::Buffer<ffi::BF16>>() // q
392
+ .Arg<ffi::Buffer<ffi::BF16>>() // k
393
+ .Arg<ffi::Buffer<ffi::BF16>>() // v
394
+ .Arg<ffi::Buffer<ffi::BF16>>() // a
395
+ .Arg<ffi::Buffer<ffi::BF16>>() // b
396
+ .Arg<ffi::Buffer<ffi::F32>>() // h0
397
+ .Ret<ffi::Buffer<ffi::BF16>>() // y
398
+ .Ret<ffi::Buffer<ffi::F32>>() // s (final state)
399
+ , {ffi::Traits::kCmdBufferCompatible});