rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,233 @@
1
+ import os
2
+ import torch
3
+ from torch.utils.cpp_extension import load
4
+ from keras.src.backend.torch.core import cast
5
+ from keras.src.backend.torch.numpy import transpose, zeros
6
+
7
+
8
+ def transpose_head(x, head_first):
9
+ if head_first:
10
+ return transpose(x, (0, 2, 1, 3))
11
+ else:
12
+ return x
13
+
14
+
15
+ def get_torch_generalized_delta_rule(HEAD_SIZE=64):
16
+ CHUNK_LEN = 16
17
+ flags = [
18
+ "-res-usage",
19
+ f"-D_C_={HEAD_SIZE}",
20
+ f"-D_CHUNK_LEN_={CHUNK_LEN}",
21
+ "--use_fast_math",
22
+ "-O3",
23
+ "-Xptxas -O3",
24
+ "--extra-device-vectorization",
25
+ ]
26
+ # 获取当前文件的绝对路径
27
+ current_file_path = os.path.abspath(__file__)
28
+
29
+ # 获取当前文件的目录路径
30
+ current_dir_path = os.path.dirname(current_file_path)
31
+ load(
32
+ name="wind_backstepping",
33
+ sources=[
34
+ os.path.join(current_dir_path, "wkv7_cuda.cu"),
35
+ os.path.join(current_dir_path, "wkv7_op.cpp"),
36
+ ],
37
+ is_python_module=False,
38
+ verbose=True,
39
+ extra_cuda_cflags=flags,
40
+ )
41
+
42
+ class WindBackstepping(torch.autograd.Function):
43
+ @staticmethod
44
+ def forward(ctx, w, q, k, v, z, b, h0):
45
+ B, T, H, N = w.shape
46
+ DTYPE = q.dtype
47
+ q = cast(q, "bfloat16")
48
+ k = cast(k, "bfloat16")
49
+ v = cast(v, "bfloat16")
50
+ z = cast(z, "bfloat16")
51
+ b = cast(b, "bfloat16")
52
+ w = cast(w, "bfloat16")
53
+ if T % CHUNK_LEN != 0:
54
+ raise ValueError(
55
+ "RWKV输入的序列长度必须可以被16整除"
56
+ "Please make sure the sequence length is divisible by 16"
57
+ )
58
+ assert all(i.is_contiguous() for i in [w, q, k, v, z, b])
59
+ y = torch.empty_like(v)
60
+ s = torch.empty(
61
+ B, H, T // CHUNK_LEN, N, N, dtype=torch.float32, device=w.device
62
+ )
63
+ sa = torch.empty(B, T, H, N, dtype=torch.float32, device=w.device)
64
+ torch.ops.wind_backstepping.forward(w, q, k, v, z, b, y, s, sa, h0)
65
+ ctx.save_for_backward(w, q, k, v, z, b, s, sa)
66
+ last_state = torch.empty_like(h0)
67
+ last_state.copy_(transpose(s[:, :, -1], [0, 1, 3, 2]))
68
+
69
+ return cast(y, DTYPE), last_state
70
+
71
+ @staticmethod
72
+ def backward(ctx, dy, dht):
73
+ DTYPE = dy.dtype
74
+ dy = cast(dy, torch.bfloat16)
75
+ dy = dy.contiguous()
76
+
77
+ w, q, k, v, z, b, s, sa = ctx.saved_tensors
78
+ dht = cast(dht, "float32")
79
+ dht = dht.contiguous()
80
+ assert all(i.dtype == torch.bfloat16 for i in [dy])
81
+ assert all(i.is_contiguous() for i in [dy, dht])
82
+ dh0 = torch.empty(dht.shape, dtype=dht.dtype, device=dht.device)
83
+ dw, dq, dk, dv, dz, db = [torch.empty_like(x) for x in [w, q, k, v, z, b]]
84
+
85
+ torch.ops.wind_backstepping.backward(
86
+ w, q, k, v, z, b, dy, s, sa, dht, dh0, dw, dq, dk, dv, dz, db
87
+ )
88
+ return (
89
+ cast(dw, DTYPE),
90
+ cast(dq, DTYPE),
91
+ cast(dk, DTYPE),
92
+ cast(dv, DTYPE),
93
+ cast(dz, DTYPE),
94
+ cast(db, DTYPE),
95
+ dh0,
96
+ )
97
+
98
+ def RUN_CUDA_RWKV7g(q, w, k, v, a, b, h0):
99
+ B, T, H, C = q.shape
100
+ q = q.contiguous()
101
+ w = w.contiguous()
102
+ k = k.contiguous()
103
+ v = v.contiguous()
104
+ a = a.contiguous()
105
+ b = b.contiguous()
106
+ out, state = WindBackstepping.apply(w, q, k, v, a, b, h0)
107
+ return out, state
108
+
109
+ def generalized_delta_rule(
110
+ r: torch.Tensor,
111
+ w: torch.Tensor,
112
+ k: torch.Tensor,
113
+ v: torch.Tensor,
114
+ a: torch.Tensor,
115
+ b: torch.Tensor,
116
+ initial_state: torch.Tensor = None,
117
+ output_final_state: bool = True,
118
+ head_first: bool = False,
119
+ use_chunk: bool = True,
120
+ ):
121
+ if w.device.type != "cuda":
122
+ from ..native_keras_op import generalized_delta_rule
123
+
124
+ return generalized_delta_rule(
125
+ r=r,
126
+ k=k,
127
+ v=v,
128
+ a=a,
129
+ b=b,
130
+ w=w,
131
+ initial_state=initial_state,
132
+ output_final_state=output_final_state,
133
+ )
134
+ r = transpose_head(r, head_first)
135
+ k = transpose_head(k, head_first)
136
+ v = transpose_head(v, head_first)
137
+ a = transpose_head(a, head_first)
138
+ b = transpose_head(b, head_first)
139
+ w = transpose_head(w, head_first)
140
+ B, T, H, N = w.shape
141
+ if initial_state is None:
142
+ initial_state = zeros((B, H, N, N), "float32")
143
+ else:
144
+ initial_state = cast(initial_state, "float32")
145
+ out, state = RUN_CUDA_RWKV7g(r, w, k, v, a, b, initial_state)
146
+ if output_final_state:
147
+ return out, state
148
+ return out
149
+
150
+ class Wkv7Inference(torch.autograd.Function):
151
+ @staticmethod
152
+ def forward(ctx, w, q, k, v, a, b, h0):
153
+ B, T, H, N = w.shape
154
+ DTYPE = q.dtype
155
+
156
+ # 类型转换
157
+ q = cast(q, "bfloat16")
158
+ k = cast(k, "bfloat16")
159
+ v = cast(v, "bfloat16")
160
+ a = cast(a, "bfloat16")
161
+ b = cast(b, "bfloat16")
162
+ w = cast(w, "bfloat16")
163
+
164
+ assert all(i.is_contiguous() for i in [w, q, k, v, a, b])
165
+
166
+ # **关键:s 的形状从 (B, H, chunk_num, N, N) 变为 (B, H, N, N) **
167
+ y = torch.empty_like(v)
168
+ s = torch.empty(B, H, N, N, dtype=torch.float32, device=w.device)
169
+
170
+ # 调用推理算子(无 sa)
171
+ torch.ops.wind_backstepping.forward_inference(w, q, k, v, a, b, y, s, h0)
172
+
173
+ return cast(y, DTYPE), s
174
+
175
+ @staticmethod
176
+ def backward(ctx, dy, dht):
177
+ raise NotImplementedError("Inference kernel does not support backward")
178
+
179
+ def RUN_CUDA_RWKV7g_inference(q, w, k, v, a, b, h0):
180
+ B, T, H, C = q.shape
181
+ q = q.contiguous()
182
+ w = w.contiguous()
183
+ k = k.contiguous()
184
+ v = v.contiguous()
185
+ a = a.contiguous()
186
+ b = b.contiguous()
187
+ out, state = Wkv7Inference.apply(w, q, k, v, a, b, h0)
188
+ return out, state
189
+
190
+ # -------------------- 公共推理 API --------------------
191
+ def generalized_delta_rule_inference(
192
+ r: torch.Tensor,
193
+ w: torch.Tensor,
194
+ k: torch.Tensor,
195
+ v: torch.Tensor,
196
+ a: torch.Tensor,
197
+ b: torch.Tensor,
198
+ initial_state: torch.Tensor = None,
199
+ head_first: bool = False,
200
+ output_final_state: bool = True,
201
+ ):
202
+ """
203
+ 纯推理版本,显存占用降低 90%+
204
+
205
+ 参数:
206
+ r,w,k,v,a,b: 输入张量,形状 (B, T, H, K) 或 (B, H, T, K)
207
+ initial_state: (B, H, K, K) 初始状态,None 则零初始化
208
+ head_first: 是否将 head 维提前
209
+ 返回:
210
+ out: (B, T, H, K) 输出
211
+ final_state: (B, H, K, K) 仅最终状态
212
+ """
213
+ if w.device.type != "cuda":
214
+ raise NotImplementedError("Inference kernel only supports CUDA")
215
+
216
+ r = transpose_head(r, head_first)
217
+ k = transpose_head(k, head_first)
218
+ v = transpose_head(v, head_first)
219
+ a = transpose_head(a, head_first)
220
+ b = transpose_head(b, head_first)
221
+ w = transpose_head(w, head_first)
222
+
223
+ B, T, H, N = w.shape
224
+ if initial_state is None:
225
+ initial_state = zeros((B, H, N, N), "float32")
226
+ else:
227
+ initial_state = cast(initial_state, "float32")
228
+
229
+ out, final_state = RUN_CUDA_RWKV7g_inference(r, w, k, v, a, b, initial_state)
230
+ return out, final_state if output_final_state else out
231
+
232
+ # 返回两个函数,用户按需选择
233
+ return [generalized_delta_rule, generalized_delta_rule_inference]
@@ -0,0 +1,101 @@
1
+ #include <cuda_bf16.h>
2
+ #include <assert.h>
3
+ #include <cstdint>
4
+
5
+ using bf = __nv_bfloat16;
6
+
7
+ __device__ inline float to_float(const bf &u) {
8
+ return __bfloat162float(u);
9
+ }
10
+
11
+ __device__ inline bf to_bf(const float &u) {
12
+ return __float2bfloat16_rn(u);
13
+ }
14
+
15
+ typedef bf *__restrict__ F_;
16
+
17
+ // Single-step forward kernel for T=1
18
+ template<int C>
19
+ __launch_bounds__(C, 2)
20
+ __global__ void forward_single_step_kernel(
21
+ int H, // Number of heads
22
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
23
+ float *h0_, // (B, H, C, C) - input state
24
+ bf *y_, // (B, H, C) - output
25
+ float *h1_ // (B, H, C, C) - output state
26
+ ) {
27
+
28
+ int bb = blockIdx.y; // Batch index
29
+ int hh = blockIdx.x; // Head index
30
+ int i = threadIdx.x; // Row index (0..C-1)
31
+
32
+ // Load parameters for this (bb, hh, i)
33
+ // Shape: (B, H, C)
34
+ int64_t param_idx = (int64_t)bb * H * C + hh * C + i;
35
+
36
+ float w_val = to_float(w_[param_idx]);
37
+ w_val = __expf(-__expf(w_val)); // Decay factor
38
+ float q_val = to_float(q_[param_idx]);
39
+ float k_val = to_float(k_[param_idx]);
40
+ float v_val = to_float(v_[param_idx]); // Load per-thread v
41
+ float a_val = to_float(a_[param_idx]);
42
+ float b_val = to_float(b_[param_idx]);
43
+
44
+ // Load state row i from h0_: (B, H, C, C)
45
+ int64_t h0_base = (int64_t)bb * H * C * C + hh * C * C + i * C;
46
+ float state_row[C];
47
+ #pragma unroll
48
+ for (int j = 0; j < C; j++) {
49
+ state_row[j] = h0_[h0_base + j];
50
+ }
51
+
52
+ // Share vectors across threads in block (each thread loads one element)
53
+ __shared__ float shared_a[C], shared_b[C], shared_w[C], shared_k[C], shared_q[C];
54
+
55
+ shared_a[i] = a_val;
56
+ shared_b[i] = b_val;
57
+ shared_w[i] = w_val;
58
+ shared_k[i] = k_val;
59
+ shared_q[i] = q_val;
60
+ __syncthreads();
61
+
62
+ // Compute sa = sum_j(a[j] * state[i][j])
63
+ float sa = 0.0f;
64
+ #pragma unroll
65
+ for (int j = 0; j < C; j++) {
66
+ sa += shared_a[j] * state_row[j];
67
+ }
68
+
69
+ // Update state row i and compute output element i
70
+ float y = 0.0f;
71
+ #pragma unroll
72
+ for (int j = 0; j < C; j++) {
73
+ state_row[j] = state_row[j] * shared_w[j] + sa * shared_b[j] + shared_k[j] * v_val;
74
+ y += state_row[j] * shared_q[j];
75
+ }
76
+
77
+ // Write output y[i]: (B, H, C)
78
+ int64_t y_idx = (int64_t)bb * H * C + hh * C + i;
79
+ y_[y_idx] = to_bf(y);
80
+
81
+ // Write new state row i to h1_: (B, H, C, C)
82
+ int64_t h1_base = (int64_t)bb * H * C * C + hh * C * C + i * C;
83
+ #pragma unroll
84
+ for (int j = 0; j < C; j++) {
85
+ h1_[h1_base + j] = state_row[j];
86
+ }
87
+ }
88
+
89
+
90
+ void cuda_forward_single_step(
91
+ int B, int H,
92
+ bf *w, bf *q, bf *k, bf *v, bf *a, bf *b,
93
+ float *h0, bf *y, float *h1
94
+ ) {
95
+ dim3 blocks(H, B); // (num_heads, batch_size)
96
+ dim3 threads(_C_); // HEAD_SIZE
97
+
98
+ forward_single_step_kernel<_C_><<<blocks, threads>>>(
99
+ H, w, q, k, v, a, b, h0, y, h1
100
+ );
101
+ }
@@ -0,0 +1,56 @@
1
+ #include <torch/extension.h>
2
+ #include <cuda_bf16.h>
3
+
4
+ using bf = __nv_bfloat16;
5
+
6
+ /* 前向声明:与 CUDA 侧一致 */
7
+ void cuda_forward_single_step(
8
+ int B, int H,
9
+ bf* w, bf* q, bf* k, bf* v, bf* a, bf* b,
10
+ float* h0, bf* y, float* h1);
11
+
12
+ /* PyTorch 入口:只负责张量解包与类型转换 */
13
+ void forward_single_step(
14
+ torch::Tensor w, // (B, H, K) bfloat16
15
+ torch::Tensor q, // (B, H, K) bfloat16
16
+ torch::Tensor k, // (B, H, K) bfloat16
17
+ torch::Tensor v, // (B, H, K) bfloat16
18
+ torch::Tensor a, // (B, H, K) bfloat16
19
+ torch::Tensor b, // (B, H, K) bfloat16
20
+ torch::Tensor h0, // (B, H, K, K) float32
21
+ torch::Tensor y, // (B, H, K) bfloat16 输出
22
+ torch::Tensor h1) // (B, H, K, K) float32 输出
23
+ {
24
+ /* 基本校验 */
25
+ TORCH_CHECK(w.device().is_cuda(), "All tensors must be CUDA");
26
+ TORCH_CHECK(w.dtype() == torch::kBFloat16, "w/q/k/v/a/b must be bfloat16");
27
+ TORCH_CHECK(h0.dtype() == torch::kFloat32, "h0/h1 must be float32");
28
+ TORCH_CHECK(w.is_contiguous(), "All tensors must be contiguous");
29
+
30
+ const int B = w.size(0);
31
+ const int H = w.size(1);
32
+ const int K = w.size(2);
33
+
34
+ cuda_forward_single_step(
35
+ B, H,
36
+ reinterpret_cast<bf*>(w.data_ptr()),
37
+ reinterpret_cast<bf*>(q.data_ptr()),
38
+ reinterpret_cast<bf*>(k.data_ptr()),
39
+ reinterpret_cast<bf*>(v.data_ptr()),
40
+ reinterpret_cast<bf*>(a.data_ptr()),
41
+ reinterpret_cast<bf*>(b.data_ptr()),
42
+ h0.data_ptr<float>(),
43
+ reinterpret_cast<bf*>(y.data_ptr()),
44
+ h1.data_ptr<float>());
45
+ }
46
+
47
+ /* 注册算子 */
48
+ TORCH_LIBRARY(wind_backstepping_single_step, m) {
49
+ m.def("forward_single_step("
50
+ "Tensor w, Tensor q, Tensor k, Tensor v, Tensor a, Tensor b, "
51
+ "Tensor h0, Tensor(a!) y, Tensor(b!) h1) -> ()");
52
+ }
53
+
54
+ TORCH_LIBRARY_IMPL(wind_backstepping_single_step, CUDA, m) {
55
+ m.impl("forward_single_step", forward_single_step);
56
+ }
@@ -0,0 +1,112 @@
1
+ import os
2
+ import torch
3
+ from torch.utils.cpp_extension import load
4
+
5
+
6
+ def get_torch_generalized_delta_rule_single_step(HEAD_SIZE=64):
7
+ flags = [
8
+ "-res-usage",
9
+ f"-D_C_={HEAD_SIZE}",
10
+ "-D_CHUNK_LEN_=1",
11
+ "--use_fast_math",
12
+ "-O3",
13
+ "-Xptxas -O3",
14
+ "--extra-device-vectorization",
15
+ ]
16
+ current_dir = os.path.dirname(os.path.abspath(__file__))
17
+ load(
18
+ name="wind_backstepping_single_step",
19
+ sources=[
20
+ os.path.join(current_dir, "wkv7_single_step_cuda.cu"),
21
+ os.path.join(current_dir, "wkv7_single_step_op.cpp"),
22
+ ],
23
+ is_python_module=False,
24
+ verbose=False,
25
+ extra_cuda_cflags=flags,
26
+ )
27
+
28
+ class WindBacksteppingSingleStep(torch.autograd.Function):
29
+ @staticmethod
30
+ def forward(ctx, w, q, k, v, a, b, h0):
31
+ DTYPE = q.dtype
32
+ w = w.contiguous().bfloat16()
33
+ q = q.contiguous().bfloat16()
34
+ k = k.contiguous().bfloat16()
35
+ v = v.contiguous().bfloat16()
36
+ a = a.contiguous().bfloat16()
37
+ b = b.contiguous().bfloat16()
38
+ h0 = h0.contiguous().float()
39
+ y = torch.empty_like(v)
40
+ h1 = torch.empty_like(h0)
41
+ torch.ops.wind_backstepping_single_step.forward_single_step(
42
+ w, q, k, v, a, b, h0, y, h1
43
+ )
44
+ return y.to(DTYPE), h1
45
+
46
+ @staticmethod
47
+ def backward(ctx, *grads):
48
+ raise NotImplementedError("single-step kernel does not support backward")
49
+
50
+ def run_single_step(w, q, k, v, a, b, h0):
51
+ return WindBacksteppingSingleStep.apply(w, q, k, v, a, b, h0)
52
+
53
+ def generalized_delta_rule(
54
+ r: torch.Tensor,
55
+ w: torch.Tensor,
56
+ k: torch.Tensor,
57
+ v: torch.Tensor,
58
+ a: torch.Tensor,
59
+ b: torch.Tensor,
60
+ *,
61
+ initial_state: torch.Tensor = None,
62
+ output_final_state: bool = True,
63
+ head_first: bool = False,
64
+ ):
65
+ """
66
+ 单步 RWKV7 前向,输入形状:
67
+ head_first=False -> (B, 1, H, K) **默认**
68
+ head_first=True -> (B, H, 1, K)
69
+ 输出形状与输入保持一致。
70
+ """
71
+ if w.device.type != "cuda":
72
+ from ..native_keras_op import generalized_delta_rule
73
+
74
+ return generalized_delta_rule(
75
+ r=r,
76
+ k=k,
77
+ v=v,
78
+ a=a,
79
+ b=b,
80
+ w=w,
81
+ initial_state=initial_state,
82
+ output_final_state=output_final_state,
83
+ )
84
+ # 1. 统一先转成 (B, H, K)
85
+ if head_first: # (B, H, 1, K) -> (B, H, K)
86
+ r = r.squeeze(2)
87
+ w = w.squeeze(2)
88
+ k = k.squeeze(2)
89
+ v = v.squeeze(2)
90
+ a = a.squeeze(2)
91
+ b = b.squeeze(2)
92
+ else: # (B, 1, H, K) -> (B, H, K)
93
+ r = r.squeeze(1)
94
+ w = w.squeeze(1)
95
+ k = k.squeeze(1)
96
+ v = v.squeeze(1)
97
+ a = a.squeeze(1)
98
+ b = b.squeeze(1)
99
+
100
+ B, H, K = r.shape
101
+ if initial_state is None:
102
+ initial_state = torch.zeros(
103
+ B, H, K, K, dtype=torch.float32, device=r.device
104
+ )
105
+
106
+ # 2. 计算
107
+ y, h1 = run_single_step(w, r, k, v, a, b, initial_state) # y:(B,H,K)
108
+ y = y.unsqueeze(1) # (B, 1, H, K)
109
+
110
+ return (y, h1) if output_final_state else y
111
+
112
+ return generalized_delta_rule
@@ -0,0 +1,13 @@
1
+ from ..torch_kernel.chunk_A_fwd import *
2
+ from ..torch_kernel.chunk_A_bwd import *
3
+
4
+ # ---------- chunk_h ----------
5
+ from ..torch_kernel.chunk_h_fwd import *
6
+ from ..torch_kernel.chunk_h_bwd import *
7
+
8
+ # ---------- chunk_o ----------
9
+ from ..torch_kernel.chunk_o_fwd import *
10
+ from ..torch_kernel.chunk_o_bwd import *
11
+ from ..torch_kernel.cumsum import *
12
+ from ..torch_kernel.wy_fast_fwd import *
13
+ from ..torch_kernel.wy_fast_bwd import *
@@ -0,0 +1,96 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ import torch
6
+ import triton
7
+ from ..triton_kernel.chunk_A_bwd import *
8
+ from ..triton_kernel.utils import is_gather_supported
9
+ from ..get_torch_devices_info import check_shared_mem
10
+
11
+
12
+ def chunk_dplr_bwd_dqk_intra(
13
+ q: torch.Tensor,
14
+ k: torch.Tensor,
15
+ a: torch.Tensor,
16
+ b: torch.Tensor,
17
+ gi: torch.Tensor,
18
+ ge: torch.Tensor,
19
+ dAqk: torch.Tensor,
20
+ dAqb: torch.Tensor,
21
+ dAak: torch.Tensor,
22
+ dAab: torch.Tensor,
23
+ dqg: torch.Tensor,
24
+ dkg: torch.Tensor,
25
+ dag: torch.Tensor,
26
+ dbg: torch.Tensor,
27
+ dgk_last: torch.Tensor,
28
+ scale: float = 1.0,
29
+ chunk_size: int = 16,
30
+ ):
31
+ B, T, H, K = q.shape
32
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
33
+ BK = (
34
+ min(64, triton.next_power_of_2(K))
35
+ if check_shared_mem()
36
+ else min(32, triton.next_power_of_2(K))
37
+ )
38
+
39
+ NT = triton.cdiv(T, BT)
40
+ NK = triton.cdiv(K, BK)
41
+ grid = (NK, NT, B * H)
42
+
43
+ dq = torch.empty_like(q)
44
+ dk = torch.empty_like(k)
45
+ da = torch.empty_like(a)
46
+ db = torch.empty_like(b)
47
+ dgk = torch.empty_like(gi, dtype=torch.float)
48
+ dgk_offset = torch.empty_like(gi, dtype=torch.float)
49
+
50
+ chunk_dplr_bwd_kernel_intra[grid](
51
+ q=q,
52
+ k=k,
53
+ a=a,
54
+ b=b,
55
+ gi=gi,
56
+ ge=ge,
57
+ dAqk=dAqk,
58
+ dAqb=dAqb,
59
+ dAak=dAak,
60
+ dAab=dAab,
61
+ dq=dq,
62
+ dk=dk,
63
+ dgk=dgk,
64
+ dgk_offset=dgk_offset,
65
+ dqg=dqg,
66
+ dkg=dkg,
67
+ dag=dag,
68
+ dbg=dbg,
69
+ da=da,
70
+ db=db,
71
+ scale=scale,
72
+ T=T,
73
+ H=H,
74
+ K=K,
75
+ BT=BT,
76
+ BC=BT,
77
+ BK=BK,
78
+ GATHER_SUPPORTED=is_gather_supported,
79
+ )
80
+
81
+ dgk_output = torch.empty_like(dgk)
82
+
83
+ def grid(meta):
84
+ return (NT, triton.cdiv(K, meta["BK"]), B * H)
85
+
86
+ chunk_dplr_bwd_dgk_kernel[grid](
87
+ dgk=dgk,
88
+ dgk_offset=dgk_offset,
89
+ dgk_last=dgk_last,
90
+ dgk_output=dgk_output,
91
+ T=T,
92
+ H=H,
93
+ K=K,
94
+ BT=BT,
95
+ )
96
+ return dq, dk, da, db, dgk_output
@@ -0,0 +1,64 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ import torch
6
+ import triton
7
+
8
+ from ..triton_kernel.utils import is_gather_supported
9
+
10
+ from ..triton_kernel.chunk_A_fwd import *
11
+
12
+
13
+ def chunk_dplr_fwd_intra(
14
+ q: torch.Tensor,
15
+ k: torch.Tensor,
16
+ a: torch.Tensor,
17
+ b: torch.Tensor,
18
+ gi: torch.Tensor,
19
+ ge: torch.Tensor,
20
+ scale: float,
21
+ chunk_size: int,
22
+ ):
23
+ B, T, H, K = k.shape
24
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
25
+
26
+ NT = triton.cdiv(T, BT)
27
+
28
+ Aqk = q.new_empty(B, T, H, BT, dtype=q.dtype)
29
+ Aqb = q.new_empty(B, T, H, BT, dtype=q.dtype)
30
+ # involving matrix inverse and it'd be better to use float here.
31
+ Aab = q.new_empty(B, T, H, BT, dtype=torch.float)
32
+ Aak = q.new_empty(B, T, H, BT, dtype=torch.float)
33
+
34
+ grid = (NT, B, H)
35
+ BK = triton.next_power_of_2(K)
36
+ qg = torch.empty_like(q)
37
+ kg = torch.empty_like(k, dtype=q.dtype)
38
+ ag = torch.empty_like(a, dtype=q.dtype)
39
+ bg = torch.empty_like(b, dtype=q.dtype)
40
+ chunk_dplr_fwd_A_kernel_intra_sub_intra[grid](
41
+ q=q,
42
+ k=k,
43
+ a=a,
44
+ b=b,
45
+ gi=gi,
46
+ ge=ge,
47
+ Aqk=Aqk,
48
+ Aqb=Aqb,
49
+ Aab=Aab,
50
+ Aak=Aak,
51
+ qg=qg,
52
+ kg=kg,
53
+ ag=ag,
54
+ bg=bg,
55
+ scale=scale,
56
+ T=T,
57
+ H=H,
58
+ K=K,
59
+ BT=BT,
60
+ BC=BT,
61
+ BK=BK,
62
+ GATHER_SUPPORTED=is_gather_supported,
63
+ )
64
+ return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg