rwkv-ops 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- rwkv_ops/__init__.py +5 -6
- rwkv_ops/rwkv6_kernel/__init__.py +0 -6
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +21 -23
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +14 -10
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +4 -4
- rwkv_ops/rwkv7_kernel/__init__.py +77 -29
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +279 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +237 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +6 -5
- rwkv_ops/rwkv7_kernel/native_keras_op.py +5 -6
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +123 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +165 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +35 -0
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/METADATA +28 -27
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/RECORD +30 -13
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/WHEEL +1 -2
- rwkv_ops-0.2.2.dist-info/top_level.txt +0 -1
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info/licenses}/LICENSE.txt +0 -0
|
@@ -12,18 +12,17 @@ def transpose_head(x, head_first):
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
15
|
-
|
|
15
|
+
USE_TRITON_KERNEL = False
|
|
16
16
|
if keras.config.backend() == "torch":
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
19
|
if KERNEL_TYPE.lower() == "triton":
|
|
20
20
|
from .torch_op import generalized_delta_rule
|
|
21
21
|
|
|
22
|
-
|
|
22
|
+
USE_TRITON_KERNEL = True
|
|
23
23
|
|
|
24
24
|
elif KERNEL_TYPE.lower() == "cuda":
|
|
25
25
|
CHUNK_LEN = 16
|
|
26
|
-
USE_KERNEL = True
|
|
27
26
|
from torch.utils.cpp_extension import load
|
|
28
27
|
import os
|
|
29
28
|
|
|
@@ -44,8 +43,8 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
44
43
|
load(
|
|
45
44
|
name="wind_backstepping",
|
|
46
45
|
sources=[
|
|
47
|
-
os.path.join(current_dir_path, "
|
|
48
|
-
os.path.join(current_dir_path, "
|
|
46
|
+
os.path.join(current_dir_path, "torch_cuda_kernel/wkv7_cuda.cu"),
|
|
47
|
+
os.path.join(current_dir_path, "torch_cuda_kernel/wkv7_op.cpp"),
|
|
49
48
|
],
|
|
50
49
|
is_python_module=False,
|
|
51
50
|
verbose=True,
|
|
@@ -54,8 +53,8 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
54
53
|
|
|
55
54
|
class WindBackstepping(torch.autograd.Function):
|
|
56
55
|
@staticmethod
|
|
57
|
-
def forward(ctx, w, q, k, v, z, b):
|
|
58
|
-
B, T, H,
|
|
56
|
+
def forward(ctx, w, q, k, v, z, b, h0):
|
|
57
|
+
B, T, H, N = w.shape
|
|
59
58
|
DTYPE = q.dtype
|
|
60
59
|
q = ops.cast(q, "bfloat16")
|
|
61
60
|
k = ops.cast(k, "bfloat16")
|
|
@@ -63,30 +62,42 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
63
62
|
z = ops.cast(z, "bfloat16")
|
|
64
63
|
b = ops.cast(b, "bfloat16")
|
|
65
64
|
w = ops.cast(w, "bfloat16")
|
|
66
|
-
|
|
65
|
+
if T % CHUNK_LEN != 0:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"RWKV输入的序列长度必须可以被16整除"
|
|
68
|
+
"Please make sure the sequence length is divisible by 16"
|
|
69
|
+
)
|
|
67
70
|
assert all(i.is_contiguous() for i in [w, q, k, v, z, b])
|
|
68
71
|
y = torch.empty_like(v)
|
|
69
72
|
s = torch.empty(
|
|
70
|
-
B, H, T // CHUNK_LEN,
|
|
73
|
+
B, H, T // CHUNK_LEN, N, N, dtype=torch.float32, device=w.device
|
|
71
74
|
)
|
|
72
|
-
sa = torch.empty(B, T, H,
|
|
73
|
-
torch.ops.wind_backstepping.forward(w, q, k, v, z, b, y, s, sa)
|
|
75
|
+
sa = torch.empty(B, T, H, N, dtype=torch.float32, device=w.device)
|
|
76
|
+
torch.ops.wind_backstepping.forward(w, q, k, v, z, b, y, s, sa, h0)
|
|
74
77
|
ctx.save_for_backward(w, q, k, v, z, b, s, sa)
|
|
75
|
-
|
|
78
|
+
last_state = torch.empty_like(h0)
|
|
79
|
+
last_state.copy_(ops.transpose(s[:, :, -1], [0, 1, 3, 2]))
|
|
80
|
+
|
|
81
|
+
return ops.cast(y, DTYPE), last_state
|
|
76
82
|
|
|
77
83
|
@staticmethod
|
|
78
|
-
def backward(ctx, dy):
|
|
84
|
+
def backward(ctx, dy, dht):
|
|
79
85
|
DTYPE = dy.dtype
|
|
80
86
|
dy = ops.cast(dy, torch.bfloat16)
|
|
81
87
|
dy = dy.contiguous()
|
|
82
|
-
|
|
83
|
-
assert all(i.is_contiguous() for i in [dy])
|
|
88
|
+
|
|
84
89
|
w, q, k, v, z, b, s, sa = ctx.saved_tensors
|
|
90
|
+
dht = ops.cast(dht, "float32")
|
|
91
|
+
dht = dht.contiguous()
|
|
92
|
+
assert all(i.dtype == torch.bfloat16 for i in [dy])
|
|
93
|
+
assert all(i.is_contiguous() for i in [dy, dht])
|
|
94
|
+
dh0 = torch.empty(dht.shape, dtype=dht.dtype, device=dht.device)
|
|
85
95
|
dw, dq, dk, dv, dz, db = [
|
|
86
96
|
torch.empty_like(x) for x in [w, q, k, v, z, b]
|
|
87
97
|
]
|
|
98
|
+
|
|
88
99
|
torch.ops.wind_backstepping.backward(
|
|
89
|
-
w, q, k, v, z, b, dy, s, sa, dw, dq, dk, dv, dz, db
|
|
100
|
+
w, q, k, v, z, b, dy, s, sa, dht, dh0, dw, dq, dk, dv, dz, db
|
|
90
101
|
)
|
|
91
102
|
return (
|
|
92
103
|
ops.cast(dw, DTYPE),
|
|
@@ -95,9 +106,10 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
95
106
|
ops.cast(dv, DTYPE),
|
|
96
107
|
ops.cast(dz, DTYPE),
|
|
97
108
|
ops.cast(db, DTYPE),
|
|
109
|
+
dh0,
|
|
98
110
|
)
|
|
99
111
|
|
|
100
|
-
def RUN_CUDA_RWKV7g(q, w, k, v, a, b):
|
|
112
|
+
def RUN_CUDA_RWKV7g(q, w, k, v, a, b, h0):
|
|
101
113
|
B, T, H, C = q.shape
|
|
102
114
|
q = q.contiguous()
|
|
103
115
|
w = w.contiguous()
|
|
@@ -105,7 +117,8 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
105
117
|
v = v.contiguous()
|
|
106
118
|
a = a.contiguous()
|
|
107
119
|
b = b.contiguous()
|
|
108
|
-
|
|
120
|
+
out, state = WindBackstepping.apply(w, q, k, v, a, b, h0)
|
|
121
|
+
return out, state
|
|
109
122
|
|
|
110
123
|
def generalized_delta_rule(
|
|
111
124
|
r: torch.Tensor,
|
|
@@ -125,26 +138,61 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
125
138
|
a = transpose_head(a, head_first)
|
|
126
139
|
b = transpose_head(b, head_first)
|
|
127
140
|
w = transpose_head(w, head_first)
|
|
128
|
-
|
|
141
|
+
B, T, H, N = w.shape
|
|
142
|
+
if initial_state is None:
|
|
143
|
+
initial_state = ops.zeros((B, H, N, N), "float32")
|
|
144
|
+
else:
|
|
145
|
+
initial_state = ops.cast(initial_state, "float32")
|
|
146
|
+
out, state = RUN_CUDA_RWKV7g(r, w, k, v, a, b, initial_state)
|
|
147
|
+
if output_final_state:
|
|
148
|
+
return out, state
|
|
149
|
+
return out
|
|
129
150
|
else:
|
|
130
151
|
from .native_keras_op import generalized_delta_rule
|
|
131
152
|
|
|
132
|
-
|
|
153
|
+
USE_TRITON_KERNEL = False
|
|
133
154
|
elif keras.config.backend() == "jax":
|
|
134
155
|
from jax.lib import xla_bridge
|
|
135
156
|
import os
|
|
136
157
|
|
|
137
|
-
if (
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
os.environ["JAX_LOG_COMPUTATION"] = "0"
|
|
142
|
-
from .jax_op import generalized_delta_rule
|
|
158
|
+
if xla_bridge.get_backend().platform == "gpu":
|
|
159
|
+
if KERNEL_TYPE.lower() == "triton":
|
|
160
|
+
os.environ["JAX_LOG_COMPUTATION"] = "0"
|
|
161
|
+
from .jax_op import generalized_delta_rule
|
|
143
162
|
|
|
144
|
-
|
|
163
|
+
USE_TRITON_KERNEL = True
|
|
164
|
+
elif KERNEL_TYPE.lower() == "cuda":
|
|
165
|
+
from .jax_cuda_kernel.wkv7_jax import get_jax_generalized_delta_rule
|
|
166
|
+
|
|
167
|
+
generalized_delta_rule = get_jax_generalized_delta_rule(HEAD_SIZE)[0]
|
|
168
|
+
else:
|
|
169
|
+
from .native_keras_op import generalized_delta_rule
|
|
170
|
+
else:
|
|
171
|
+
from .native_keras_op import generalized_delta_rule
|
|
172
|
+
elif keras.config.backend() == "tensorflow":
|
|
173
|
+
import tensorflow as tf
|
|
174
|
+
|
|
175
|
+
if len(tf.config.list_physical_devices("GPU")) > 0:
|
|
176
|
+
if KERNEL_TYPE.lower() == "cuda" and HEAD_SIZE == 64:
|
|
177
|
+
try:
|
|
178
|
+
from jax.lib import xla_bridge
|
|
179
|
+
|
|
180
|
+
assert xla_bridge.get_backend().platform == "gpu"
|
|
181
|
+
except:
|
|
182
|
+
raise (
|
|
183
|
+
"The operation of the TensorFlow kernel depends on the JAX kernel."
|
|
184
|
+
"Therefore, it is necessary to ensure that it can be used in JAX, so that TensorFlow can be used."
|
|
185
|
+
)
|
|
186
|
+
print("🎉" * 10)
|
|
187
|
+
print("Tensorflow CUDA kernel onlt support Forward,not get graident")
|
|
188
|
+
print("🎉" * 10)
|
|
189
|
+
from .tf_eager_kernel import get_tf_generalized_delta_rule
|
|
190
|
+
|
|
191
|
+
generalized_delta_rule = get_tf_generalized_delta_rule(HEAD_SIZE)[0]
|
|
192
|
+
else:
|
|
193
|
+
from .native_keras_op import generalized_delta_rule
|
|
145
194
|
else:
|
|
146
195
|
from .native_keras_op import generalized_delta_rule
|
|
147
|
-
|
|
148
196
|
else:
|
|
149
197
|
from .native_keras_op import generalized_delta_rule
|
|
150
|
-
return generalized_delta_rule,
|
|
198
|
+
return generalized_delta_rule, USE_TRITON_KERNEL
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.18)
|
|
2
|
+
project(wkv7 LANGUAGES CXX CUDA)
|
|
3
|
+
|
|
4
|
+
find_package(CUDAToolkit REQUIRED)
|
|
5
|
+
|
|
6
|
+
# ---------- 1. 找到 Python ----------
|
|
7
|
+
find_package(Python3 REQUIRED COMPONENTS Interpreter)
|
|
8
|
+
|
|
9
|
+
# ---------- 2. 取 XLA 头文件路径 ----------
|
|
10
|
+
execute_process(
|
|
11
|
+
COMMAND "${Python3_EXECUTABLE}" -c "from jax import ffi; print(ffi.include_dir())"
|
|
12
|
+
OUTPUT_VARIABLE XLA_INCLUDE_DIR
|
|
13
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
14
|
+
)
|
|
15
|
+
if(NOT XLA_INCLUDE_DIR)
|
|
16
|
+
message(FATAL_ERROR "Cannot get XLA include dir from jax.ffi")
|
|
17
|
+
endif()
|
|
18
|
+
message(STATUS "XLA include directory: ${XLA_INCLUDE_DIR}")
|
|
19
|
+
|
|
20
|
+
# ---------- 3. 生成共享库 ----------
|
|
21
|
+
add_library(wkv7 SHARED wkv7_ffi.cu)
|
|
22
|
+
|
|
23
|
+
# 3-1. 头文件搜索路径
|
|
24
|
+
target_include_directories(wkv7 PRIVATE ${XLA_INCLUDE_DIR})
|
|
25
|
+
|
|
26
|
+
# 3-2. 链接 CUDA 运行时
|
|
27
|
+
target_link_libraries(wkv7 PRIVATE CUDA::cudart)
|
|
28
|
+
|
|
29
|
+
# 3-3. 关键:C++17 / CUDA17 标准
|
|
30
|
+
target_compile_features(wkv7 PUBLIC cxx_std_17)
|
|
31
|
+
set_target_properties(wkv7 PROPERTIES
|
|
32
|
+
CUDA_STANDARD 17
|
|
33
|
+
CUDA_SEPARABLE_COMPILATION ON
|
|
34
|
+
POSITION_INDEPENDENT_CODE ON
|
|
35
|
+
PREFIX "" # 去掉默认的 "lib" 前缀
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# ---------- 4. 安装 ----------
|
|
39
|
+
# 把 .so 直接装到源码目录(与 wkv7_jax.py 同一级),方便 ctypes.CDLL 加载
|
|
40
|
+
install(TARGETS wkv7
|
|
41
|
+
LIBRARY DESTINATION "${CMAKE_SOURCE_DIR}"
|
|
42
|
+
RUNTIME DESTINATION "${CMAKE_SOURCE_DIR}") # Windows 用 RUNTIME
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* wkv7_ffi_bf16.cu
|
|
3
|
+
* BF16 版本,外部接口 BF16,内部 kernel 保持原样
|
|
4
|
+
*/
|
|
5
|
+
#include <cuda_bf16.h>
|
|
6
|
+
#include <cuda_runtime.h>
|
|
7
|
+
#include <xla/ffi/api/ffi.h>
|
|
8
|
+
#include <vector>
|
|
9
|
+
|
|
10
|
+
namespace ffi = xla::ffi;
|
|
11
|
+
|
|
12
|
+
/* -------------------- 类型别名 -------------------- */
|
|
13
|
+
using bf = __nv_bfloat16;
|
|
14
|
+
|
|
15
|
+
/* -------------------- 设备端辅助(kernel 里用) -------------------- */
|
|
16
|
+
__device__ inline float to_float(const bf &u) {
|
|
17
|
+
return __bfloat162float(u);
|
|
18
|
+
}
|
|
19
|
+
__device__ inline bf to_bf(const float &u) {
|
|
20
|
+
return __float2bfloat16_rn(u);
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
typedef bf *__restrict__ F_;
|
|
24
|
+
|
|
25
|
+
/* -------------------- 你的 kernel(禁止修改) -------------------- */
|
|
26
|
+
__global__ void forward_kernel(int T, int H,
|
|
27
|
+
F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
|
|
28
|
+
bf *y_, float *s_, float *sa_, float *h0_) {
|
|
29
|
+
constexpr int C = _C_;
|
|
30
|
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
|
31
|
+
float state[C] = {0};
|
|
32
|
+
__shared__ float q[C], k[C], w[C], a[C], b[C];
|
|
33
|
+
int h0_base = ((bb * H + hh) * C + i) * C;
|
|
34
|
+
#pragma unroll
|
|
35
|
+
for (int j = 0; j < C; ++j) state[j] = h0_[h0_base + j];
|
|
36
|
+
|
|
37
|
+
for (int t = 0; t < T; ++t) {
|
|
38
|
+
int ind = bb * T * H * C + t * H * C + hh * C + i;
|
|
39
|
+
__syncthreads();
|
|
40
|
+
q[i] = to_float(q_[ind]);
|
|
41
|
+
w[i] = __expf(-__expf(to_float(w_[ind])));
|
|
42
|
+
k[i] = to_float(k_[ind]);
|
|
43
|
+
a[i] = to_float(a_[ind]);
|
|
44
|
+
b[i] = to_float(b_[ind]);
|
|
45
|
+
__syncthreads();
|
|
46
|
+
|
|
47
|
+
float sa = 0.f;
|
|
48
|
+
#pragma unroll
|
|
49
|
+
for (int j = 0; j < C; ++j) sa += a[j] * state[j];
|
|
50
|
+
sa_[ind] = sa;
|
|
51
|
+
|
|
52
|
+
float v = to_float(v_[ind]);
|
|
53
|
+
float y = 0.f;
|
|
54
|
+
#pragma unroll
|
|
55
|
+
for (int j = 0; j < C; ++j) {
|
|
56
|
+
float &s = state[j];
|
|
57
|
+
s = s * w[j] + sa * b[j] + k[j] * v;
|
|
58
|
+
y += s * q[j];
|
|
59
|
+
}
|
|
60
|
+
y_[ind] = to_bf(y);
|
|
61
|
+
|
|
62
|
+
if ((t + 1) % _CHUNK_LEN_ == 0) {
|
|
63
|
+
int base = (bb * H + hh) * (T / _CHUNK_LEN_) * C * C +
|
|
64
|
+
(t / _CHUNK_LEN_) * C * C + i;
|
|
65
|
+
#pragma unroll
|
|
66
|
+
for (int j = 0; j < C; ++j) s_[base + j * C] = state[j];
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
__global__ void backward_kernel(int T, int H,
|
|
72
|
+
F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_,
|
|
73
|
+
float *s_, float *sa_, float *dht_, float *dh0_,
|
|
74
|
+
bf *dw_, bf *dq_, bf *dk_, bf *dv_, bf *da_, bf *db_) {
|
|
75
|
+
constexpr int C = _C_;
|
|
76
|
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
|
77
|
+
float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
|
|
78
|
+
int dht_base = ((bb * H + hh) * C + i) * C;
|
|
79
|
+
#pragma unroll
|
|
80
|
+
for (int j = 0; j < C; ++j) {
|
|
81
|
+
dstate[j] = dht_[dht_base + j];
|
|
82
|
+
dstateT[j] = dht_[dht_base + j];
|
|
83
|
+
}
|
|
84
|
+
__shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
|
|
85
|
+
float qi, wi, ki, ai, bi, dyi;
|
|
86
|
+
|
|
87
|
+
for (int t = T - 1; t >= 0; --t) {
|
|
88
|
+
int ind = bb * T * H * C + t * H * C + hh * C + i;
|
|
89
|
+
__syncthreads();
|
|
90
|
+
q[i] = qi = to_float(q_[ind]);
|
|
91
|
+
float wi_fac = -__expf(to_float(w_[ind]));
|
|
92
|
+
w[i] = wi = __expf(wi_fac);
|
|
93
|
+
k[i] = ki = to_float(k_[ind]);
|
|
94
|
+
a[i] = ai = to_float(a_[ind]);
|
|
95
|
+
b[i] = bi = to_float(b_[ind]);
|
|
96
|
+
v[i] = to_float(v_[ind]);
|
|
97
|
+
dy[i] = dyi = to_float(dy_[ind]);
|
|
98
|
+
sa[i] = sa_[ind];
|
|
99
|
+
__syncthreads();
|
|
100
|
+
|
|
101
|
+
if ((t + 1) % _CHUNK_LEN_ == 0) {
|
|
102
|
+
int base = (bb * H + hh) * (T / _CHUNK_LEN_) * C * C +
|
|
103
|
+
(t / _CHUNK_LEN_) * C * C + i * C;
|
|
104
|
+
#pragma unroll
|
|
105
|
+
for (int j = 0; j < C; ++j) stateT[j] = s_[base + j];
|
|
106
|
+
}
|
|
107
|
+
float dq = 0.f;
|
|
108
|
+
#pragma unroll
|
|
109
|
+
for (int j = 0; j < C; ++j) dq += stateT[j] * dy[j];
|
|
110
|
+
dq_[ind] = to_bf(dq);
|
|
111
|
+
|
|
112
|
+
float iwi = 1.f / (wi + 1e-6f);
|
|
113
|
+
#pragma unroll
|
|
114
|
+
for (int j = 0; j < C; ++j) {
|
|
115
|
+
stateT[j] = (stateT[j] - ki * v[j] - bi * sa[j]) * iwi;
|
|
116
|
+
dstate[j] += dyi * q[j];
|
|
117
|
+
dstateT[j] += qi * dy[j];
|
|
118
|
+
}
|
|
119
|
+
float dw = 0.f, dk = 0.f, dv = 0.f, db = 0.f, dSb = 0.f;
|
|
120
|
+
#pragma unroll
|
|
121
|
+
for (int j = 0; j < C; ++j) {
|
|
122
|
+
dw += dstateT[j] * stateT[j];
|
|
123
|
+
dk += dstateT[j] * v[j];
|
|
124
|
+
dv += dstate[j] * k[j];
|
|
125
|
+
dSb += dstate[j] * b[j];
|
|
126
|
+
db += dstateT[j] * sa[j];
|
|
127
|
+
}
|
|
128
|
+
dw_[ind] = to_bf(dw * wi * wi_fac);
|
|
129
|
+
dk_[ind] = to_bf(dk);
|
|
130
|
+
dv_[ind] = to_bf(dv);
|
|
131
|
+
db_[ind] = to_bf(db);
|
|
132
|
+
__syncthreads();
|
|
133
|
+
dSb_shared[i] = dSb;
|
|
134
|
+
__syncthreads();
|
|
135
|
+
float da = 0.f;
|
|
136
|
+
#pragma unroll
|
|
137
|
+
for (int j = 0; j < C; ++j) da += stateT[j] * dSb_shared[j];
|
|
138
|
+
da_[ind] = to_bf(da);
|
|
139
|
+
#pragma unroll
|
|
140
|
+
for (int j = 0; j < C; ++j) {
|
|
141
|
+
dstate[j] = dstate[j] * w[j] + dSb * a[j];
|
|
142
|
+
dstateT[j] = dstateT[j] * wi + ai * dSb_shared[j];
|
|
143
|
+
if (t == 0) dh0_[dht_base + j] = dstate[j];
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
/* -------------------- 宿主函数 -------------------- */
|
|
149
|
+
static ffi::Error WKV7FwdHost(
|
|
150
|
+
cudaStream_t stream,
|
|
151
|
+
ffi::Buffer<ffi::BF16> w,
|
|
152
|
+
ffi::Buffer<ffi::BF16> q,
|
|
153
|
+
ffi::Buffer<ffi::BF16> k,
|
|
154
|
+
ffi::Buffer<ffi::BF16> v,
|
|
155
|
+
ffi::Buffer<ffi::BF16> z,
|
|
156
|
+
ffi::Buffer<ffi::BF16> a,
|
|
157
|
+
ffi::Buffer<ffi::F32> h0, // 保持 float
|
|
158
|
+
ffi::ResultBuffer<ffi::BF16> y,
|
|
159
|
+
ffi::ResultBuffer<ffi::F32> s,
|
|
160
|
+
ffi::ResultBuffer<ffi::F32> sa)
|
|
161
|
+
{
|
|
162
|
+
constexpr int C = _C_;
|
|
163
|
+
auto dims = w.dimensions();
|
|
164
|
+
int B = dims[0], T = dims[1], H = dims[2];
|
|
165
|
+
dim3 block(C);
|
|
166
|
+
dim3 grid(H, B);
|
|
167
|
+
|
|
168
|
+
forward_kernel<<<grid, block, 0, stream>>>(
|
|
169
|
+
T, H,
|
|
170
|
+
reinterpret_cast<bf *>(w.typed_data()),
|
|
171
|
+
reinterpret_cast<bf *>(q.typed_data()),
|
|
172
|
+
reinterpret_cast<bf *>(k.typed_data()),
|
|
173
|
+
reinterpret_cast<bf *>(v.typed_data()),
|
|
174
|
+
reinterpret_cast<bf *>(z.typed_data()),
|
|
175
|
+
reinterpret_cast<bf *>(a.typed_data()),
|
|
176
|
+
reinterpret_cast<bf *>(y->typed_data()),
|
|
177
|
+
s->typed_data(),
|
|
178
|
+
sa->typed_data(),
|
|
179
|
+
h0.typed_data());
|
|
180
|
+
|
|
181
|
+
cudaError_t err = cudaGetLastError();
|
|
182
|
+
if (err != cudaSuccess)
|
|
183
|
+
return ffi::Error::Internal(
|
|
184
|
+
std::string("CUDA forward_kernel error: ") + cudaGetErrorString(err));
|
|
185
|
+
return ffi::Error::Success();
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
static ffi::Error WKV7BwdHost(
|
|
189
|
+
cudaStream_t stream,
|
|
190
|
+
ffi::Buffer<ffi::BF16> w,
|
|
191
|
+
ffi::Buffer<ffi::BF16> q,
|
|
192
|
+
ffi::Buffer<ffi::BF16> k,
|
|
193
|
+
ffi::Buffer<ffi::BF16> v,
|
|
194
|
+
ffi::Buffer<ffi::BF16> z,
|
|
195
|
+
ffi::Buffer<ffi::BF16> a,
|
|
196
|
+
ffi::Buffer<ffi::BF16> dy,
|
|
197
|
+
ffi::Buffer<ffi::F32> s,
|
|
198
|
+
ffi::Buffer<ffi::F32> sa,
|
|
199
|
+
ffi::Buffer<ffi::F32> dht,
|
|
200
|
+
ffi::ResultBuffer<ffi::F32> dh0,
|
|
201
|
+
ffi::ResultBuffer<ffi::BF16> dw,
|
|
202
|
+
ffi::ResultBuffer<ffi::BF16> dq,
|
|
203
|
+
ffi::ResultBuffer<ffi::BF16> dk,
|
|
204
|
+
ffi::ResultBuffer<ffi::BF16> dv,
|
|
205
|
+
ffi::ResultBuffer<ffi::BF16> da,
|
|
206
|
+
ffi::ResultBuffer<ffi::BF16> db)
|
|
207
|
+
{
|
|
208
|
+
auto dims = w.dimensions();
|
|
209
|
+
int B = dims[0], T = dims[1], H = dims[2];
|
|
210
|
+
constexpr int C = _C_;
|
|
211
|
+
dim3 block(C);
|
|
212
|
+
dim3 grid(H, B);
|
|
213
|
+
|
|
214
|
+
backward_kernel<<<grid, block, 0, stream>>>(
|
|
215
|
+
T, H,
|
|
216
|
+
reinterpret_cast<bf *>(w.typed_data()),
|
|
217
|
+
reinterpret_cast<bf *>(q.typed_data()),
|
|
218
|
+
reinterpret_cast<bf *>(k.typed_data()),
|
|
219
|
+
reinterpret_cast<bf *>(v.typed_data()),
|
|
220
|
+
reinterpret_cast<bf *>(z.typed_data()),
|
|
221
|
+
reinterpret_cast<bf *>(a.typed_data()),
|
|
222
|
+
reinterpret_cast<bf *>(dy.typed_data()),
|
|
223
|
+
s.typed_data(),
|
|
224
|
+
sa.typed_data(),
|
|
225
|
+
dht.typed_data(),
|
|
226
|
+
dh0->typed_data(),
|
|
227
|
+
reinterpret_cast<bf *>(dw->typed_data()),
|
|
228
|
+
reinterpret_cast<bf *>(dq->typed_data()),
|
|
229
|
+
reinterpret_cast<bf *>(dk->typed_data()),
|
|
230
|
+
reinterpret_cast<bf *>(dv->typed_data()),
|
|
231
|
+
reinterpret_cast<bf *>(da->typed_data()),
|
|
232
|
+
reinterpret_cast<bf *>(db->typed_data()));
|
|
233
|
+
|
|
234
|
+
cudaError_t err = cudaGetLastError();
|
|
235
|
+
if (err != cudaSuccess)
|
|
236
|
+
return ffi::Error::Internal(
|
|
237
|
+
std::string("CUDA backward_kernel error: ") + cudaGetErrorString(err));
|
|
238
|
+
return ffi::Error::Success();
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
/* -------------------- 注册符号 -------------------- */
|
|
242
|
+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
|
243
|
+
Wkv7Fwd, WKV7FwdHost,
|
|
244
|
+
ffi::Ffi::Bind()
|
|
245
|
+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
|
|
246
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // w
|
|
247
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // q
|
|
248
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // k
|
|
249
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // v
|
|
250
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // z
|
|
251
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // a
|
|
252
|
+
.Arg<ffi::Buffer<ffi::F32>>() // h0 (float)
|
|
253
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // y
|
|
254
|
+
.Ret<ffi::Buffer<ffi::F32>>() // s
|
|
255
|
+
.Ret<ffi::Buffer<ffi::F32>>() // sa
|
|
256
|
+
, {ffi::Traits::kCmdBufferCompatible});
|
|
257
|
+
|
|
258
|
+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
|
259
|
+
Wkv7Bwd, WKV7BwdHost,
|
|
260
|
+
ffi::Ffi::Bind()
|
|
261
|
+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
|
|
262
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // w
|
|
263
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // q
|
|
264
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // k
|
|
265
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // v
|
|
266
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // z
|
|
267
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // a
|
|
268
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // dy
|
|
269
|
+
.Arg<ffi::Buffer<ffi::F32>>() // s
|
|
270
|
+
.Arg<ffi::Buffer<ffi::F32>>() // sa
|
|
271
|
+
.Arg<ffi::Buffer<ffi::F32>>() // dht
|
|
272
|
+
.Ret<ffi::Buffer<ffi::F32>>() // dh0
|
|
273
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // dw
|
|
274
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // dq
|
|
275
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // dk
|
|
276
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // dv
|
|
277
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // da
|
|
278
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // db
|
|
279
|
+
, {ffi::Traits::kCmdBufferCompatible});
|