rwkv-ops 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (31) hide show
  1. rwkv_ops/__init__.py +5 -6
  2. rwkv_ops/rwkv6_kernel/__init__.py +0 -6
  3. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  4. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  5. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  6. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  7. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  8. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  9. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  10. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  11. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  12. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  13. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +21 -23
  14. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +14 -10
  15. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  16. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  17. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +4 -4
  18. rwkv_ops/rwkv7_kernel/__init__.py +80 -29
  19. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  20. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +279 -0
  21. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +237 -0
  22. rwkv_ops/rwkv7_kernel/jax_op.py +6 -5
  23. rwkv_ops/rwkv7_kernel/native_keras_op.py +5 -6
  24. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +123 -0
  25. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +165 -0
  26. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +35 -0
  27. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/METADATA +28 -27
  28. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/RECORD +30 -13
  29. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/WHEEL +1 -2
  30. rwkv_ops-0.2.2.dist-info/top_level.txt +0 -1
  31. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info/licenses}/LICENSE.txt +0 -0
@@ -12,18 +12,20 @@ def transpose_head(x, head_first):
12
12
 
13
13
 
14
14
  def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
15
- USE_KERNEL = False
15
+ USE_TRITON_KERNEL = False
16
16
  if keras.config.backend() == "torch":
17
17
  import torch
18
+ if not torch.cuda.is_available():
19
+ from .native_keras_op import generalized_delta_rule
20
+ return generalized_delta_rule,False
18
21
 
19
22
  if KERNEL_TYPE.lower() == "triton":
20
23
  from .torch_op import generalized_delta_rule
21
24
 
22
- USE_KERNEL = True
25
+ USE_TRITON_KERNEL = True
23
26
 
24
27
  elif KERNEL_TYPE.lower() == "cuda":
25
28
  CHUNK_LEN = 16
26
- USE_KERNEL = True
27
29
  from torch.utils.cpp_extension import load
28
30
  import os
29
31
 
@@ -44,8 +46,8 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
44
46
  load(
45
47
  name="wind_backstepping",
46
48
  sources=[
47
- os.path.join(current_dir_path, "cuda_kernel/wkv7_cuda.cu"),
48
- os.path.join(current_dir_path, "cuda_kernel/wkv7_op.cpp"),
49
+ os.path.join(current_dir_path, "torch_cuda_kernel/wkv7_cuda.cu"),
50
+ os.path.join(current_dir_path, "torch_cuda_kernel/wkv7_op.cpp"),
49
51
  ],
50
52
  is_python_module=False,
51
53
  verbose=True,
@@ -54,8 +56,8 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
54
56
 
55
57
  class WindBackstepping(torch.autograd.Function):
56
58
  @staticmethod
57
- def forward(ctx, w, q, k, v, z, b):
58
- B, T, H, C = w.shape
59
+ def forward(ctx, w, q, k, v, z, b, h0):
60
+ B, T, H, N = w.shape
59
61
  DTYPE = q.dtype
60
62
  q = ops.cast(q, "bfloat16")
61
63
  k = ops.cast(k, "bfloat16")
@@ -63,30 +65,42 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
63
65
  z = ops.cast(z, "bfloat16")
64
66
  b = ops.cast(b, "bfloat16")
65
67
  w = ops.cast(w, "bfloat16")
66
- assert T % CHUNK_LEN == 0
68
+ if T % CHUNK_LEN != 0:
69
+ raise ValueError(
70
+ "RWKV输入的序列长度必须可以被16整除"
71
+ "Please make sure the sequence length is divisible by 16"
72
+ )
67
73
  assert all(i.is_contiguous() for i in [w, q, k, v, z, b])
68
74
  y = torch.empty_like(v)
69
75
  s = torch.empty(
70
- B, H, T // CHUNK_LEN, C, C, dtype=torch.float32, device=w.device
76
+ B, H, T // CHUNK_LEN, N, N, dtype=torch.float32, device=w.device
71
77
  )
72
- sa = torch.empty(B, T, H, C, dtype=torch.float32, device=w.device)
73
- torch.ops.wind_backstepping.forward(w, q, k, v, z, b, y, s, sa)
78
+ sa = torch.empty(B, T, H, N, dtype=torch.float32, device=w.device)
79
+ torch.ops.wind_backstepping.forward(w, q, k, v, z, b, y, s, sa, h0)
74
80
  ctx.save_for_backward(w, q, k, v, z, b, s, sa)
75
- return ops.cast(y, DTYPE)
81
+ last_state = torch.empty_like(h0)
82
+ last_state.copy_(ops.transpose(s[:, :, -1], [0, 1, 3, 2]))
83
+
84
+ return ops.cast(y, DTYPE), last_state
76
85
 
77
86
  @staticmethod
78
- def backward(ctx, dy):
87
+ def backward(ctx, dy, dht):
79
88
  DTYPE = dy.dtype
80
89
  dy = ops.cast(dy, torch.bfloat16)
81
90
  dy = dy.contiguous()
82
- assert all(i.dtype == torch.bfloat16 for i in [dy])
83
- assert all(i.is_contiguous() for i in [dy])
91
+
84
92
  w, q, k, v, z, b, s, sa = ctx.saved_tensors
93
+ dht = ops.cast(dht, "float32")
94
+ dht = dht.contiguous()
95
+ assert all(i.dtype == torch.bfloat16 for i in [dy])
96
+ assert all(i.is_contiguous() for i in [dy, dht])
97
+ dh0 = torch.empty(dht.shape, dtype=dht.dtype, device=dht.device)
85
98
  dw, dq, dk, dv, dz, db = [
86
99
  torch.empty_like(x) for x in [w, q, k, v, z, b]
87
100
  ]
101
+
88
102
  torch.ops.wind_backstepping.backward(
89
- w, q, k, v, z, b, dy, s, sa, dw, dq, dk, dv, dz, db
103
+ w, q, k, v, z, b, dy, s, sa, dht, dh0, dw, dq, dk, dv, dz, db
90
104
  )
91
105
  return (
92
106
  ops.cast(dw, DTYPE),
@@ -95,9 +109,10 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
95
109
  ops.cast(dv, DTYPE),
96
110
  ops.cast(dz, DTYPE),
97
111
  ops.cast(db, DTYPE),
112
+ dh0,
98
113
  )
99
114
 
100
- def RUN_CUDA_RWKV7g(q, w, k, v, a, b):
115
+ def RUN_CUDA_RWKV7g(q, w, k, v, a, b, h0):
101
116
  B, T, H, C = q.shape
102
117
  q = q.contiguous()
103
118
  w = w.contiguous()
@@ -105,7 +120,8 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
105
120
  v = v.contiguous()
106
121
  a = a.contiguous()
107
122
  b = b.contiguous()
108
- return WindBackstepping.apply(w, q, k, v, a, b).view(B, T, H * C)
123
+ out, state = WindBackstepping.apply(w, q, k, v, a, b, h0)
124
+ return out, state
109
125
 
110
126
  def generalized_delta_rule(
111
127
  r: torch.Tensor,
@@ -125,26 +141,61 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
125
141
  a = transpose_head(a, head_first)
126
142
  b = transpose_head(b, head_first)
127
143
  w = transpose_head(w, head_first)
128
- return RUN_CUDA_RWKV7g(r, w, k, v, a, b), None
144
+ B, T, H, N = w.shape
145
+ if initial_state is None:
146
+ initial_state = ops.zeros((B, H, N, N), "float32")
147
+ else:
148
+ initial_state = ops.cast(initial_state, "float32")
149
+ out, state = RUN_CUDA_RWKV7g(r, w, k, v, a, b, initial_state)
150
+ if output_final_state:
151
+ return out, state
152
+ return out
129
153
  else:
130
154
  from .native_keras_op import generalized_delta_rule
131
155
 
132
- USE_KERNEL = False
156
+ USE_TRITON_KERNEL = False
133
157
  elif keras.config.backend() == "jax":
134
158
  from jax.lib import xla_bridge
135
159
  import os
136
160
 
137
- if (
138
- xla_bridge.get_backend().platform == "gpu"
139
- and KERNEL_TYPE.lower() == "triton"
140
- ):
141
- os.environ["JAX_LOG_COMPUTATION"] = "0"
142
- from .jax_op import generalized_delta_rule
161
+ if xla_bridge.get_backend().platform == "gpu":
162
+ if KERNEL_TYPE.lower() == "triton":
163
+ os.environ["JAX_LOG_COMPUTATION"] = "0"
164
+ from .jax_op import generalized_delta_rule
143
165
 
144
- USE_KERNEL = True
166
+ USE_TRITON_KERNEL = True
167
+ elif KERNEL_TYPE.lower() == "cuda":
168
+ from .jax_cuda_kernel.wkv7_jax import get_jax_generalized_delta_rule
169
+
170
+ generalized_delta_rule = get_jax_generalized_delta_rule(HEAD_SIZE)[0]
171
+ else:
172
+ from .native_keras_op import generalized_delta_rule
173
+ else:
174
+ from .native_keras_op import generalized_delta_rule
175
+ elif keras.config.backend() == "tensorflow":
176
+ import tensorflow as tf
177
+
178
+ if len(tf.config.list_physical_devices("GPU")) > 0:
179
+ if KERNEL_TYPE.lower() == "cuda" and HEAD_SIZE == 64:
180
+ try:
181
+ from jax.lib import xla_bridge
182
+
183
+ assert xla_bridge.get_backend().platform == "gpu"
184
+ except:
185
+ raise (
186
+ "The operation of the TensorFlow kernel depends on the JAX kernel."
187
+ "Therefore, it is necessary to ensure that it can be used in JAX, so that TensorFlow can be used."
188
+ )
189
+ print("🎉" * 10)
190
+ print("Tensorflow CUDA kernel onlt support Forward,not get graident")
191
+ print("🎉" * 10)
192
+ from .tf_eager_kernel import get_tf_generalized_delta_rule
193
+
194
+ generalized_delta_rule = get_tf_generalized_delta_rule(HEAD_SIZE)[0]
195
+ else:
196
+ from .native_keras_op import generalized_delta_rule
145
197
  else:
146
198
  from .native_keras_op import generalized_delta_rule
147
-
148
199
  else:
149
200
  from .native_keras_op import generalized_delta_rule
150
- return generalized_delta_rule, USE_KERNEL
201
+ return generalized_delta_rule, USE_TRITON_KERNEL
@@ -0,0 +1,42 @@
1
+ cmake_minimum_required(VERSION 3.18)
2
+ project(wkv7 LANGUAGES CXX CUDA)
3
+
4
+ find_package(CUDAToolkit REQUIRED)
5
+
6
+ # ---------- 1. 找到 Python ----------
7
+ find_package(Python3 REQUIRED COMPONENTS Interpreter)
8
+
9
+ # ---------- 2. 取 XLA 头文件路径 ----------
10
+ execute_process(
11
+ COMMAND "${Python3_EXECUTABLE}" -c "from jax import ffi; print(ffi.include_dir())"
12
+ OUTPUT_VARIABLE XLA_INCLUDE_DIR
13
+ OUTPUT_STRIP_TRAILING_WHITESPACE
14
+ )
15
+ if(NOT XLA_INCLUDE_DIR)
16
+ message(FATAL_ERROR "Cannot get XLA include dir from jax.ffi")
17
+ endif()
18
+ message(STATUS "XLA include directory: ${XLA_INCLUDE_DIR}")
19
+
20
+ # ---------- 3. 生成共享库 ----------
21
+ add_library(wkv7 SHARED wkv7_ffi.cu)
22
+
23
+ # 3-1. 头文件搜索路径
24
+ target_include_directories(wkv7 PRIVATE ${XLA_INCLUDE_DIR})
25
+
26
+ # 3-2. 链接 CUDA 运行时
27
+ target_link_libraries(wkv7 PRIVATE CUDA::cudart)
28
+
29
+ # 3-3. 关键:C++17 / CUDA17 标准
30
+ target_compile_features(wkv7 PUBLIC cxx_std_17)
31
+ set_target_properties(wkv7 PROPERTIES
32
+ CUDA_STANDARD 17
33
+ CUDA_SEPARABLE_COMPILATION ON
34
+ POSITION_INDEPENDENT_CODE ON
35
+ PREFIX "" # 去掉默认的 "lib" 前缀
36
+ )
37
+
38
+ # ---------- 4. 安装 ----------
39
+ # 把 .so 直接装到源码目录(与 wkv7_jax.py 同一级),方便 ctypes.CDLL 加载
40
+ install(TARGETS wkv7
41
+ LIBRARY DESTINATION "${CMAKE_SOURCE_DIR}"
42
+ RUNTIME DESTINATION "${CMAKE_SOURCE_DIR}") # Windows 用 RUNTIME
@@ -0,0 +1,279 @@
1
+ /*
2
+ * wkv7_ffi_bf16.cu
3
+ * BF16 版本,外部接口 BF16,内部 kernel 保持原样
4
+ */
5
+ #include <cuda_bf16.h>
6
+ #include <cuda_runtime.h>
7
+ #include <xla/ffi/api/ffi.h>
8
+ #include <vector>
9
+
10
+ namespace ffi = xla::ffi;
11
+
12
+ /* -------------------- 类型别名 -------------------- */
13
+ using bf = __nv_bfloat16;
14
+
15
+ /* -------------------- 设备端辅助(kernel 里用) -------------------- */
16
+ __device__ inline float to_float(const bf &u) {
17
+ return __bfloat162float(u);
18
+ }
19
+ __device__ inline bf to_bf(const float &u) {
20
+ return __float2bfloat16_rn(u);
21
+ }
22
+
23
+ typedef bf *__restrict__ F_;
24
+
25
+ /* -------------------- 你的 kernel(禁止修改) -------------------- */
26
+ __global__ void forward_kernel(int T, int H,
27
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
28
+ bf *y_, float *s_, float *sa_, float *h0_) {
29
+ constexpr int C = _C_;
30
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
31
+ float state[C] = {0};
32
+ __shared__ float q[C], k[C], w[C], a[C], b[C];
33
+ int h0_base = ((bb * H + hh) * C + i) * C;
34
+ #pragma unroll
35
+ for (int j = 0; j < C; ++j) state[j] = h0_[h0_base + j];
36
+
37
+ for (int t = 0; t < T; ++t) {
38
+ int ind = bb * T * H * C + t * H * C + hh * C + i;
39
+ __syncthreads();
40
+ q[i] = to_float(q_[ind]);
41
+ w[i] = __expf(-__expf(to_float(w_[ind])));
42
+ k[i] = to_float(k_[ind]);
43
+ a[i] = to_float(a_[ind]);
44
+ b[i] = to_float(b_[ind]);
45
+ __syncthreads();
46
+
47
+ float sa = 0.f;
48
+ #pragma unroll
49
+ for (int j = 0; j < C; ++j) sa += a[j] * state[j];
50
+ sa_[ind] = sa;
51
+
52
+ float v = to_float(v_[ind]);
53
+ float y = 0.f;
54
+ #pragma unroll
55
+ for (int j = 0; j < C; ++j) {
56
+ float &s = state[j];
57
+ s = s * w[j] + sa * b[j] + k[j] * v;
58
+ y += s * q[j];
59
+ }
60
+ y_[ind] = to_bf(y);
61
+
62
+ if ((t + 1) % _CHUNK_LEN_ == 0) {
63
+ int base = (bb * H + hh) * (T / _CHUNK_LEN_) * C * C +
64
+ (t / _CHUNK_LEN_) * C * C + i;
65
+ #pragma unroll
66
+ for (int j = 0; j < C; ++j) s_[base + j * C] = state[j];
67
+ }
68
+ }
69
+ }
70
+
71
+ __global__ void backward_kernel(int T, int H,
72
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_,
73
+ float *s_, float *sa_, float *dht_, float *dh0_,
74
+ bf *dw_, bf *dq_, bf *dk_, bf *dv_, bf *da_, bf *db_) {
75
+ constexpr int C = _C_;
76
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
77
+ float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
78
+ int dht_base = ((bb * H + hh) * C + i) * C;
79
+ #pragma unroll
80
+ for (int j = 0; j < C; ++j) {
81
+ dstate[j] = dht_[dht_base + j];
82
+ dstateT[j] = dht_[dht_base + j];
83
+ }
84
+ __shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
85
+ float qi, wi, ki, ai, bi, dyi;
86
+
87
+ for (int t = T - 1; t >= 0; --t) {
88
+ int ind = bb * T * H * C + t * H * C + hh * C + i;
89
+ __syncthreads();
90
+ q[i] = qi = to_float(q_[ind]);
91
+ float wi_fac = -__expf(to_float(w_[ind]));
92
+ w[i] = wi = __expf(wi_fac);
93
+ k[i] = ki = to_float(k_[ind]);
94
+ a[i] = ai = to_float(a_[ind]);
95
+ b[i] = bi = to_float(b_[ind]);
96
+ v[i] = to_float(v_[ind]);
97
+ dy[i] = dyi = to_float(dy_[ind]);
98
+ sa[i] = sa_[ind];
99
+ __syncthreads();
100
+
101
+ if ((t + 1) % _CHUNK_LEN_ == 0) {
102
+ int base = (bb * H + hh) * (T / _CHUNK_LEN_) * C * C +
103
+ (t / _CHUNK_LEN_) * C * C + i * C;
104
+ #pragma unroll
105
+ for (int j = 0; j < C; ++j) stateT[j] = s_[base + j];
106
+ }
107
+ float dq = 0.f;
108
+ #pragma unroll
109
+ for (int j = 0; j < C; ++j) dq += stateT[j] * dy[j];
110
+ dq_[ind] = to_bf(dq);
111
+
112
+ float iwi = 1.f / (wi + 1e-6f);
113
+ #pragma unroll
114
+ for (int j = 0; j < C; ++j) {
115
+ stateT[j] = (stateT[j] - ki * v[j] - bi * sa[j]) * iwi;
116
+ dstate[j] += dyi * q[j];
117
+ dstateT[j] += qi * dy[j];
118
+ }
119
+ float dw = 0.f, dk = 0.f, dv = 0.f, db = 0.f, dSb = 0.f;
120
+ #pragma unroll
121
+ for (int j = 0; j < C; ++j) {
122
+ dw += dstateT[j] * stateT[j];
123
+ dk += dstateT[j] * v[j];
124
+ dv += dstate[j] * k[j];
125
+ dSb += dstate[j] * b[j];
126
+ db += dstateT[j] * sa[j];
127
+ }
128
+ dw_[ind] = to_bf(dw * wi * wi_fac);
129
+ dk_[ind] = to_bf(dk);
130
+ dv_[ind] = to_bf(dv);
131
+ db_[ind] = to_bf(db);
132
+ __syncthreads();
133
+ dSb_shared[i] = dSb;
134
+ __syncthreads();
135
+ float da = 0.f;
136
+ #pragma unroll
137
+ for (int j = 0; j < C; ++j) da += stateT[j] * dSb_shared[j];
138
+ da_[ind] = to_bf(da);
139
+ #pragma unroll
140
+ for (int j = 0; j < C; ++j) {
141
+ dstate[j] = dstate[j] * w[j] + dSb * a[j];
142
+ dstateT[j] = dstateT[j] * wi + ai * dSb_shared[j];
143
+ if (t == 0) dh0_[dht_base + j] = dstate[j];
144
+ }
145
+ }
146
+ }
147
+
148
+ /* -------------------- 宿主函数 -------------------- */
149
+ static ffi::Error WKV7FwdHost(
150
+ cudaStream_t stream,
151
+ ffi::Buffer<ffi::BF16> w,
152
+ ffi::Buffer<ffi::BF16> q,
153
+ ffi::Buffer<ffi::BF16> k,
154
+ ffi::Buffer<ffi::BF16> v,
155
+ ffi::Buffer<ffi::BF16> z,
156
+ ffi::Buffer<ffi::BF16> a,
157
+ ffi::Buffer<ffi::F32> h0, // 保持 float
158
+ ffi::ResultBuffer<ffi::BF16> y,
159
+ ffi::ResultBuffer<ffi::F32> s,
160
+ ffi::ResultBuffer<ffi::F32> sa)
161
+ {
162
+ constexpr int C = _C_;
163
+ auto dims = w.dimensions();
164
+ int B = dims[0], T = dims[1], H = dims[2];
165
+ dim3 block(C);
166
+ dim3 grid(H, B);
167
+
168
+ forward_kernel<<<grid, block, 0, stream>>>(
169
+ T, H,
170
+ reinterpret_cast<bf *>(w.typed_data()),
171
+ reinterpret_cast<bf *>(q.typed_data()),
172
+ reinterpret_cast<bf *>(k.typed_data()),
173
+ reinterpret_cast<bf *>(v.typed_data()),
174
+ reinterpret_cast<bf *>(z.typed_data()),
175
+ reinterpret_cast<bf *>(a.typed_data()),
176
+ reinterpret_cast<bf *>(y->typed_data()),
177
+ s->typed_data(),
178
+ sa->typed_data(),
179
+ h0.typed_data());
180
+
181
+ cudaError_t err = cudaGetLastError();
182
+ if (err != cudaSuccess)
183
+ return ffi::Error::Internal(
184
+ std::string("CUDA forward_kernel error: ") + cudaGetErrorString(err));
185
+ return ffi::Error::Success();
186
+ }
187
+
188
+ static ffi::Error WKV7BwdHost(
189
+ cudaStream_t stream,
190
+ ffi::Buffer<ffi::BF16> w,
191
+ ffi::Buffer<ffi::BF16> q,
192
+ ffi::Buffer<ffi::BF16> k,
193
+ ffi::Buffer<ffi::BF16> v,
194
+ ffi::Buffer<ffi::BF16> z,
195
+ ffi::Buffer<ffi::BF16> a,
196
+ ffi::Buffer<ffi::BF16> dy,
197
+ ffi::Buffer<ffi::F32> s,
198
+ ffi::Buffer<ffi::F32> sa,
199
+ ffi::Buffer<ffi::F32> dht,
200
+ ffi::ResultBuffer<ffi::F32> dh0,
201
+ ffi::ResultBuffer<ffi::BF16> dw,
202
+ ffi::ResultBuffer<ffi::BF16> dq,
203
+ ffi::ResultBuffer<ffi::BF16> dk,
204
+ ffi::ResultBuffer<ffi::BF16> dv,
205
+ ffi::ResultBuffer<ffi::BF16> da,
206
+ ffi::ResultBuffer<ffi::BF16> db)
207
+ {
208
+ auto dims = w.dimensions();
209
+ int B = dims[0], T = dims[1], H = dims[2];
210
+ constexpr int C = _C_;
211
+ dim3 block(C);
212
+ dim3 grid(H, B);
213
+
214
+ backward_kernel<<<grid, block, 0, stream>>>(
215
+ T, H,
216
+ reinterpret_cast<bf *>(w.typed_data()),
217
+ reinterpret_cast<bf *>(q.typed_data()),
218
+ reinterpret_cast<bf *>(k.typed_data()),
219
+ reinterpret_cast<bf *>(v.typed_data()),
220
+ reinterpret_cast<bf *>(z.typed_data()),
221
+ reinterpret_cast<bf *>(a.typed_data()),
222
+ reinterpret_cast<bf *>(dy.typed_data()),
223
+ s.typed_data(),
224
+ sa.typed_data(),
225
+ dht.typed_data(),
226
+ dh0->typed_data(),
227
+ reinterpret_cast<bf *>(dw->typed_data()),
228
+ reinterpret_cast<bf *>(dq->typed_data()),
229
+ reinterpret_cast<bf *>(dk->typed_data()),
230
+ reinterpret_cast<bf *>(dv->typed_data()),
231
+ reinterpret_cast<bf *>(da->typed_data()),
232
+ reinterpret_cast<bf *>(db->typed_data()));
233
+
234
+ cudaError_t err = cudaGetLastError();
235
+ if (err != cudaSuccess)
236
+ return ffi::Error::Internal(
237
+ std::string("CUDA backward_kernel error: ") + cudaGetErrorString(err));
238
+ return ffi::Error::Success();
239
+ }
240
+
241
+ /* -------------------- 注册符号 -------------------- */
242
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
243
+ Wkv7Fwd, WKV7FwdHost,
244
+ ffi::Ffi::Bind()
245
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
246
+ .Arg<ffi::Buffer<ffi::BF16>>() // w
247
+ .Arg<ffi::Buffer<ffi::BF16>>() // q
248
+ .Arg<ffi::Buffer<ffi::BF16>>() // k
249
+ .Arg<ffi::Buffer<ffi::BF16>>() // v
250
+ .Arg<ffi::Buffer<ffi::BF16>>() // z
251
+ .Arg<ffi::Buffer<ffi::BF16>>() // a
252
+ .Arg<ffi::Buffer<ffi::F32>>() // h0 (float)
253
+ .Ret<ffi::Buffer<ffi::BF16>>() // y
254
+ .Ret<ffi::Buffer<ffi::F32>>() // s
255
+ .Ret<ffi::Buffer<ffi::F32>>() // sa
256
+ , {ffi::Traits::kCmdBufferCompatible});
257
+
258
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
259
+ Wkv7Bwd, WKV7BwdHost,
260
+ ffi::Ffi::Bind()
261
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
262
+ .Arg<ffi::Buffer<ffi::BF16>>() // w
263
+ .Arg<ffi::Buffer<ffi::BF16>>() // q
264
+ .Arg<ffi::Buffer<ffi::BF16>>() // k
265
+ .Arg<ffi::Buffer<ffi::BF16>>() // v
266
+ .Arg<ffi::Buffer<ffi::BF16>>() // z
267
+ .Arg<ffi::Buffer<ffi::BF16>>() // a
268
+ .Arg<ffi::Buffer<ffi::BF16>>() // dy
269
+ .Arg<ffi::Buffer<ffi::F32>>() // s
270
+ .Arg<ffi::Buffer<ffi::F32>>() // sa
271
+ .Arg<ffi::Buffer<ffi::F32>>() // dht
272
+ .Ret<ffi::Buffer<ffi::F32>>() // dh0
273
+ .Ret<ffi::Buffer<ffi::BF16>>() // dw
274
+ .Ret<ffi::Buffer<ffi::BF16>>() // dq
275
+ .Ret<ffi::Buffer<ffi::BF16>>() // dk
276
+ .Ret<ffi::Buffer<ffi::BF16>>() // dv
277
+ .Ret<ffi::Buffer<ffi::BF16>>() // da
278
+ .Ret<ffi::Buffer<ffi::BF16>>() // db
279
+ , {ffi::Traits::kCmdBufferCompatible});