rwkv-ops 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- rwkv_ops/__init__.py +5 -6
- rwkv_ops/rwkv6_kernel/__init__.py +0 -6
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +21 -23
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +14 -10
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +4 -4
- rwkv_ops/rwkv7_kernel/__init__.py +80 -29
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +279 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +237 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +6 -5
- rwkv_ops/rwkv7_kernel/native_keras_op.py +5 -6
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +123 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +165 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +35 -0
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/METADATA +28 -27
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/RECORD +30 -13
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/WHEEL +1 -2
- rwkv_ops-0.2.2.dist-info/top_level.txt +0 -1
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info/licenses}/LICENSE.txt +0 -0
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JAX 版 RWKV7 wkv kernel + generalized_delta_rule
|
|
3
|
+
延迟编译 CUDA 扩展,接口与 Torch 版本 1:1 对齐
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
import pathlib
|
|
8
|
+
import subprocess
|
|
9
|
+
import ctypes
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
from typing import Optional, Tuple, Union
|
|
13
|
+
from jax.ad_checkpoint import checkpoint_policies as cp
|
|
14
|
+
CHUNK_LEN = 16 # 这是一个常数
|
|
15
|
+
# ---------- 延迟编译(改到当前目录) ----------
|
|
16
|
+
_CURRENT_DIR = pathlib.Path(
|
|
17
|
+
__file__
|
|
18
|
+
).parent.absolute() # rwkv_ops/rwkv7_kernel/jax_cuda_kernel
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_jax_generalized_delta_rule(HEAD_SIZE=64):
|
|
22
|
+
_BUILD_DIR = _CURRENT_DIR / f"build_{HEAD_SIZE}"
|
|
23
|
+
_SO_PATH = _CURRENT_DIR / f"build_{HEAD_SIZE}/wkv7.so"
|
|
24
|
+
|
|
25
|
+
def _ensure_compiled() -> pathlib.Path:
|
|
26
|
+
"""首次调用时编译 CUDA 扩展,产出放在当前源码目录"""
|
|
27
|
+
if _SO_PATH.exists():
|
|
28
|
+
return _SO_PATH
|
|
29
|
+
|
|
30
|
+
print("[rwkv7_jax] First use – compiling CUDA kernel…")
|
|
31
|
+
src_dir = _CURRENT_DIR
|
|
32
|
+
build_dir = _BUILD_DIR
|
|
33
|
+
build_dir.mkdir(exist_ok=True)
|
|
34
|
+
|
|
35
|
+
# ---------- 关键:拿到 JAX 的 XLA 头文件路径 ----------
|
|
36
|
+
xla_include_dir = jax.ffi.include_dir() # 方案 3 核心 API
|
|
37
|
+
if not xla_include_dir:
|
|
38
|
+
raise RuntimeError("jax.ffi.include_dir() 返回空,请检查 JAX >= 0.4.31")
|
|
39
|
+
|
|
40
|
+
# ---------- 关键:把数值稳定性 flag 写死 ----------
|
|
41
|
+
cuda_flags = [
|
|
42
|
+
"-ftz=true", # flush sub-normal to zero
|
|
43
|
+
"-prec-div=false", # 更快除法,避免特殊路径
|
|
44
|
+
"-prec-sqrt=false", # 更快开方
|
|
45
|
+
"--use_fast_math", # 统一 fast math
|
|
46
|
+
"-O3",
|
|
47
|
+
"-Xptxas=-O3",
|
|
48
|
+
"-res-usage",
|
|
49
|
+
"--extra-device-vectorization",
|
|
50
|
+
"-D_C_=64",
|
|
51
|
+
f"-D_C_={HEAD_SIZE}",
|
|
52
|
+
f"-D_CHUNK_LEN_={CHUNK_LEN}",
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
# 1. 配置
|
|
56
|
+
cmake_args = [
|
|
57
|
+
"cmake",
|
|
58
|
+
"-S",
|
|
59
|
+
str(src_dir),
|
|
60
|
+
"-B",
|
|
61
|
+
str(build_dir),
|
|
62
|
+
"-DCMAKE_BUILD_TYPE=Release",
|
|
63
|
+
f"-DCMAKE_INSTALL_PREFIX={_CURRENT_DIR}",
|
|
64
|
+
f"-DXLA_INCLUDE_DIR={xla_include_dir}", # 传给 CMake
|
|
65
|
+
f"-DCMAKE_CUDA_FLAGS={' '.join(cuda_flags)}",
|
|
66
|
+
]
|
|
67
|
+
subprocess.check_call(cmake_args)
|
|
68
|
+
|
|
69
|
+
# 2. 构建
|
|
70
|
+
subprocess.check_call(["cmake", "--build", str(build_dir), "-j"])
|
|
71
|
+
|
|
72
|
+
# 3. 安装(把 .so 拷贝到当前目录)
|
|
73
|
+
subprocess.check_call(["cmake", "--install", str(build_dir)])
|
|
74
|
+
|
|
75
|
+
if not _SO_PATH.exists():
|
|
76
|
+
raise RuntimeError("Compilation failed – wkv7.so not found.")
|
|
77
|
+
|
|
78
|
+
print("[rwkv7_jax] Compilation finished – output at", _SO_PATH)
|
|
79
|
+
return _SO_PATH
|
|
80
|
+
|
|
81
|
+
# 注册 FFI 符号
|
|
82
|
+
_lib = ctypes.CDLL(_ensure_compiled())
|
|
83
|
+
jax.ffi.register_ffi_target(
|
|
84
|
+
"wkv7_fwd", jax.ffi.pycapsule(_lib.Wkv7Fwd), platform="CUDA"
|
|
85
|
+
)
|
|
86
|
+
jax.ffi.register_ffi_target(
|
|
87
|
+
"wkv7_bwd", jax.ffi.pycapsule(_lib.Wkv7Bwd), platform="CUDA"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# ---------- 工具 ----------
|
|
91
|
+
def _transpose_head(x: jnp.ndarray, head_first: bool) -> jnp.ndarray:
|
|
92
|
+
"""(B, T, H, K) <-> (B, H, T, K)"""
|
|
93
|
+
x = jnp.asarray(x, dtype=jnp.bfloat16)
|
|
94
|
+
if head_first:
|
|
95
|
+
return jnp.transpose(x, (0, 2, 1, 3))
|
|
96
|
+
return x
|
|
97
|
+
|
|
98
|
+
# ---------- 前向 + 反向 kernel ----------
|
|
99
|
+
|
|
100
|
+
def _wkv7_kernel(
|
|
101
|
+
w: jnp.ndarray,
|
|
102
|
+
q: jnp.ndarray,
|
|
103
|
+
k: jnp.ndarray,
|
|
104
|
+
v: jnp.ndarray,
|
|
105
|
+
a: jnp.ndarray,
|
|
106
|
+
b: jnp.ndarray,
|
|
107
|
+
h0: jnp.ndarray,
|
|
108
|
+
):
|
|
109
|
+
"""
|
|
110
|
+
内部 kernel 接口
|
|
111
|
+
参数顺序与 wkv7_ffi.cc 声明完全一致:
|
|
112
|
+
w,q,k,v,z,a,b -> y,s,sa
|
|
113
|
+
"""
|
|
114
|
+
B, T, H, K = q.shape
|
|
115
|
+
dtype = q.dtype
|
|
116
|
+
chunk_num = int(T // CHUNK_LEN)
|
|
117
|
+
out_type = jax.ShapeDtypeStruct((B, T, H, K), dtype)
|
|
118
|
+
s_type = jax.ShapeDtypeStruct((B, H, chunk_num, K, K), jnp.float32)
|
|
119
|
+
sa_type = jax.ShapeDtypeStruct((B, T, H, K), jnp.float32)
|
|
120
|
+
|
|
121
|
+
y, s, sa = jax.ffi.ffi_call(
|
|
122
|
+
"wkv7_fwd", (out_type, s_type, sa_type), vmap_method="broadcast_all"
|
|
123
|
+
)(w, q, k, v, a, b, h0)
|
|
124
|
+
|
|
125
|
+
return y, s, sa
|
|
126
|
+
|
|
127
|
+
@jax.custom_vjp
|
|
128
|
+
def wk7_kernel(
|
|
129
|
+
w: jnp.ndarray,
|
|
130
|
+
q: jnp.ndarray,
|
|
131
|
+
k: jnp.ndarray,
|
|
132
|
+
v: jnp.ndarray,
|
|
133
|
+
a: jnp.ndarray,
|
|
134
|
+
b: jnp.ndarray,
|
|
135
|
+
h0: jnp.ndarray,
|
|
136
|
+
):
|
|
137
|
+
y, s, sa = _wkv7_kernel(w, q, k, v, a, b, h0)
|
|
138
|
+
finnal_state = s[:, :, -1]
|
|
139
|
+
return (y, jnp.transpose(finnal_state, [0, 1, 3, 2]))
|
|
140
|
+
|
|
141
|
+
# 前向定义
|
|
142
|
+
def _fwd(
|
|
143
|
+
w: jnp.ndarray,
|
|
144
|
+
q: jnp.ndarray,
|
|
145
|
+
k: jnp.ndarray,
|
|
146
|
+
v: jnp.ndarray,
|
|
147
|
+
a: jnp.ndarray,
|
|
148
|
+
b: jnp.ndarray,
|
|
149
|
+
h0: jnp.ndarray,
|
|
150
|
+
):
|
|
151
|
+
y, s, sa = _wkv7_kernel(w, q, k, v, a, b, h0)
|
|
152
|
+
finnal_state = s[:, :, -1]
|
|
153
|
+
return (y, jnp.transpose(finnal_state, [0, 1, 3, 2])), (w, q, k, v, a, b, s, sa)
|
|
154
|
+
|
|
155
|
+
def _wkv7_bwd_kernel(w, q, k, v, a, b, dy, s, sa, dht):
|
|
156
|
+
dh0_type = jax.ShapeDtypeStruct(dht.shape, dht.dtype)
|
|
157
|
+
dw_type = jax.ShapeDtypeStruct(w.shape, w.dtype)
|
|
158
|
+
dq_type = jax.ShapeDtypeStruct(q.shape, q.dtype)
|
|
159
|
+
dk_type = jax.ShapeDtypeStruct(k.shape, k.dtype)
|
|
160
|
+
dv_type = jax.ShapeDtypeStruct(v.shape, v.dtype)
|
|
161
|
+
da_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
|
|
162
|
+
db_type = jax.ShapeDtypeStruct(b.shape, b.dtype)
|
|
163
|
+
|
|
164
|
+
dh0, dw, dq, dk, dv, da, db = jax.ffi.ffi_call(
|
|
165
|
+
"wkv7_bwd",
|
|
166
|
+
(dh0_type, dw_type, dq_type, dk_type, dv_type, da_type, db_type),
|
|
167
|
+
vmap_method="broadcast_all",
|
|
168
|
+
)(w, q, k, v, a, b, dy, s, sa, dht)
|
|
169
|
+
|
|
170
|
+
return dw, dq, dk, dv, da, db, dh0
|
|
171
|
+
|
|
172
|
+
# 反向定义
|
|
173
|
+
def _bwd(res, grads):
|
|
174
|
+
w, q, k, v, a, b, s, sa = res
|
|
175
|
+
dy, dht = grads
|
|
176
|
+
dy = jnp.asarray(dy, jnp.bfloat16)
|
|
177
|
+
# 调用反向 kernel
|
|
178
|
+
return _wkv7_bwd_kernel(w, q, k, v, a, b, dy, s, sa, dht)
|
|
179
|
+
|
|
180
|
+
wk7_kernel.defvjp(_fwd, _bwd)
|
|
181
|
+
|
|
182
|
+
def generalized_delta_rule(
|
|
183
|
+
r: jnp.ndarray,
|
|
184
|
+
w: jnp.ndarray,
|
|
185
|
+
k: jnp.ndarray,
|
|
186
|
+
v: jnp.ndarray,
|
|
187
|
+
a: jnp.ndarray,
|
|
188
|
+
b: jnp.ndarray,
|
|
189
|
+
initial_state: Optional[jnp.ndarray] = None,
|
|
190
|
+
output_final_state: bool = True,
|
|
191
|
+
head_first: bool = False,
|
|
192
|
+
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
|
|
193
|
+
"""
|
|
194
|
+
广义 delta 规则,接口与 Torch 实现完全一致
|
|
195
|
+
参数:
|
|
196
|
+
r,w,k,v,a,b: 输入张量,形状 (B, T, H, K) 或 (B, H, T, K) 当 head_first=True
|
|
197
|
+
initial_state: 可选 (B, H, K, K) 初始状态,None 则零初始化
|
|
198
|
+
output_final_state: 是否同时返回最后状态
|
|
199
|
+
head_first: 是否将 head 维提前
|
|
200
|
+
chunk_len: 必须整除 T,默认 16
|
|
201
|
+
返回:
|
|
202
|
+
out: (B, T, H, K) 与输入 dtype 一致
|
|
203
|
+
last_state: (B, H, K, K) 当 output_final_state=True
|
|
204
|
+
"""
|
|
205
|
+
# 统一转 (B, T, H, K)
|
|
206
|
+
dtype = r.dtype
|
|
207
|
+
r = _transpose_head(r, head_first)
|
|
208
|
+
w = _transpose_head(w, head_first)
|
|
209
|
+
k = _transpose_head(k, head_first)
|
|
210
|
+
v = _transpose_head(v, head_first)
|
|
211
|
+
a = _transpose_head(a, head_first)
|
|
212
|
+
b = _transpose_head(b, head_first)
|
|
213
|
+
|
|
214
|
+
B, T, H, K = r.shape
|
|
215
|
+
if T % CHUNK_LEN:
|
|
216
|
+
raise ValueError(
|
|
217
|
+
f"Sequence length T={T} must be divisible by chunk_len={CHUNK_LEN}"
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# 处理初始状态
|
|
221
|
+
if initial_state is None:
|
|
222
|
+
h0 = jnp.zeros((B, H, K, K), jnp.float32)
|
|
223
|
+
else:
|
|
224
|
+
h0 = jnp.asarray(initial_state, jnp.float32)
|
|
225
|
+
|
|
226
|
+
# 调用 kernel
|
|
227
|
+
|
|
228
|
+
out, last_state = jax.checkpoint(
|
|
229
|
+
wk7_kernel, policy=cp.save_anything_except_these_names(())
|
|
230
|
+
)(w, r, k, v, a, b, h0)
|
|
231
|
+
out = jnp.asarray(out, dtype) # 保证输出 dtype 与输入一致
|
|
232
|
+
|
|
233
|
+
if output_final_state:
|
|
234
|
+
return out, last_state
|
|
235
|
+
return out
|
|
236
|
+
|
|
237
|
+
return generalized_delta_rule, _wkv7_kernel, _wkv7_bwd_kernel
|
rwkv_ops/rwkv7_kernel/jax_op.py
CHANGED
|
@@ -14,9 +14,11 @@ from .jax_kernel.chunk_o_fwd import chunk_dplr_fwd_o
|
|
|
14
14
|
from .jax_kernel.wy_fast_bwd import chunk_dplr_bwd_wy
|
|
15
15
|
from .jax_kernel.wy_fast_fwd import prepare_wy_repr_fwd
|
|
16
16
|
from .jax_kernel.cumsum import chunk_rwkv6_fwd_cumsum
|
|
17
|
-
from jax.ad_checkpoint import checkpoint_policies
|
|
17
|
+
from jax.ad_checkpoint import checkpoint_policies as cp
|
|
18
|
+
|
|
18
19
|
CHUNKSIZE = 16
|
|
19
20
|
|
|
21
|
+
|
|
20
22
|
def chunk_dplr_fwd(
|
|
21
23
|
q: jax.Array,
|
|
22
24
|
k: jax.Array,
|
|
@@ -307,7 +309,6 @@ def transpose_head(x, head_first):
|
|
|
307
309
|
return x
|
|
308
310
|
|
|
309
311
|
|
|
310
|
-
# @partial(jax.jit, static_argnames=['initial_state',"output_final_state","head_first","use_chunk"])
|
|
311
312
|
def generalized_delta_rule(
|
|
312
313
|
r: jax.Array,
|
|
313
314
|
w: jax.Array,
|
|
@@ -365,7 +366,9 @@ def generalized_delta_rule(
|
|
|
365
366
|
else:
|
|
366
367
|
assert log_w is not None, "Either w or log_w must be provided!"
|
|
367
368
|
log_w = transpose_head(log_w, head_first)
|
|
368
|
-
o, final_state =
|
|
369
|
+
o, final_state = jax.checkpoint(
|
|
370
|
+
chunk_dplr, policy=cp.save_anything_except_these_names(())
|
|
371
|
+
)(
|
|
369
372
|
r=r,
|
|
370
373
|
k=k,
|
|
371
374
|
v=v,
|
|
@@ -377,5 +380,3 @@ def generalized_delta_rule(
|
|
|
377
380
|
if output_final_state:
|
|
378
381
|
return jnp.asarray(o, DTYPE), final_state
|
|
379
382
|
return jnp.asarray(o, DTYPE)
|
|
380
|
-
|
|
381
|
-
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import keras
|
|
2
1
|
from keras import ops
|
|
3
2
|
|
|
4
3
|
|
|
@@ -62,8 +61,9 @@ def generalized_delta_rule(
|
|
|
62
61
|
if ops.shape(state)[0] == 1:
|
|
63
62
|
state = ops.broadcast_to(state, (B, H, N, N))
|
|
64
63
|
else:
|
|
65
|
-
state = ops.zeros((B, H, N, N)
|
|
66
|
-
|
|
64
|
+
state = ops.zeros((B, H, N, N))
|
|
65
|
+
state = ops.cast(state, "float32")
|
|
66
|
+
out = ops.zeros((B, T, H, N), DTYPE)
|
|
67
67
|
|
|
68
68
|
def step(t, inputs):
|
|
69
69
|
"""
|
|
@@ -83,9 +83,8 @@ def generalized_delta_rule(
|
|
|
83
83
|
aa = ops.reshape(a[:, t, :], (B, H, N, 1))
|
|
84
84
|
bb = ops.reshape(b[:, t, :], (B, H, 1, N))
|
|
85
85
|
state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
)
|
|
86
|
+
o = ops.cast((state @ rr), out.dtype)
|
|
87
|
+
out = ops.slice_update(out, [0, t, 0, 0], ops.reshape(o, (B, 1, H, N)))
|
|
89
88
|
return [state, out]
|
|
90
89
|
|
|
91
90
|
state, out = ops.fori_loop(0, T, step, [state, out])
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""
|
|
2
|
+
TensorFlow 版 generalized_delta_rule
|
|
3
|
+
前向用 tf.py_function 调 JAX CUDA 内核,反向同样走 JAX。
|
|
4
|
+
可 @tf.function 编译,可 tf.GradientTape 训练。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
from typing import Optional, Tuple
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
from .jax_cuda_kernel.wkv7_jax import get_jax_generalized_delta_rule
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def transpose_head(x, head_first: bool):
|
|
14
|
+
"""(B, T, H, K) <-> (B, H, T, K)"""
|
|
15
|
+
x = tf.cast(x, dtype=tf.float32)
|
|
16
|
+
if head_first:
|
|
17
|
+
return tf.transpose(x, (0, 2, 1, 3))
|
|
18
|
+
return x
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_tf_generalized_delta_rule(HEAD_SIZE=64):
|
|
22
|
+
_, _wkv7_kernel, _wkv7_bwd_kernel = get_jax_generalized_delta_rule(HEAD_SIZE)
|
|
23
|
+
|
|
24
|
+
# ---------- 底层 kernel 包装 ----------
|
|
25
|
+
@tf.py_function(Tout=[tf.bfloat16, tf.float32, tf.float32])
|
|
26
|
+
def _tf_wkv7_fwd(w, q, k, v, a, b, h0):
|
|
27
|
+
"""tf.py_function 包装 JAX 前向"""
|
|
28
|
+
y, s, sa = _wkv7_kernel(
|
|
29
|
+
jnp.asarray(w, jnp.bfloat16),
|
|
30
|
+
jnp.asarray(q, jnp.bfloat16),
|
|
31
|
+
jnp.asarray(k, jnp.bfloat16),
|
|
32
|
+
jnp.asarray(v, jnp.bfloat16),
|
|
33
|
+
jnp.asarray(a, jnp.bfloat16),
|
|
34
|
+
jnp.asarray(b, jnp.bfloat16),
|
|
35
|
+
jnp.asarray(h0, jnp.float32),
|
|
36
|
+
)
|
|
37
|
+
return (
|
|
38
|
+
tf.convert_to_tensor(y, tf.bfloat16),
|
|
39
|
+
tf.convert_to_tensor(s, tf.float32),
|
|
40
|
+
tf.convert_to_tensor(sa, tf.float32),
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
@tf.py_function(Tout=[tf.bfloat16] * 6 + [tf.float32])
|
|
44
|
+
def _tf_wkv7_bwd(w, q, k, v, a, b, dy, s, sa, dht):
|
|
45
|
+
"""tf.py_function 包装 JAX 反向"""
|
|
46
|
+
dw, dq, dk, dv, da, db, dh0 = _wkv7_bwd_kernel(
|
|
47
|
+
jnp.asarray(w, jnp.bfloat16),
|
|
48
|
+
jnp.asarray(q, jnp.bfloat16),
|
|
49
|
+
jnp.asarray(k, jnp.bfloat16),
|
|
50
|
+
jnp.asarray(v, jnp.bfloat16),
|
|
51
|
+
jnp.asarray(a, jnp.bfloat16),
|
|
52
|
+
jnp.asarray(b, jnp.bfloat16),
|
|
53
|
+
jnp.asarray(dy, jnp.bfloat16),
|
|
54
|
+
jnp.asarray(s, jnp.float32),
|
|
55
|
+
jnp.asarray(sa, jnp.float32),
|
|
56
|
+
jnp.asarray(dht, jnp.bfloat16),
|
|
57
|
+
)
|
|
58
|
+
return tuple(
|
|
59
|
+
tf.convert_to_tensor(g, dtype)
|
|
60
|
+
for g, dtype in zip((dw, dq, dk, dv, da, db), [tf.bfloat16] * 6)
|
|
61
|
+
) + (tf.convert_to_tensor(dh0, tf.float32),)
|
|
62
|
+
|
|
63
|
+
# ---------- 带梯度的前向 ----------
|
|
64
|
+
@tf.custom_gradient
|
|
65
|
+
def _wk7_tf(w, q, k, v, a, b, h0):
|
|
66
|
+
y, s, sa = _tf_wkv7_fwd(w, q, k, v, a, b, h0)
|
|
67
|
+
|
|
68
|
+
def grad(dy, dht):
|
|
69
|
+
# dy 上层传来的 loss 对 y 的梯度
|
|
70
|
+
# dht 对最后状态的梯度(没有就传 0)
|
|
71
|
+
if dht is None:
|
|
72
|
+
dht = tf.zeros_like(h0)
|
|
73
|
+
grads = _tf_wkv7_bwd(w, q, k, v, a, b, dy, s, sa, dht)
|
|
74
|
+
return grads # (dw, dq, dk, dv, da, db, dh0)
|
|
75
|
+
|
|
76
|
+
final_state = s[:, :, -1] # (B, H, K, K)
|
|
77
|
+
final_state = tf.transpose(final_state, [0, 1, 3, 2]) # 与 JAX 对齐
|
|
78
|
+
return (y, final_state), grad
|
|
79
|
+
|
|
80
|
+
# ---------- 用户接口 ----------
|
|
81
|
+
def generalized_delta_rule(
|
|
82
|
+
r: tf.Tensor, # (B, T, H, K) 或 (B, H, T, K)
|
|
83
|
+
w: tf.Tensor,
|
|
84
|
+
k: tf.Tensor,
|
|
85
|
+
v: tf.Tensor,
|
|
86
|
+
a: tf.Tensor,
|
|
87
|
+
b: tf.Tensor,
|
|
88
|
+
initial_state: Optional[tf.Tensor] = None,
|
|
89
|
+
output_final_state: bool = True,
|
|
90
|
+
head_first: bool = False,
|
|
91
|
+
chunk_len: int = 16,
|
|
92
|
+
) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
93
|
+
"""
|
|
94
|
+
与 JAX 版接口 1:1 对齐,返回 (out, last_state)
|
|
95
|
+
可 @tf.function compile,可 tf.GradientTape 训练
|
|
96
|
+
"""
|
|
97
|
+
dtype = r.dtype
|
|
98
|
+
|
|
99
|
+
r = transpose_head(r, head_first)
|
|
100
|
+
w = transpose_head(w, head_first)
|
|
101
|
+
k = transpose_head(k, head_first)
|
|
102
|
+
v = transpose_head(v, head_first)
|
|
103
|
+
a = transpose_head(a, head_first)
|
|
104
|
+
b = transpose_head(b, head_first)
|
|
105
|
+
|
|
106
|
+
B, T, H, K = tf.unstack(tf.shape(r), num=4)
|
|
107
|
+
if T % chunk_len != 0:
|
|
108
|
+
raise ValueError(f"T={T} must be divisible by chunk_len={chunk_len}")
|
|
109
|
+
|
|
110
|
+
if initial_state is None:
|
|
111
|
+
h0 = tf.zeros([B, H, K, K], dtype=tf.float32)
|
|
112
|
+
else:
|
|
113
|
+
h0 = tf.cast(initial_state, tf.float32)
|
|
114
|
+
|
|
115
|
+
# 带梯度前向
|
|
116
|
+
out, last_state = _wk7_tf(w, r, k, v, a, b, h0)
|
|
117
|
+
|
|
118
|
+
# 转回用户期望 dtype
|
|
119
|
+
out = tf.cast(out, dtype)
|
|
120
|
+
|
|
121
|
+
return (out, last_state) if output_final_state else out
|
|
122
|
+
|
|
123
|
+
return generalized_delta_rule, _tf_wkv7_fwd, _tf_wkv7_bwd
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
#include <cuda_bf16.h>
|
|
2
|
+
#include <assert.h>
|
|
3
|
+
|
|
4
|
+
using bf = __nv_bfloat16;
|
|
5
|
+
|
|
6
|
+
__device__ inline float to_float(const bf & u) {
|
|
7
|
+
return __bfloat162float(u);
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
__device__ inline bf to_bf(const float & u) {
|
|
11
|
+
return __float2bfloat16_rn(u);
|
|
12
|
+
}
|
|
13
|
+
typedef bf * __restrict__ F_;
|
|
14
|
+
|
|
15
|
+
__global__ void forward_kernel(int T, int H,
|
|
16
|
+
F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
|
|
17
|
+
bf* y_, float* s_, float* sa_, float* h0_) {
|
|
18
|
+
constexpr int C = _C_;
|
|
19
|
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
|
20
|
+
float state[C] = {0};
|
|
21
|
+
__shared__ float q[C], k[C], w[C], a[C], b[C];
|
|
22
|
+
int h0_base =( (bb*H + hh)*C + i)*C;
|
|
23
|
+
#pragma unroll
|
|
24
|
+
for (int j = 0; j < C; j++) {
|
|
25
|
+
state[j] = h0_[h0_base + j];
|
|
26
|
+
}
|
|
27
|
+
for (int t = 0; t < T; t++) {
|
|
28
|
+
int ind = bb*T*H*C + t*H*C + hh * C + i;
|
|
29
|
+
__syncthreads();
|
|
30
|
+
q[i] = to_float(q_[ind]);
|
|
31
|
+
w[i] = __expf(-__expf(to_float(w_[ind])));
|
|
32
|
+
k[i] = to_float(k_[ind]);
|
|
33
|
+
a[i] = to_float(a_[ind]);
|
|
34
|
+
b[i] = to_float(b_[ind]);
|
|
35
|
+
__syncthreads();
|
|
36
|
+
float sa = 0;
|
|
37
|
+
#pragma unroll
|
|
38
|
+
for (int j = 0; j < C; j++) {
|
|
39
|
+
sa += a[j] * state[j];
|
|
40
|
+
}
|
|
41
|
+
sa_[ind] = sa;
|
|
42
|
+
float v = to_float(v_[ind]);
|
|
43
|
+
float y = 0;
|
|
44
|
+
#pragma unroll
|
|
45
|
+
for (int j = 0; j < C; j++) {
|
|
46
|
+
float& s = state[j];
|
|
47
|
+
s = s * w[j] + sa * b[j] + k[j] * v;
|
|
48
|
+
y += s * q[j];
|
|
49
|
+
}
|
|
50
|
+
y_[ind] = to_bf(y);
|
|
51
|
+
|
|
52
|
+
if ((t+1)%_CHUNK_LEN_ == 0) {
|
|
53
|
+
int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i;
|
|
54
|
+
#pragma unroll
|
|
55
|
+
for (int j = 0; j < C; j++) {
|
|
56
|
+
s_[base + j*C] = state[j];
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
__global__ void backward_kernel(int T, int H,
|
|
62
|
+
F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_,
|
|
63
|
+
float * __restrict__ s_, float * __restrict__ sa_,
|
|
64
|
+
float * __restrict__ dht_,float * __restrict__ dh0_,
|
|
65
|
+
bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) {
|
|
66
|
+
constexpr int C = _C_;
|
|
67
|
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
|
68
|
+
float stateT[C] = {
|
|
69
|
+
0
|
|
70
|
+
}
|
|
71
|
+
, dstate[C] = {
|
|
72
|
+
0
|
|
73
|
+
}
|
|
74
|
+
, dstateT[C] = {
|
|
75
|
+
0
|
|
76
|
+
}
|
|
77
|
+
;
|
|
78
|
+
int dht_base =( (bb*H + hh)*C + i)*C;
|
|
79
|
+
#pragma unroll
|
|
80
|
+
for (int j = 0; j < C; j++) {
|
|
81
|
+
dstate[j] = dht_[dht_base + j];
|
|
82
|
+
dstateT[j] = dht_[dht_base + j];
|
|
83
|
+
}
|
|
84
|
+
__shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
|
|
85
|
+
float qi, wi, ki, ai, bi, dyi;
|
|
86
|
+
|
|
87
|
+
for (int t = T-1; t >= 0; t--) {
|
|
88
|
+
int ind = bb*T*H*C + t*H*C + hh * C + i;
|
|
89
|
+
__syncthreads();
|
|
90
|
+
q[i] = qi = to_float(q_[ind]);
|
|
91
|
+
float wi_fac = -__expf(to_float(w_[ind]));
|
|
92
|
+
w[i] = wi = __expf(wi_fac);
|
|
93
|
+
k[i] = ki = to_float(k_[ind]);
|
|
94
|
+
a[i] = ai = to_float(a_[ind]);
|
|
95
|
+
b[i] = bi = to_float(b_[ind]);
|
|
96
|
+
v[i] = to_float(v_[ind]);
|
|
97
|
+
dy[i] = dyi = to_float(dy_[ind]);
|
|
98
|
+
sa[i] = sa_[ind];
|
|
99
|
+
__syncthreads();
|
|
100
|
+
|
|
101
|
+
if ((t+1)%_CHUNK_LEN_ == 0) {
|
|
102
|
+
int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C;
|
|
103
|
+
#pragma unroll
|
|
104
|
+
for (int j = 0; j < C; j++) {
|
|
105
|
+
stateT[j] = s_[base + j];
|
|
106
|
+
}
|
|
107
|
+
}
|
|
108
|
+
float dq = 0;
|
|
109
|
+
#pragma unroll
|
|
110
|
+
for (int j = 0; j < C; j++) {
|
|
111
|
+
dq += stateT[j]*dy[j];
|
|
112
|
+
}
|
|
113
|
+
dq_[ind] = to_bf(dq);
|
|
114
|
+
float iwi = 1.0f/(wi+0.000001f);
|
|
115
|
+
#pragma unroll
|
|
116
|
+
for (int j = 0; j < C; j++) {
|
|
117
|
+
stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi;
|
|
118
|
+
dstate[j] += dyi * q[j];
|
|
119
|
+
dstateT[j] += qi * dy[j];
|
|
120
|
+
}
|
|
121
|
+
float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0;
|
|
122
|
+
#pragma unroll
|
|
123
|
+
for (int j = 0; j < C; j++) {
|
|
124
|
+
dw += dstateT[j]*stateT[j];
|
|
125
|
+
dk += dstateT[j]*v[j];
|
|
126
|
+
dv += dstate[j]*k[j];
|
|
127
|
+
dSb += dstate[j]*b[j];
|
|
128
|
+
db += dstateT[j]*sa[j];
|
|
129
|
+
}
|
|
130
|
+
dw_[ind] = to_bf(dw * wi * wi_fac);
|
|
131
|
+
dk_[ind] = to_bf(dk);
|
|
132
|
+
dv_[ind] = to_bf(dv);
|
|
133
|
+
db_[ind] = to_bf(db);
|
|
134
|
+
__syncthreads();
|
|
135
|
+
dSb_shared[i] = dSb;
|
|
136
|
+
__syncthreads();
|
|
137
|
+
float da = 0;
|
|
138
|
+
#pragma unroll
|
|
139
|
+
for (int j = 0; j < C; j++) {
|
|
140
|
+
da += stateT[j]*dSb_shared[j];
|
|
141
|
+
}
|
|
142
|
+
da_[ind] = to_bf(da);
|
|
143
|
+
#pragma unroll
|
|
144
|
+
for (int j = 0; j < C; j++) {
|
|
145
|
+
dstate[j] = dstate[j]*w[j] + dSb * a[j];
|
|
146
|
+
dstateT[j] = dstateT[j]*wi + ai * dSb_shared[j];
|
|
147
|
+
if (t==0){
|
|
148
|
+
dh0_[dht_base + j] = dstate[j];
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa, float* h0) {
|
|
155
|
+
forward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,y,s,sa,h0);
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
void cuda_backward(int B, int T, int H,
|
|
159
|
+
bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy,
|
|
160
|
+
float*s, float*sa,float*dht,float*dh0,
|
|
161
|
+
bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da
|
|
162
|
+
) {
|
|
163
|
+
assert(T%_CHUNK_LEN_ == 0);
|
|
164
|
+
backward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,dy,s,sa,dht,dh0,dw,dq,dk,dv,dz,da);
|
|
165
|
+
}
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
#include <torch/extension.h>
|
|
2
|
+
#include <cuda_bf16.h>
|
|
3
|
+
|
|
4
|
+
using bf = __nv_bfloat16;
|
|
5
|
+
void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa,float*h0);
|
|
6
|
+
|
|
7
|
+
void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v,
|
|
8
|
+
torch::Tensor &z, torch::Tensor &a,
|
|
9
|
+
torch::Tensor &y,
|
|
10
|
+
torch::Tensor &s, torch::Tensor &sa,torch::Tensor &h0) {
|
|
11
|
+
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
|
|
12
|
+
cuda_forward(B, T, H,
|
|
13
|
+
(bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)y.data_ptr(),
|
|
14
|
+
(float*)s.data_ptr(), (float*)sa.data_ptr(),(float*)h0.data_ptr());
|
|
15
|
+
}
|
|
16
|
+
void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa,float*dht,float*dh0, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da);
|
|
17
|
+
void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy,
|
|
18
|
+
torch::Tensor &s, torch::Tensor &sa,torch::Tensor &dht,torch::Tensor &dh0,
|
|
19
|
+
torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da) {
|
|
20
|
+
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
|
|
21
|
+
cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(),
|
|
22
|
+
(bf*)dy.data_ptr(),
|
|
23
|
+
(float*)s.data_ptr(), (float*)sa.data_ptr(),(float*)dht.data_ptr(),(float*)dh0.data_ptr(),
|
|
24
|
+
(bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr());
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
TORCH_LIBRARY(wind_backstepping, m) {
|
|
28
|
+
m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa, Tensor(f!) h0) -> ()");
|
|
29
|
+
m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor sa,Tensor dht,Tensor(a!) dh0, Tensor(b!) dw, Tensor(c!) dq, Tensor(d!) dk, Tensor(e!) dv, Tensor(f!) dz, Tensor(g!) da) -> ()");
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
TORCH_LIBRARY_IMPL(wind_backstepping, CUDA, m) {
|
|
33
|
+
m.impl("forward", &forward);
|
|
34
|
+
m.impl("backward", &backward);
|
|
35
|
+
}
|