rwkv-ops 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_ops/__init__.py +45 -0
- rwkv_ops/mhc_kernel/__init__.py +50 -0
- rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
- rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
- rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
- rwkv_ops/rwkv6_kernel/__init__.py +120 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +113 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.6.1.dist-info/METADATA +495 -0
- rwkv_ops-0.6.1.dist-info/RECORD +89 -0
- rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
- rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from torch.utils.cpp_extension import load
|
|
4
|
+
from keras.src.backend.torch.core import cast
|
|
5
|
+
from keras.src.backend.torch.numpy import transpose, zeros
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def transpose_head(x, head_first):
|
|
9
|
+
if head_first:
|
|
10
|
+
return transpose(x, (0, 2, 1, 3))
|
|
11
|
+
else:
|
|
12
|
+
return x
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_torch_generalized_delta_rule(HEAD_SIZE=64):
|
|
16
|
+
CHUNK_LEN = 16
|
|
17
|
+
flags = [
|
|
18
|
+
"-res-usage",
|
|
19
|
+
f"-D_C_={HEAD_SIZE}",
|
|
20
|
+
f"-D_CHUNK_LEN_={CHUNK_LEN}",
|
|
21
|
+
"--use_fast_math",
|
|
22
|
+
"-O3",
|
|
23
|
+
"-Xptxas -O3",
|
|
24
|
+
"--extra-device-vectorization",
|
|
25
|
+
]
|
|
26
|
+
# 获取当前文件的绝对路径
|
|
27
|
+
current_file_path = os.path.abspath(__file__)
|
|
28
|
+
|
|
29
|
+
# 获取当前文件的目录路径
|
|
30
|
+
current_dir_path = os.path.dirname(current_file_path)
|
|
31
|
+
load(
|
|
32
|
+
name="wind_backstepping",
|
|
33
|
+
sources=[
|
|
34
|
+
os.path.join(current_dir_path, "wkv7_cuda.cu"),
|
|
35
|
+
os.path.join(current_dir_path, "wkv7_op.cpp"),
|
|
36
|
+
],
|
|
37
|
+
is_python_module=False,
|
|
38
|
+
verbose=True,
|
|
39
|
+
extra_cuda_cflags=flags,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
class WindBackstepping(torch.autograd.Function):
|
|
43
|
+
@staticmethod
|
|
44
|
+
def forward(ctx, w, q, k, v, z, b, h0):
|
|
45
|
+
B, T, H, N = w.shape
|
|
46
|
+
DTYPE = q.dtype
|
|
47
|
+
q = cast(q, "bfloat16")
|
|
48
|
+
k = cast(k, "bfloat16")
|
|
49
|
+
v = cast(v, "bfloat16")
|
|
50
|
+
z = cast(z, "bfloat16")
|
|
51
|
+
b = cast(b, "bfloat16")
|
|
52
|
+
w = cast(w, "bfloat16")
|
|
53
|
+
if T % CHUNK_LEN != 0:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
"RWKV输入的序列长度必须可以被16整除"
|
|
56
|
+
"Please make sure the sequence length is divisible by 16"
|
|
57
|
+
)
|
|
58
|
+
assert all(i.is_contiguous() for i in [w, q, k, v, z, b])
|
|
59
|
+
y = torch.empty_like(v)
|
|
60
|
+
s = torch.empty(
|
|
61
|
+
B, H, T // CHUNK_LEN, N, N, dtype=torch.float32, device=w.device
|
|
62
|
+
)
|
|
63
|
+
sa = torch.empty(B, T, H, N, dtype=torch.float32, device=w.device)
|
|
64
|
+
torch.ops.wind_backstepping.forward(w, q, k, v, z, b, y, s, sa, h0)
|
|
65
|
+
ctx.save_for_backward(w, q, k, v, z, b, s, sa)
|
|
66
|
+
last_state = torch.empty_like(h0)
|
|
67
|
+
last_state.copy_(transpose(s[:, :, -1], [0, 1, 3, 2]))
|
|
68
|
+
|
|
69
|
+
return cast(y, DTYPE), last_state
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def backward(ctx, dy, dht):
|
|
73
|
+
DTYPE = dy.dtype
|
|
74
|
+
dy = cast(dy, torch.bfloat16)
|
|
75
|
+
dy = dy.contiguous()
|
|
76
|
+
|
|
77
|
+
w, q, k, v, z, b, s, sa = ctx.saved_tensors
|
|
78
|
+
dht = cast(dht, "float32")
|
|
79
|
+
dht = dht.contiguous()
|
|
80
|
+
assert all(i.dtype == torch.bfloat16 for i in [dy])
|
|
81
|
+
assert all(i.is_contiguous() for i in [dy, dht])
|
|
82
|
+
dh0 = torch.empty(dht.shape, dtype=dht.dtype, device=dht.device)
|
|
83
|
+
dw, dq, dk, dv, dz, db = [torch.empty_like(x) for x in [w, q, k, v, z, b]]
|
|
84
|
+
|
|
85
|
+
torch.ops.wind_backstepping.backward(
|
|
86
|
+
w, q, k, v, z, b, dy, s, sa, dht, dh0, dw, dq, dk, dv, dz, db
|
|
87
|
+
)
|
|
88
|
+
return (
|
|
89
|
+
cast(dw, DTYPE),
|
|
90
|
+
cast(dq, DTYPE),
|
|
91
|
+
cast(dk, DTYPE),
|
|
92
|
+
cast(dv, DTYPE),
|
|
93
|
+
cast(dz, DTYPE),
|
|
94
|
+
cast(db, DTYPE),
|
|
95
|
+
dh0,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def RUN_CUDA_RWKV7g(q, w, k, v, a, b, h0):
|
|
99
|
+
B, T, H, C = q.shape
|
|
100
|
+
q = q.contiguous()
|
|
101
|
+
w = w.contiguous()
|
|
102
|
+
k = k.contiguous()
|
|
103
|
+
v = v.contiguous()
|
|
104
|
+
a = a.contiguous()
|
|
105
|
+
b = b.contiguous()
|
|
106
|
+
out, state = WindBackstepping.apply(w, q, k, v, a, b, h0)
|
|
107
|
+
return out, state
|
|
108
|
+
|
|
109
|
+
def generalized_delta_rule(
|
|
110
|
+
r: torch.Tensor,
|
|
111
|
+
w: torch.Tensor,
|
|
112
|
+
k: torch.Tensor,
|
|
113
|
+
v: torch.Tensor,
|
|
114
|
+
a: torch.Tensor,
|
|
115
|
+
b: torch.Tensor,
|
|
116
|
+
initial_state: torch.Tensor = None,
|
|
117
|
+
output_final_state: bool = True,
|
|
118
|
+
head_first: bool = False,
|
|
119
|
+
use_chunk: bool = True,
|
|
120
|
+
):
|
|
121
|
+
if w.device.type != "cuda":
|
|
122
|
+
from ..native_keras_op import generalized_delta_rule
|
|
123
|
+
|
|
124
|
+
return generalized_delta_rule(
|
|
125
|
+
r=r,
|
|
126
|
+
k=k,
|
|
127
|
+
v=v,
|
|
128
|
+
a=a,
|
|
129
|
+
b=b,
|
|
130
|
+
w=w,
|
|
131
|
+
initial_state=initial_state,
|
|
132
|
+
output_final_state=output_final_state,
|
|
133
|
+
)
|
|
134
|
+
r = transpose_head(r, head_first)
|
|
135
|
+
k = transpose_head(k, head_first)
|
|
136
|
+
v = transpose_head(v, head_first)
|
|
137
|
+
a = transpose_head(a, head_first)
|
|
138
|
+
b = transpose_head(b, head_first)
|
|
139
|
+
w = transpose_head(w, head_first)
|
|
140
|
+
B, T, H, N = w.shape
|
|
141
|
+
if initial_state is None:
|
|
142
|
+
initial_state = zeros((B, H, N, N), "float32")
|
|
143
|
+
else:
|
|
144
|
+
initial_state = cast(initial_state, "float32")
|
|
145
|
+
out, state = RUN_CUDA_RWKV7g(r, w, k, v, a, b, initial_state)
|
|
146
|
+
if output_final_state:
|
|
147
|
+
return out, state
|
|
148
|
+
return out
|
|
149
|
+
|
|
150
|
+
class Wkv7Inference(torch.autograd.Function):
|
|
151
|
+
@staticmethod
|
|
152
|
+
def forward(ctx, w, q, k, v, a, b, h0):
|
|
153
|
+
B, T, H, N = w.shape
|
|
154
|
+
DTYPE = q.dtype
|
|
155
|
+
|
|
156
|
+
# 类型转换
|
|
157
|
+
q = cast(q, "bfloat16")
|
|
158
|
+
k = cast(k, "bfloat16")
|
|
159
|
+
v = cast(v, "bfloat16")
|
|
160
|
+
a = cast(a, "bfloat16")
|
|
161
|
+
b = cast(b, "bfloat16")
|
|
162
|
+
w = cast(w, "bfloat16")
|
|
163
|
+
|
|
164
|
+
assert all(i.is_contiguous() for i in [w, q, k, v, a, b])
|
|
165
|
+
|
|
166
|
+
# **关键:s 的形状从 (B, H, chunk_num, N, N) 变为 (B, H, N, N) **
|
|
167
|
+
y = torch.empty_like(v)
|
|
168
|
+
s = torch.empty(B, H, N, N, dtype=torch.float32, device=w.device)
|
|
169
|
+
|
|
170
|
+
# 调用推理算子(无 sa)
|
|
171
|
+
torch.ops.wind_backstepping.forward_inference(w, q, k, v, a, b, y, s, h0)
|
|
172
|
+
|
|
173
|
+
return cast(y, DTYPE), s
|
|
174
|
+
|
|
175
|
+
@staticmethod
|
|
176
|
+
def backward(ctx, dy, dht):
|
|
177
|
+
raise NotImplementedError("Inference kernel does not support backward")
|
|
178
|
+
|
|
179
|
+
def RUN_CUDA_RWKV7g_inference(q, w, k, v, a, b, h0):
|
|
180
|
+
B, T, H, C = q.shape
|
|
181
|
+
q = q.contiguous()
|
|
182
|
+
w = w.contiguous()
|
|
183
|
+
k = k.contiguous()
|
|
184
|
+
v = v.contiguous()
|
|
185
|
+
a = a.contiguous()
|
|
186
|
+
b = b.contiguous()
|
|
187
|
+
out, state = Wkv7Inference.apply(w, q, k, v, a, b, h0)
|
|
188
|
+
return out, state
|
|
189
|
+
|
|
190
|
+
# -------------------- 公共推理 API --------------------
|
|
191
|
+
def generalized_delta_rule_inference(
|
|
192
|
+
r: torch.Tensor,
|
|
193
|
+
w: torch.Tensor,
|
|
194
|
+
k: torch.Tensor,
|
|
195
|
+
v: torch.Tensor,
|
|
196
|
+
a: torch.Tensor,
|
|
197
|
+
b: torch.Tensor,
|
|
198
|
+
initial_state: torch.Tensor = None,
|
|
199
|
+
head_first: bool = False,
|
|
200
|
+
output_final_state: bool = True,
|
|
201
|
+
):
|
|
202
|
+
"""
|
|
203
|
+
纯推理版本,显存占用降低 90%+
|
|
204
|
+
|
|
205
|
+
参数:
|
|
206
|
+
r,w,k,v,a,b: 输入张量,形状 (B, T, H, K) 或 (B, H, T, K)
|
|
207
|
+
initial_state: (B, H, K, K) 初始状态,None 则零初始化
|
|
208
|
+
head_first: 是否将 head 维提前
|
|
209
|
+
返回:
|
|
210
|
+
out: (B, T, H, K) 输出
|
|
211
|
+
final_state: (B, H, K, K) 仅最终状态
|
|
212
|
+
"""
|
|
213
|
+
if w.device.type != "cuda":
|
|
214
|
+
raise NotImplementedError("Inference kernel only supports CUDA")
|
|
215
|
+
|
|
216
|
+
r = transpose_head(r, head_first)
|
|
217
|
+
k = transpose_head(k, head_first)
|
|
218
|
+
v = transpose_head(v, head_first)
|
|
219
|
+
a = transpose_head(a, head_first)
|
|
220
|
+
b = transpose_head(b, head_first)
|
|
221
|
+
w = transpose_head(w, head_first)
|
|
222
|
+
|
|
223
|
+
B, T, H, N = w.shape
|
|
224
|
+
if initial_state is None:
|
|
225
|
+
initial_state = zeros((B, H, N, N), "float32")
|
|
226
|
+
else:
|
|
227
|
+
initial_state = cast(initial_state, "float32")
|
|
228
|
+
|
|
229
|
+
out, final_state = RUN_CUDA_RWKV7g_inference(r, w, k, v, a, b, initial_state)
|
|
230
|
+
return out, final_state if output_final_state else out
|
|
231
|
+
|
|
232
|
+
# 返回两个函数,用户按需选择
|
|
233
|
+
return [generalized_delta_rule, generalized_delta_rule_inference]
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
#include <cuda_bf16.h>
|
|
2
|
+
#include <assert.h>
|
|
3
|
+
#include <cstdint>
|
|
4
|
+
|
|
5
|
+
using bf = __nv_bfloat16;
|
|
6
|
+
|
|
7
|
+
__device__ inline float to_float(const bf &u) {
|
|
8
|
+
return __bfloat162float(u);
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
__device__ inline bf to_bf(const float &u) {
|
|
12
|
+
return __float2bfloat16_rn(u);
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
typedef bf *__restrict__ F_;
|
|
16
|
+
|
|
17
|
+
// Single-step forward kernel for T=1
|
|
18
|
+
template<int C>
|
|
19
|
+
__launch_bounds__(C, 2)
|
|
20
|
+
__global__ void forward_single_step_kernel(
|
|
21
|
+
int H, // Number of heads
|
|
22
|
+
F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
|
|
23
|
+
float *h0_, // (B, H, C, C) - input state
|
|
24
|
+
bf *y_, // (B, H, C) - output
|
|
25
|
+
float *h1_ // (B, H, C, C) - output state
|
|
26
|
+
) {
|
|
27
|
+
|
|
28
|
+
int bb = blockIdx.y; // Batch index
|
|
29
|
+
int hh = blockIdx.x; // Head index
|
|
30
|
+
int i = threadIdx.x; // Row index (0..C-1)
|
|
31
|
+
|
|
32
|
+
// Load parameters for this (bb, hh, i)
|
|
33
|
+
// Shape: (B, H, C)
|
|
34
|
+
int64_t param_idx = (int64_t)bb * H * C + hh * C + i;
|
|
35
|
+
|
|
36
|
+
float w_val = to_float(w_[param_idx]);
|
|
37
|
+
w_val = __expf(-__expf(w_val)); // Decay factor
|
|
38
|
+
float q_val = to_float(q_[param_idx]);
|
|
39
|
+
float k_val = to_float(k_[param_idx]);
|
|
40
|
+
float v_val = to_float(v_[param_idx]); // Load per-thread v
|
|
41
|
+
float a_val = to_float(a_[param_idx]);
|
|
42
|
+
float b_val = to_float(b_[param_idx]);
|
|
43
|
+
|
|
44
|
+
// Load state row i from h0_: (B, H, C, C)
|
|
45
|
+
int64_t h0_base = (int64_t)bb * H * C * C + hh * C * C + i * C;
|
|
46
|
+
float state_row[C];
|
|
47
|
+
#pragma unroll
|
|
48
|
+
for (int j = 0; j < C; j++) {
|
|
49
|
+
state_row[j] = h0_[h0_base + j];
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
// Share vectors across threads in block (each thread loads one element)
|
|
53
|
+
__shared__ float shared_a[C], shared_b[C], shared_w[C], shared_k[C], shared_q[C];
|
|
54
|
+
|
|
55
|
+
shared_a[i] = a_val;
|
|
56
|
+
shared_b[i] = b_val;
|
|
57
|
+
shared_w[i] = w_val;
|
|
58
|
+
shared_k[i] = k_val;
|
|
59
|
+
shared_q[i] = q_val;
|
|
60
|
+
__syncthreads();
|
|
61
|
+
|
|
62
|
+
// Compute sa = sum_j(a[j] * state[i][j])
|
|
63
|
+
float sa = 0.0f;
|
|
64
|
+
#pragma unroll
|
|
65
|
+
for (int j = 0; j < C; j++) {
|
|
66
|
+
sa += shared_a[j] * state_row[j];
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
// Update state row i and compute output element i
|
|
70
|
+
float y = 0.0f;
|
|
71
|
+
#pragma unroll
|
|
72
|
+
for (int j = 0; j < C; j++) {
|
|
73
|
+
state_row[j] = state_row[j] * shared_w[j] + sa * shared_b[j] + shared_k[j] * v_val;
|
|
74
|
+
y += state_row[j] * shared_q[j];
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
// Write output y[i]: (B, H, C)
|
|
78
|
+
int64_t y_idx = (int64_t)bb * H * C + hh * C + i;
|
|
79
|
+
y_[y_idx] = to_bf(y);
|
|
80
|
+
|
|
81
|
+
// Write new state row i to h1_: (B, H, C, C)
|
|
82
|
+
int64_t h1_base = (int64_t)bb * H * C * C + hh * C * C + i * C;
|
|
83
|
+
#pragma unroll
|
|
84
|
+
for (int j = 0; j < C; j++) {
|
|
85
|
+
h1_[h1_base + j] = state_row[j];
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
void cuda_forward_single_step(
|
|
91
|
+
int B, int H,
|
|
92
|
+
bf *w, bf *q, bf *k, bf *v, bf *a, bf *b,
|
|
93
|
+
float *h0, bf *y, float *h1
|
|
94
|
+
) {
|
|
95
|
+
dim3 blocks(H, B); // (num_heads, batch_size)
|
|
96
|
+
dim3 threads(_C_); // HEAD_SIZE
|
|
97
|
+
|
|
98
|
+
forward_single_step_kernel<_C_><<<blocks, threads>>>(
|
|
99
|
+
H, w, q, k, v, a, b, h0, y, h1
|
|
100
|
+
);
|
|
101
|
+
}
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
#include <torch/extension.h>
|
|
2
|
+
#include <cuda_bf16.h>
|
|
3
|
+
|
|
4
|
+
using bf = __nv_bfloat16;
|
|
5
|
+
|
|
6
|
+
/* 前向声明:与 CUDA 侧一致 */
|
|
7
|
+
void cuda_forward_single_step(
|
|
8
|
+
int B, int H,
|
|
9
|
+
bf* w, bf* q, bf* k, bf* v, bf* a, bf* b,
|
|
10
|
+
float* h0, bf* y, float* h1);
|
|
11
|
+
|
|
12
|
+
/* PyTorch 入口:只负责张量解包与类型转换 */
|
|
13
|
+
void forward_single_step(
|
|
14
|
+
torch::Tensor w, // (B, H, K) bfloat16
|
|
15
|
+
torch::Tensor q, // (B, H, K) bfloat16
|
|
16
|
+
torch::Tensor k, // (B, H, K) bfloat16
|
|
17
|
+
torch::Tensor v, // (B, H, K) bfloat16
|
|
18
|
+
torch::Tensor a, // (B, H, K) bfloat16
|
|
19
|
+
torch::Tensor b, // (B, H, K) bfloat16
|
|
20
|
+
torch::Tensor h0, // (B, H, K, K) float32
|
|
21
|
+
torch::Tensor y, // (B, H, K) bfloat16 输出
|
|
22
|
+
torch::Tensor h1) // (B, H, K, K) float32 输出
|
|
23
|
+
{
|
|
24
|
+
/* 基本校验 */
|
|
25
|
+
TORCH_CHECK(w.device().is_cuda(), "All tensors must be CUDA");
|
|
26
|
+
TORCH_CHECK(w.dtype() == torch::kBFloat16, "w/q/k/v/a/b must be bfloat16");
|
|
27
|
+
TORCH_CHECK(h0.dtype() == torch::kFloat32, "h0/h1 must be float32");
|
|
28
|
+
TORCH_CHECK(w.is_contiguous(), "All tensors must be contiguous");
|
|
29
|
+
|
|
30
|
+
const int B = w.size(0);
|
|
31
|
+
const int H = w.size(1);
|
|
32
|
+
const int K = w.size(2);
|
|
33
|
+
|
|
34
|
+
cuda_forward_single_step(
|
|
35
|
+
B, H,
|
|
36
|
+
reinterpret_cast<bf*>(w.data_ptr()),
|
|
37
|
+
reinterpret_cast<bf*>(q.data_ptr()),
|
|
38
|
+
reinterpret_cast<bf*>(k.data_ptr()),
|
|
39
|
+
reinterpret_cast<bf*>(v.data_ptr()),
|
|
40
|
+
reinterpret_cast<bf*>(a.data_ptr()),
|
|
41
|
+
reinterpret_cast<bf*>(b.data_ptr()),
|
|
42
|
+
h0.data_ptr<float>(),
|
|
43
|
+
reinterpret_cast<bf*>(y.data_ptr()),
|
|
44
|
+
h1.data_ptr<float>());
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
/* 注册算子 */
|
|
48
|
+
TORCH_LIBRARY(wind_backstepping_single_step, m) {
|
|
49
|
+
m.def("forward_single_step("
|
|
50
|
+
"Tensor w, Tensor q, Tensor k, Tensor v, Tensor a, Tensor b, "
|
|
51
|
+
"Tensor h0, Tensor(a!) y, Tensor(b!) h1) -> ()");
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
TORCH_LIBRARY_IMPL(wind_backstepping_single_step, CUDA, m) {
|
|
55
|
+
m.impl("forward_single_step", forward_single_step);
|
|
56
|
+
}
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from torch.utils.cpp_extension import load
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_torch_generalized_delta_rule_single_step(HEAD_SIZE=64):
|
|
7
|
+
flags = [
|
|
8
|
+
"-res-usage",
|
|
9
|
+
f"-D_C_={HEAD_SIZE}",
|
|
10
|
+
"-D_CHUNK_LEN_=1",
|
|
11
|
+
"--use_fast_math",
|
|
12
|
+
"-O3",
|
|
13
|
+
"-Xptxas -O3",
|
|
14
|
+
"--extra-device-vectorization",
|
|
15
|
+
]
|
|
16
|
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
17
|
+
load(
|
|
18
|
+
name="wind_backstepping_single_step",
|
|
19
|
+
sources=[
|
|
20
|
+
os.path.join(current_dir, "wkv7_single_step_cuda.cu"),
|
|
21
|
+
os.path.join(current_dir, "wkv7_single_step_op.cpp"),
|
|
22
|
+
],
|
|
23
|
+
is_python_module=False,
|
|
24
|
+
verbose=False,
|
|
25
|
+
extra_cuda_cflags=flags,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
class WindBacksteppingSingleStep(torch.autograd.Function):
|
|
29
|
+
@staticmethod
|
|
30
|
+
def forward(ctx, w, q, k, v, a, b, h0):
|
|
31
|
+
DTYPE = q.dtype
|
|
32
|
+
w = w.contiguous().bfloat16()
|
|
33
|
+
q = q.contiguous().bfloat16()
|
|
34
|
+
k = k.contiguous().bfloat16()
|
|
35
|
+
v = v.contiguous().bfloat16()
|
|
36
|
+
a = a.contiguous().bfloat16()
|
|
37
|
+
b = b.contiguous().bfloat16()
|
|
38
|
+
h0 = h0.contiguous().float()
|
|
39
|
+
y = torch.empty_like(v)
|
|
40
|
+
h1 = torch.empty_like(h0)
|
|
41
|
+
torch.ops.wind_backstepping_single_step.forward_single_step(
|
|
42
|
+
w, q, k, v, a, b, h0, y, h1
|
|
43
|
+
)
|
|
44
|
+
return y.to(DTYPE), h1
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def backward(ctx, *grads):
|
|
48
|
+
raise NotImplementedError("single-step kernel does not support backward")
|
|
49
|
+
|
|
50
|
+
def run_single_step(w, q, k, v, a, b, h0):
|
|
51
|
+
return WindBacksteppingSingleStep.apply(w, q, k, v, a, b, h0)
|
|
52
|
+
|
|
53
|
+
def generalized_delta_rule(
|
|
54
|
+
r: torch.Tensor,
|
|
55
|
+
w: torch.Tensor,
|
|
56
|
+
k: torch.Tensor,
|
|
57
|
+
v: torch.Tensor,
|
|
58
|
+
a: torch.Tensor,
|
|
59
|
+
b: torch.Tensor,
|
|
60
|
+
*,
|
|
61
|
+
initial_state: torch.Tensor = None,
|
|
62
|
+
output_final_state: bool = True,
|
|
63
|
+
head_first: bool = False,
|
|
64
|
+
):
|
|
65
|
+
"""
|
|
66
|
+
单步 RWKV7 前向,输入形状:
|
|
67
|
+
head_first=False -> (B, 1, H, K) **默认**
|
|
68
|
+
head_first=True -> (B, H, 1, K)
|
|
69
|
+
输出形状与输入保持一致。
|
|
70
|
+
"""
|
|
71
|
+
if w.device.type != "cuda":
|
|
72
|
+
from ..native_keras_op import generalized_delta_rule
|
|
73
|
+
|
|
74
|
+
return generalized_delta_rule(
|
|
75
|
+
r=r,
|
|
76
|
+
k=k,
|
|
77
|
+
v=v,
|
|
78
|
+
a=a,
|
|
79
|
+
b=b,
|
|
80
|
+
w=w,
|
|
81
|
+
initial_state=initial_state,
|
|
82
|
+
output_final_state=output_final_state,
|
|
83
|
+
)
|
|
84
|
+
# 1. 统一先转成 (B, H, K)
|
|
85
|
+
if head_first: # (B, H, 1, K) -> (B, H, K)
|
|
86
|
+
r = r.squeeze(2)
|
|
87
|
+
w = w.squeeze(2)
|
|
88
|
+
k = k.squeeze(2)
|
|
89
|
+
v = v.squeeze(2)
|
|
90
|
+
a = a.squeeze(2)
|
|
91
|
+
b = b.squeeze(2)
|
|
92
|
+
else: # (B, 1, H, K) -> (B, H, K)
|
|
93
|
+
r = r.squeeze(1)
|
|
94
|
+
w = w.squeeze(1)
|
|
95
|
+
k = k.squeeze(1)
|
|
96
|
+
v = v.squeeze(1)
|
|
97
|
+
a = a.squeeze(1)
|
|
98
|
+
b = b.squeeze(1)
|
|
99
|
+
|
|
100
|
+
B, H, K = r.shape
|
|
101
|
+
if initial_state is None:
|
|
102
|
+
initial_state = torch.zeros(
|
|
103
|
+
B, H, K, K, dtype=torch.float32, device=r.device
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# 2. 计算
|
|
107
|
+
y, h1 = run_single_step(w, r, k, v, a, b, initial_state) # y:(B,H,K)
|
|
108
|
+
y = y.unsqueeze(1) # (B, 1, H, K)
|
|
109
|
+
|
|
110
|
+
return (y, h1) if output_final_state else y
|
|
111
|
+
|
|
112
|
+
return generalized_delta_rule
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from ..torch_kernel.chunk_A_fwd import *
|
|
2
|
+
from ..torch_kernel.chunk_A_bwd import *
|
|
3
|
+
|
|
4
|
+
# ---------- chunk_h ----------
|
|
5
|
+
from ..torch_kernel.chunk_h_fwd import *
|
|
6
|
+
from ..torch_kernel.chunk_h_bwd import *
|
|
7
|
+
|
|
8
|
+
# ---------- chunk_o ----------
|
|
9
|
+
from ..torch_kernel.chunk_o_fwd import *
|
|
10
|
+
from ..torch_kernel.chunk_o_bwd import *
|
|
11
|
+
from ..torch_kernel.cumsum import *
|
|
12
|
+
from ..torch_kernel.wy_fast_fwd import *
|
|
13
|
+
from ..torch_kernel.wy_fast_bwd import *
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import triton
|
|
7
|
+
from ..triton_kernel.chunk_A_bwd import *
|
|
8
|
+
from ..triton_kernel.utils import is_gather_supported
|
|
9
|
+
from ..get_torch_devices_info import check_shared_mem
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def chunk_dplr_bwd_dqk_intra(
|
|
13
|
+
q: torch.Tensor,
|
|
14
|
+
k: torch.Tensor,
|
|
15
|
+
a: torch.Tensor,
|
|
16
|
+
b: torch.Tensor,
|
|
17
|
+
gi: torch.Tensor,
|
|
18
|
+
ge: torch.Tensor,
|
|
19
|
+
dAqk: torch.Tensor,
|
|
20
|
+
dAqb: torch.Tensor,
|
|
21
|
+
dAak: torch.Tensor,
|
|
22
|
+
dAab: torch.Tensor,
|
|
23
|
+
dqg: torch.Tensor,
|
|
24
|
+
dkg: torch.Tensor,
|
|
25
|
+
dag: torch.Tensor,
|
|
26
|
+
dbg: torch.Tensor,
|
|
27
|
+
dgk_last: torch.Tensor,
|
|
28
|
+
scale: float = 1.0,
|
|
29
|
+
chunk_size: int = 16,
|
|
30
|
+
):
|
|
31
|
+
B, T, H, K = q.shape
|
|
32
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
33
|
+
BK = (
|
|
34
|
+
min(64, triton.next_power_of_2(K))
|
|
35
|
+
if check_shared_mem()
|
|
36
|
+
else min(32, triton.next_power_of_2(K))
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
NT = triton.cdiv(T, BT)
|
|
40
|
+
NK = triton.cdiv(K, BK)
|
|
41
|
+
grid = (NK, NT, B * H)
|
|
42
|
+
|
|
43
|
+
dq = torch.empty_like(q)
|
|
44
|
+
dk = torch.empty_like(k)
|
|
45
|
+
da = torch.empty_like(a)
|
|
46
|
+
db = torch.empty_like(b)
|
|
47
|
+
dgk = torch.empty_like(gi, dtype=torch.float)
|
|
48
|
+
dgk_offset = torch.empty_like(gi, dtype=torch.float)
|
|
49
|
+
|
|
50
|
+
chunk_dplr_bwd_kernel_intra[grid](
|
|
51
|
+
q=q,
|
|
52
|
+
k=k,
|
|
53
|
+
a=a,
|
|
54
|
+
b=b,
|
|
55
|
+
gi=gi,
|
|
56
|
+
ge=ge,
|
|
57
|
+
dAqk=dAqk,
|
|
58
|
+
dAqb=dAqb,
|
|
59
|
+
dAak=dAak,
|
|
60
|
+
dAab=dAab,
|
|
61
|
+
dq=dq,
|
|
62
|
+
dk=dk,
|
|
63
|
+
dgk=dgk,
|
|
64
|
+
dgk_offset=dgk_offset,
|
|
65
|
+
dqg=dqg,
|
|
66
|
+
dkg=dkg,
|
|
67
|
+
dag=dag,
|
|
68
|
+
dbg=dbg,
|
|
69
|
+
da=da,
|
|
70
|
+
db=db,
|
|
71
|
+
scale=scale,
|
|
72
|
+
T=T,
|
|
73
|
+
H=H,
|
|
74
|
+
K=K,
|
|
75
|
+
BT=BT,
|
|
76
|
+
BC=BT,
|
|
77
|
+
BK=BK,
|
|
78
|
+
GATHER_SUPPORTED=is_gather_supported,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
dgk_output = torch.empty_like(dgk)
|
|
82
|
+
|
|
83
|
+
def grid(meta):
|
|
84
|
+
return (NT, triton.cdiv(K, meta["BK"]), B * H)
|
|
85
|
+
|
|
86
|
+
chunk_dplr_bwd_dgk_kernel[grid](
|
|
87
|
+
dgk=dgk,
|
|
88
|
+
dgk_offset=dgk_offset,
|
|
89
|
+
dgk_last=dgk_last,
|
|
90
|
+
dgk_output=dgk_output,
|
|
91
|
+
T=T,
|
|
92
|
+
H=H,
|
|
93
|
+
K=K,
|
|
94
|
+
BT=BT,
|
|
95
|
+
)
|
|
96
|
+
return dq, dk, da, db, dgk_output
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import triton
|
|
7
|
+
|
|
8
|
+
from ..triton_kernel.utils import is_gather_supported
|
|
9
|
+
|
|
10
|
+
from ..triton_kernel.chunk_A_fwd import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def chunk_dplr_fwd_intra(
|
|
14
|
+
q: torch.Tensor,
|
|
15
|
+
k: torch.Tensor,
|
|
16
|
+
a: torch.Tensor,
|
|
17
|
+
b: torch.Tensor,
|
|
18
|
+
gi: torch.Tensor,
|
|
19
|
+
ge: torch.Tensor,
|
|
20
|
+
scale: float,
|
|
21
|
+
chunk_size: int,
|
|
22
|
+
):
|
|
23
|
+
B, T, H, K = k.shape
|
|
24
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
25
|
+
|
|
26
|
+
NT = triton.cdiv(T, BT)
|
|
27
|
+
|
|
28
|
+
Aqk = q.new_empty(B, T, H, BT, dtype=q.dtype)
|
|
29
|
+
Aqb = q.new_empty(B, T, H, BT, dtype=q.dtype)
|
|
30
|
+
# involving matrix inverse and it'd be better to use float here.
|
|
31
|
+
Aab = q.new_empty(B, T, H, BT, dtype=torch.float)
|
|
32
|
+
Aak = q.new_empty(B, T, H, BT, dtype=torch.float)
|
|
33
|
+
|
|
34
|
+
grid = (NT, B, H)
|
|
35
|
+
BK = triton.next_power_of_2(K)
|
|
36
|
+
qg = torch.empty_like(q)
|
|
37
|
+
kg = torch.empty_like(k, dtype=q.dtype)
|
|
38
|
+
ag = torch.empty_like(a, dtype=q.dtype)
|
|
39
|
+
bg = torch.empty_like(b, dtype=q.dtype)
|
|
40
|
+
chunk_dplr_fwd_A_kernel_intra_sub_intra[grid](
|
|
41
|
+
q=q,
|
|
42
|
+
k=k,
|
|
43
|
+
a=a,
|
|
44
|
+
b=b,
|
|
45
|
+
gi=gi,
|
|
46
|
+
ge=ge,
|
|
47
|
+
Aqk=Aqk,
|
|
48
|
+
Aqb=Aqb,
|
|
49
|
+
Aab=Aab,
|
|
50
|
+
Aak=Aak,
|
|
51
|
+
qg=qg,
|
|
52
|
+
kg=kg,
|
|
53
|
+
ag=ag,
|
|
54
|
+
bg=bg,
|
|
55
|
+
scale=scale,
|
|
56
|
+
T=T,
|
|
57
|
+
H=H,
|
|
58
|
+
K=K,
|
|
59
|
+
BT=BT,
|
|
60
|
+
BC=BT,
|
|
61
|
+
BK=BK,
|
|
62
|
+
GATHER_SUPPORTED=is_gather_supported,
|
|
63
|
+
)
|
|
64
|
+
return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
|