rwkv-ops 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (31) hide show
  1. rwkv_ops/__init__.py +5 -6
  2. rwkv_ops/rwkv6_kernel/__init__.py +0 -6
  3. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  4. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  5. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  6. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  7. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  8. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  9. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  10. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  11. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  12. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  13. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +21 -23
  14. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +14 -10
  15. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  16. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  17. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +4 -4
  18. rwkv_ops/rwkv7_kernel/__init__.py +77 -29
  19. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  20. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +279 -0
  21. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +237 -0
  22. rwkv_ops/rwkv7_kernel/jax_op.py +6 -5
  23. rwkv_ops/rwkv7_kernel/native_keras_op.py +5 -6
  24. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +123 -0
  25. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +165 -0
  26. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +35 -0
  27. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/METADATA +28 -27
  28. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/RECORD +30 -13
  29. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/WHEEL +1 -2
  30. rwkv_ops-0.2.2.dist-info/top_level.txt +0 -1
  31. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info/licenses}/LICENSE.txt +0 -0
@@ -12,18 +12,17 @@ def transpose_head(x, head_first):
12
12
 
13
13
 
14
14
  def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
15
- USE_KERNEL = False
15
+ USE_TRITON_KERNEL = False
16
16
  if keras.config.backend() == "torch":
17
17
  import torch
18
18
 
19
19
  if KERNEL_TYPE.lower() == "triton":
20
20
  from .torch_op import generalized_delta_rule
21
21
 
22
- USE_KERNEL = True
22
+ USE_TRITON_KERNEL = True
23
23
 
24
24
  elif KERNEL_TYPE.lower() == "cuda":
25
25
  CHUNK_LEN = 16
26
- USE_KERNEL = True
27
26
  from torch.utils.cpp_extension import load
28
27
  import os
29
28
 
@@ -44,8 +43,8 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
44
43
  load(
45
44
  name="wind_backstepping",
46
45
  sources=[
47
- os.path.join(current_dir_path, "cuda_kernel/wkv7_cuda.cu"),
48
- os.path.join(current_dir_path, "cuda_kernel/wkv7_op.cpp"),
46
+ os.path.join(current_dir_path, "torch_cuda_kernel/wkv7_cuda.cu"),
47
+ os.path.join(current_dir_path, "torch_cuda_kernel/wkv7_op.cpp"),
49
48
  ],
50
49
  is_python_module=False,
51
50
  verbose=True,
@@ -54,8 +53,8 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
54
53
 
55
54
  class WindBackstepping(torch.autograd.Function):
56
55
  @staticmethod
57
- def forward(ctx, w, q, k, v, z, b):
58
- B, T, H, C = w.shape
56
+ def forward(ctx, w, q, k, v, z, b, h0):
57
+ B, T, H, N = w.shape
59
58
  DTYPE = q.dtype
60
59
  q = ops.cast(q, "bfloat16")
61
60
  k = ops.cast(k, "bfloat16")
@@ -63,30 +62,42 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
63
62
  z = ops.cast(z, "bfloat16")
64
63
  b = ops.cast(b, "bfloat16")
65
64
  w = ops.cast(w, "bfloat16")
66
- assert T % CHUNK_LEN == 0
65
+ if T % CHUNK_LEN != 0:
66
+ raise ValueError(
67
+ "RWKV输入的序列长度必须可以被16整除"
68
+ "Please make sure the sequence length is divisible by 16"
69
+ )
67
70
  assert all(i.is_contiguous() for i in [w, q, k, v, z, b])
68
71
  y = torch.empty_like(v)
69
72
  s = torch.empty(
70
- B, H, T // CHUNK_LEN, C, C, dtype=torch.float32, device=w.device
73
+ B, H, T // CHUNK_LEN, N, N, dtype=torch.float32, device=w.device
71
74
  )
72
- sa = torch.empty(B, T, H, C, dtype=torch.float32, device=w.device)
73
- torch.ops.wind_backstepping.forward(w, q, k, v, z, b, y, s, sa)
75
+ sa = torch.empty(B, T, H, N, dtype=torch.float32, device=w.device)
76
+ torch.ops.wind_backstepping.forward(w, q, k, v, z, b, y, s, sa, h0)
74
77
  ctx.save_for_backward(w, q, k, v, z, b, s, sa)
75
- return ops.cast(y, DTYPE)
78
+ last_state = torch.empty_like(h0)
79
+ last_state.copy_(ops.transpose(s[:, :, -1], [0, 1, 3, 2]))
80
+
81
+ return ops.cast(y, DTYPE), last_state
76
82
 
77
83
  @staticmethod
78
- def backward(ctx, dy):
84
+ def backward(ctx, dy, dht):
79
85
  DTYPE = dy.dtype
80
86
  dy = ops.cast(dy, torch.bfloat16)
81
87
  dy = dy.contiguous()
82
- assert all(i.dtype == torch.bfloat16 for i in [dy])
83
- assert all(i.is_contiguous() for i in [dy])
88
+
84
89
  w, q, k, v, z, b, s, sa = ctx.saved_tensors
90
+ dht = ops.cast(dht, "float32")
91
+ dht = dht.contiguous()
92
+ assert all(i.dtype == torch.bfloat16 for i in [dy])
93
+ assert all(i.is_contiguous() for i in [dy, dht])
94
+ dh0 = torch.empty(dht.shape, dtype=dht.dtype, device=dht.device)
85
95
  dw, dq, dk, dv, dz, db = [
86
96
  torch.empty_like(x) for x in [w, q, k, v, z, b]
87
97
  ]
98
+
88
99
  torch.ops.wind_backstepping.backward(
89
- w, q, k, v, z, b, dy, s, sa, dw, dq, dk, dv, dz, db
100
+ w, q, k, v, z, b, dy, s, sa, dht, dh0, dw, dq, dk, dv, dz, db
90
101
  )
91
102
  return (
92
103
  ops.cast(dw, DTYPE),
@@ -95,9 +106,10 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
95
106
  ops.cast(dv, DTYPE),
96
107
  ops.cast(dz, DTYPE),
97
108
  ops.cast(db, DTYPE),
109
+ dh0,
98
110
  )
99
111
 
100
- def RUN_CUDA_RWKV7g(q, w, k, v, a, b):
112
+ def RUN_CUDA_RWKV7g(q, w, k, v, a, b, h0):
101
113
  B, T, H, C = q.shape
102
114
  q = q.contiguous()
103
115
  w = w.contiguous()
@@ -105,7 +117,8 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
105
117
  v = v.contiguous()
106
118
  a = a.contiguous()
107
119
  b = b.contiguous()
108
- return WindBackstepping.apply(w, q, k, v, a, b).view(B, T, H * C)
120
+ out, state = WindBackstepping.apply(w, q, k, v, a, b, h0)
121
+ return out, state
109
122
 
110
123
  def generalized_delta_rule(
111
124
  r: torch.Tensor,
@@ -125,26 +138,61 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
125
138
  a = transpose_head(a, head_first)
126
139
  b = transpose_head(b, head_first)
127
140
  w = transpose_head(w, head_first)
128
- return RUN_CUDA_RWKV7g(r, w, k, v, a, b), None
141
+ B, T, H, N = w.shape
142
+ if initial_state is None:
143
+ initial_state = ops.zeros((B, H, N, N), "float32")
144
+ else:
145
+ initial_state = ops.cast(initial_state, "float32")
146
+ out, state = RUN_CUDA_RWKV7g(r, w, k, v, a, b, initial_state)
147
+ if output_final_state:
148
+ return out, state
149
+ return out
129
150
  else:
130
151
  from .native_keras_op import generalized_delta_rule
131
152
 
132
- USE_KERNEL = False
153
+ USE_TRITON_KERNEL = False
133
154
  elif keras.config.backend() == "jax":
134
155
  from jax.lib import xla_bridge
135
156
  import os
136
157
 
137
- if (
138
- xla_bridge.get_backend().platform == "gpu"
139
- and KERNEL_TYPE.lower() == "triton"
140
- ):
141
- os.environ["JAX_LOG_COMPUTATION"] = "0"
142
- from .jax_op import generalized_delta_rule
158
+ if xla_bridge.get_backend().platform == "gpu":
159
+ if KERNEL_TYPE.lower() == "triton":
160
+ os.environ["JAX_LOG_COMPUTATION"] = "0"
161
+ from .jax_op import generalized_delta_rule
143
162
 
144
- USE_KERNEL = True
163
+ USE_TRITON_KERNEL = True
164
+ elif KERNEL_TYPE.lower() == "cuda":
165
+ from .jax_cuda_kernel.wkv7_jax import get_jax_generalized_delta_rule
166
+
167
+ generalized_delta_rule = get_jax_generalized_delta_rule(HEAD_SIZE)[0]
168
+ else:
169
+ from .native_keras_op import generalized_delta_rule
170
+ else:
171
+ from .native_keras_op import generalized_delta_rule
172
+ elif keras.config.backend() == "tensorflow":
173
+ import tensorflow as tf
174
+
175
+ if len(tf.config.list_physical_devices("GPU")) > 0:
176
+ if KERNEL_TYPE.lower() == "cuda" and HEAD_SIZE == 64:
177
+ try:
178
+ from jax.lib import xla_bridge
179
+
180
+ assert xla_bridge.get_backend().platform == "gpu"
181
+ except:
182
+ raise (
183
+ "The operation of the TensorFlow kernel depends on the JAX kernel."
184
+ "Therefore, it is necessary to ensure that it can be used in JAX, so that TensorFlow can be used."
185
+ )
186
+ print("🎉" * 10)
187
+ print("Tensorflow CUDA kernel onlt support Forward,not get graident")
188
+ print("🎉" * 10)
189
+ from .tf_eager_kernel import get_tf_generalized_delta_rule
190
+
191
+ generalized_delta_rule = get_tf_generalized_delta_rule(HEAD_SIZE)[0]
192
+ else:
193
+ from .native_keras_op import generalized_delta_rule
145
194
  else:
146
195
  from .native_keras_op import generalized_delta_rule
147
-
148
196
  else:
149
197
  from .native_keras_op import generalized_delta_rule
150
- return generalized_delta_rule, USE_KERNEL
198
+ return generalized_delta_rule, USE_TRITON_KERNEL
@@ -0,0 +1,42 @@
1
+ cmake_minimum_required(VERSION 3.18)
2
+ project(wkv7 LANGUAGES CXX CUDA)
3
+
4
+ find_package(CUDAToolkit REQUIRED)
5
+
6
+ # ---------- 1. 找到 Python ----------
7
+ find_package(Python3 REQUIRED COMPONENTS Interpreter)
8
+
9
+ # ---------- 2. 取 XLA 头文件路径 ----------
10
+ execute_process(
11
+ COMMAND "${Python3_EXECUTABLE}" -c "from jax import ffi; print(ffi.include_dir())"
12
+ OUTPUT_VARIABLE XLA_INCLUDE_DIR
13
+ OUTPUT_STRIP_TRAILING_WHITESPACE
14
+ )
15
+ if(NOT XLA_INCLUDE_DIR)
16
+ message(FATAL_ERROR "Cannot get XLA include dir from jax.ffi")
17
+ endif()
18
+ message(STATUS "XLA include directory: ${XLA_INCLUDE_DIR}")
19
+
20
+ # ---------- 3. 生成共享库 ----------
21
+ add_library(wkv7 SHARED wkv7_ffi.cu)
22
+
23
+ # 3-1. 头文件搜索路径
24
+ target_include_directories(wkv7 PRIVATE ${XLA_INCLUDE_DIR})
25
+
26
+ # 3-2. 链接 CUDA 运行时
27
+ target_link_libraries(wkv7 PRIVATE CUDA::cudart)
28
+
29
+ # 3-3. 关键:C++17 / CUDA17 标准
30
+ target_compile_features(wkv7 PUBLIC cxx_std_17)
31
+ set_target_properties(wkv7 PROPERTIES
32
+ CUDA_STANDARD 17
33
+ CUDA_SEPARABLE_COMPILATION ON
34
+ POSITION_INDEPENDENT_CODE ON
35
+ PREFIX "" # 去掉默认的 "lib" 前缀
36
+ )
37
+
38
+ # ---------- 4. 安装 ----------
39
+ # 把 .so 直接装到源码目录(与 wkv7_jax.py 同一级),方便 ctypes.CDLL 加载
40
+ install(TARGETS wkv7
41
+ LIBRARY DESTINATION "${CMAKE_SOURCE_DIR}"
42
+ RUNTIME DESTINATION "${CMAKE_SOURCE_DIR}") # Windows 用 RUNTIME
@@ -0,0 +1,279 @@
1
+ /*
2
+ * wkv7_ffi_bf16.cu
3
+ * BF16 版本,外部接口 BF16,内部 kernel 保持原样
4
+ */
5
+ #include <cuda_bf16.h>
6
+ #include <cuda_runtime.h>
7
+ #include <xla/ffi/api/ffi.h>
8
+ #include <vector>
9
+
10
+ namespace ffi = xla::ffi;
11
+
12
+ /* -------------------- 类型别名 -------------------- */
13
+ using bf = __nv_bfloat16;
14
+
15
+ /* -------------------- 设备端辅助(kernel 里用) -------------------- */
16
+ __device__ inline float to_float(const bf &u) {
17
+ return __bfloat162float(u);
18
+ }
19
+ __device__ inline bf to_bf(const float &u) {
20
+ return __float2bfloat16_rn(u);
21
+ }
22
+
23
+ typedef bf *__restrict__ F_;
24
+
25
+ /* -------------------- 你的 kernel(禁止修改) -------------------- */
26
+ __global__ void forward_kernel(int T, int H,
27
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
28
+ bf *y_, float *s_, float *sa_, float *h0_) {
29
+ constexpr int C = _C_;
30
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
31
+ float state[C] = {0};
32
+ __shared__ float q[C], k[C], w[C], a[C], b[C];
33
+ int h0_base = ((bb * H + hh) * C + i) * C;
34
+ #pragma unroll
35
+ for (int j = 0; j < C; ++j) state[j] = h0_[h0_base + j];
36
+
37
+ for (int t = 0; t < T; ++t) {
38
+ int ind = bb * T * H * C + t * H * C + hh * C + i;
39
+ __syncthreads();
40
+ q[i] = to_float(q_[ind]);
41
+ w[i] = __expf(-__expf(to_float(w_[ind])));
42
+ k[i] = to_float(k_[ind]);
43
+ a[i] = to_float(a_[ind]);
44
+ b[i] = to_float(b_[ind]);
45
+ __syncthreads();
46
+
47
+ float sa = 0.f;
48
+ #pragma unroll
49
+ for (int j = 0; j < C; ++j) sa += a[j] * state[j];
50
+ sa_[ind] = sa;
51
+
52
+ float v = to_float(v_[ind]);
53
+ float y = 0.f;
54
+ #pragma unroll
55
+ for (int j = 0; j < C; ++j) {
56
+ float &s = state[j];
57
+ s = s * w[j] + sa * b[j] + k[j] * v;
58
+ y += s * q[j];
59
+ }
60
+ y_[ind] = to_bf(y);
61
+
62
+ if ((t + 1) % _CHUNK_LEN_ == 0) {
63
+ int base = (bb * H + hh) * (T / _CHUNK_LEN_) * C * C +
64
+ (t / _CHUNK_LEN_) * C * C + i;
65
+ #pragma unroll
66
+ for (int j = 0; j < C; ++j) s_[base + j * C] = state[j];
67
+ }
68
+ }
69
+ }
70
+
71
+ __global__ void backward_kernel(int T, int H,
72
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_,
73
+ float *s_, float *sa_, float *dht_, float *dh0_,
74
+ bf *dw_, bf *dq_, bf *dk_, bf *dv_, bf *da_, bf *db_) {
75
+ constexpr int C = _C_;
76
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
77
+ float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
78
+ int dht_base = ((bb * H + hh) * C + i) * C;
79
+ #pragma unroll
80
+ for (int j = 0; j < C; ++j) {
81
+ dstate[j] = dht_[dht_base + j];
82
+ dstateT[j] = dht_[dht_base + j];
83
+ }
84
+ __shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
85
+ float qi, wi, ki, ai, bi, dyi;
86
+
87
+ for (int t = T - 1; t >= 0; --t) {
88
+ int ind = bb * T * H * C + t * H * C + hh * C + i;
89
+ __syncthreads();
90
+ q[i] = qi = to_float(q_[ind]);
91
+ float wi_fac = -__expf(to_float(w_[ind]));
92
+ w[i] = wi = __expf(wi_fac);
93
+ k[i] = ki = to_float(k_[ind]);
94
+ a[i] = ai = to_float(a_[ind]);
95
+ b[i] = bi = to_float(b_[ind]);
96
+ v[i] = to_float(v_[ind]);
97
+ dy[i] = dyi = to_float(dy_[ind]);
98
+ sa[i] = sa_[ind];
99
+ __syncthreads();
100
+
101
+ if ((t + 1) % _CHUNK_LEN_ == 0) {
102
+ int base = (bb * H + hh) * (T / _CHUNK_LEN_) * C * C +
103
+ (t / _CHUNK_LEN_) * C * C + i * C;
104
+ #pragma unroll
105
+ for (int j = 0; j < C; ++j) stateT[j] = s_[base + j];
106
+ }
107
+ float dq = 0.f;
108
+ #pragma unroll
109
+ for (int j = 0; j < C; ++j) dq += stateT[j] * dy[j];
110
+ dq_[ind] = to_bf(dq);
111
+
112
+ float iwi = 1.f / (wi + 1e-6f);
113
+ #pragma unroll
114
+ for (int j = 0; j < C; ++j) {
115
+ stateT[j] = (stateT[j] - ki * v[j] - bi * sa[j]) * iwi;
116
+ dstate[j] += dyi * q[j];
117
+ dstateT[j] += qi * dy[j];
118
+ }
119
+ float dw = 0.f, dk = 0.f, dv = 0.f, db = 0.f, dSb = 0.f;
120
+ #pragma unroll
121
+ for (int j = 0; j < C; ++j) {
122
+ dw += dstateT[j] * stateT[j];
123
+ dk += dstateT[j] * v[j];
124
+ dv += dstate[j] * k[j];
125
+ dSb += dstate[j] * b[j];
126
+ db += dstateT[j] * sa[j];
127
+ }
128
+ dw_[ind] = to_bf(dw * wi * wi_fac);
129
+ dk_[ind] = to_bf(dk);
130
+ dv_[ind] = to_bf(dv);
131
+ db_[ind] = to_bf(db);
132
+ __syncthreads();
133
+ dSb_shared[i] = dSb;
134
+ __syncthreads();
135
+ float da = 0.f;
136
+ #pragma unroll
137
+ for (int j = 0; j < C; ++j) da += stateT[j] * dSb_shared[j];
138
+ da_[ind] = to_bf(da);
139
+ #pragma unroll
140
+ for (int j = 0; j < C; ++j) {
141
+ dstate[j] = dstate[j] * w[j] + dSb * a[j];
142
+ dstateT[j] = dstateT[j] * wi + ai * dSb_shared[j];
143
+ if (t == 0) dh0_[dht_base + j] = dstate[j];
144
+ }
145
+ }
146
+ }
147
+
148
+ /* -------------------- 宿主函数 -------------------- */
149
+ static ffi::Error WKV7FwdHost(
150
+ cudaStream_t stream,
151
+ ffi::Buffer<ffi::BF16> w,
152
+ ffi::Buffer<ffi::BF16> q,
153
+ ffi::Buffer<ffi::BF16> k,
154
+ ffi::Buffer<ffi::BF16> v,
155
+ ffi::Buffer<ffi::BF16> z,
156
+ ffi::Buffer<ffi::BF16> a,
157
+ ffi::Buffer<ffi::F32> h0, // 保持 float
158
+ ffi::ResultBuffer<ffi::BF16> y,
159
+ ffi::ResultBuffer<ffi::F32> s,
160
+ ffi::ResultBuffer<ffi::F32> sa)
161
+ {
162
+ constexpr int C = _C_;
163
+ auto dims = w.dimensions();
164
+ int B = dims[0], T = dims[1], H = dims[2];
165
+ dim3 block(C);
166
+ dim3 grid(H, B);
167
+
168
+ forward_kernel<<<grid, block, 0, stream>>>(
169
+ T, H,
170
+ reinterpret_cast<bf *>(w.typed_data()),
171
+ reinterpret_cast<bf *>(q.typed_data()),
172
+ reinterpret_cast<bf *>(k.typed_data()),
173
+ reinterpret_cast<bf *>(v.typed_data()),
174
+ reinterpret_cast<bf *>(z.typed_data()),
175
+ reinterpret_cast<bf *>(a.typed_data()),
176
+ reinterpret_cast<bf *>(y->typed_data()),
177
+ s->typed_data(),
178
+ sa->typed_data(),
179
+ h0.typed_data());
180
+
181
+ cudaError_t err = cudaGetLastError();
182
+ if (err != cudaSuccess)
183
+ return ffi::Error::Internal(
184
+ std::string("CUDA forward_kernel error: ") + cudaGetErrorString(err));
185
+ return ffi::Error::Success();
186
+ }
187
+
188
+ static ffi::Error WKV7BwdHost(
189
+ cudaStream_t stream,
190
+ ffi::Buffer<ffi::BF16> w,
191
+ ffi::Buffer<ffi::BF16> q,
192
+ ffi::Buffer<ffi::BF16> k,
193
+ ffi::Buffer<ffi::BF16> v,
194
+ ffi::Buffer<ffi::BF16> z,
195
+ ffi::Buffer<ffi::BF16> a,
196
+ ffi::Buffer<ffi::BF16> dy,
197
+ ffi::Buffer<ffi::F32> s,
198
+ ffi::Buffer<ffi::F32> sa,
199
+ ffi::Buffer<ffi::F32> dht,
200
+ ffi::ResultBuffer<ffi::F32> dh0,
201
+ ffi::ResultBuffer<ffi::BF16> dw,
202
+ ffi::ResultBuffer<ffi::BF16> dq,
203
+ ffi::ResultBuffer<ffi::BF16> dk,
204
+ ffi::ResultBuffer<ffi::BF16> dv,
205
+ ffi::ResultBuffer<ffi::BF16> da,
206
+ ffi::ResultBuffer<ffi::BF16> db)
207
+ {
208
+ auto dims = w.dimensions();
209
+ int B = dims[0], T = dims[1], H = dims[2];
210
+ constexpr int C = _C_;
211
+ dim3 block(C);
212
+ dim3 grid(H, B);
213
+
214
+ backward_kernel<<<grid, block, 0, stream>>>(
215
+ T, H,
216
+ reinterpret_cast<bf *>(w.typed_data()),
217
+ reinterpret_cast<bf *>(q.typed_data()),
218
+ reinterpret_cast<bf *>(k.typed_data()),
219
+ reinterpret_cast<bf *>(v.typed_data()),
220
+ reinterpret_cast<bf *>(z.typed_data()),
221
+ reinterpret_cast<bf *>(a.typed_data()),
222
+ reinterpret_cast<bf *>(dy.typed_data()),
223
+ s.typed_data(),
224
+ sa.typed_data(),
225
+ dht.typed_data(),
226
+ dh0->typed_data(),
227
+ reinterpret_cast<bf *>(dw->typed_data()),
228
+ reinterpret_cast<bf *>(dq->typed_data()),
229
+ reinterpret_cast<bf *>(dk->typed_data()),
230
+ reinterpret_cast<bf *>(dv->typed_data()),
231
+ reinterpret_cast<bf *>(da->typed_data()),
232
+ reinterpret_cast<bf *>(db->typed_data()));
233
+
234
+ cudaError_t err = cudaGetLastError();
235
+ if (err != cudaSuccess)
236
+ return ffi::Error::Internal(
237
+ std::string("CUDA backward_kernel error: ") + cudaGetErrorString(err));
238
+ return ffi::Error::Success();
239
+ }
240
+
241
+ /* -------------------- 注册符号 -------------------- */
242
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
243
+ Wkv7Fwd, WKV7FwdHost,
244
+ ffi::Ffi::Bind()
245
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
246
+ .Arg<ffi::Buffer<ffi::BF16>>() // w
247
+ .Arg<ffi::Buffer<ffi::BF16>>() // q
248
+ .Arg<ffi::Buffer<ffi::BF16>>() // k
249
+ .Arg<ffi::Buffer<ffi::BF16>>() // v
250
+ .Arg<ffi::Buffer<ffi::BF16>>() // z
251
+ .Arg<ffi::Buffer<ffi::BF16>>() // a
252
+ .Arg<ffi::Buffer<ffi::F32>>() // h0 (float)
253
+ .Ret<ffi::Buffer<ffi::BF16>>() // y
254
+ .Ret<ffi::Buffer<ffi::F32>>() // s
255
+ .Ret<ffi::Buffer<ffi::F32>>() // sa
256
+ , {ffi::Traits::kCmdBufferCompatible});
257
+
258
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
259
+ Wkv7Bwd, WKV7BwdHost,
260
+ ffi::Ffi::Bind()
261
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
262
+ .Arg<ffi::Buffer<ffi::BF16>>() // w
263
+ .Arg<ffi::Buffer<ffi::BF16>>() // q
264
+ .Arg<ffi::Buffer<ffi::BF16>>() // k
265
+ .Arg<ffi::Buffer<ffi::BF16>>() // v
266
+ .Arg<ffi::Buffer<ffi::BF16>>() // z
267
+ .Arg<ffi::Buffer<ffi::BF16>>() // a
268
+ .Arg<ffi::Buffer<ffi::BF16>>() // dy
269
+ .Arg<ffi::Buffer<ffi::F32>>() // s
270
+ .Arg<ffi::Buffer<ffi::F32>>() // sa
271
+ .Arg<ffi::Buffer<ffi::F32>>() // dht
272
+ .Ret<ffi::Buffer<ffi::F32>>() // dh0
273
+ .Ret<ffi::Buffer<ffi::BF16>>() // dw
274
+ .Ret<ffi::Buffer<ffi::BF16>>() // dq
275
+ .Ret<ffi::Buffer<ffi::BF16>>() // dk
276
+ .Ret<ffi::Buffer<ffi::BF16>>() // dv
277
+ .Ret<ffi::Buffer<ffi::BF16>>() // da
278
+ .Ret<ffi::Buffer<ffi::BF16>>() // db
279
+ , {ffi::Traits::kCmdBufferCompatible});