rwkv-ops 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_ops/__init__.py +45 -0
- rwkv_ops/mhc_kernel/__init__.py +50 -0
- rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
- rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
- rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
- rwkv_ops/rwkv6_kernel/__init__.py +120 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +113 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.6.1.dist-info/METADATA +495 -0
- rwkv_ops-0.6.1.dist-info/RECORD +89 -0
- rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
- rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import ops
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def transpose_head(x, head_first):
|
|
6
|
+
"""
|
|
7
|
+
对输入张量进行转置操作。
|
|
8
|
+
|
|
9
|
+
参数:
|
|
10
|
+
x: 输入张量。
|
|
11
|
+
head_first: 布尔值,决定是否进行转置。
|
|
12
|
+
|
|
13
|
+
返回:
|
|
14
|
+
转置后的张量(如果head_first为True),否则返回原张量。
|
|
15
|
+
"""
|
|
16
|
+
x = ops.cast(x, "float32")
|
|
17
|
+
if head_first:
|
|
18
|
+
return ops.transpose(x, (0, 2, 1, 3))
|
|
19
|
+
else:
|
|
20
|
+
return x
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def generalized_delta_rule(
|
|
24
|
+
r,
|
|
25
|
+
w,
|
|
26
|
+
k,
|
|
27
|
+
v,
|
|
28
|
+
a,
|
|
29
|
+
b,
|
|
30
|
+
initial_state=None,
|
|
31
|
+
output_final_state: bool = True,
|
|
32
|
+
head_first: bool = False,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
实现广义delta规则的函数。
|
|
36
|
+
|
|
37
|
+
参数:
|
|
38
|
+
r: 输入张量。
|
|
39
|
+
w: 权重张量。
|
|
40
|
+
k, v, a, b: 其他输入张量。
|
|
41
|
+
initial_state: 初始状态张量。
|
|
42
|
+
output_final_state: 是否输出最终状态。
|
|
43
|
+
head_first: 是否在计算中将head维度放在第一位。
|
|
44
|
+
|
|
45
|
+
返回:
|
|
46
|
+
根据output_final_state参数决定是否返回最终状态。
|
|
47
|
+
"""
|
|
48
|
+
DTYPE = r.dtype
|
|
49
|
+
B, T, H, N = ops.shape(r)
|
|
50
|
+
r = transpose_head(r, head_first)
|
|
51
|
+
|
|
52
|
+
k = transpose_head(k, head_first)
|
|
53
|
+
|
|
54
|
+
v = transpose_head(v, head_first)
|
|
55
|
+
a = transpose_head(a, head_first)
|
|
56
|
+
b = transpose_head(b, head_first)
|
|
57
|
+
w = transpose_head(w, head_first)
|
|
58
|
+
w = ops.exp(-ops.exp(w))
|
|
59
|
+
|
|
60
|
+
if initial_state is not None:
|
|
61
|
+
state = initial_state
|
|
62
|
+
if ops.shape(state)[0] == 1:
|
|
63
|
+
state = ops.broadcast_to(state, (B, H, N, N))
|
|
64
|
+
else:
|
|
65
|
+
state = ops.zeros((B, H, N, N))
|
|
66
|
+
state = ops.cast(state, "float32")
|
|
67
|
+
|
|
68
|
+
keras_backend = keras.config.backend()
|
|
69
|
+
|
|
70
|
+
def step(t, inputs):
|
|
71
|
+
"""
|
|
72
|
+
执行单个时间步的计算。
|
|
73
|
+
|
|
74
|
+
参数:
|
|
75
|
+
t: 当前时间步。
|
|
76
|
+
inputs: 包含当前状态和输出的列表。
|
|
77
|
+
|
|
78
|
+
返回:
|
|
79
|
+
更新后的状态和输出。
|
|
80
|
+
"""
|
|
81
|
+
state, out = inputs
|
|
82
|
+
kk = ops.reshape(k[:, t, :], (B, H, 1, N))
|
|
83
|
+
rr = ops.reshape(r[:, t, :], (B, H, N, 1))
|
|
84
|
+
vv = ops.reshape(v[:, t, :], (B, H, N, 1))
|
|
85
|
+
aa = ops.reshape(a[:, t, :], (B, H, N, 1))
|
|
86
|
+
bb = ops.reshape(b[:, t, :], (B, H, 1, N))
|
|
87
|
+
state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk
|
|
88
|
+
o = ops.cast((state @ rr), out.dtype)
|
|
89
|
+
if keras_backend == "tensorflow":
|
|
90
|
+
out = out.write(t, ops.reshape(o, (B, H, N)))
|
|
91
|
+
elif keras_backend == "torch":
|
|
92
|
+
out[:, t : t + 1] = ops.reshape(o, (B, 1, H, N))
|
|
93
|
+
else:
|
|
94
|
+
out = ops.slice_update(out, [0, t, 0, 0], ops.reshape(o, (B, 1, H, N)))
|
|
95
|
+
return [state, out]
|
|
96
|
+
|
|
97
|
+
if keras_backend == "tensorflow":
|
|
98
|
+
import tensorflow as tf
|
|
99
|
+
|
|
100
|
+
out = tf.TensorArray(DTYPE, size=T)
|
|
101
|
+
else:
|
|
102
|
+
out = ops.zeros((B, T, H, N), DTYPE)
|
|
103
|
+
state, out = ops.fori_loop(0, T, step, [state, out])
|
|
104
|
+
if keras_backend == "tensorflow":
|
|
105
|
+
out = ops.transpose(out.stack(), [1, 0, 2, 3])
|
|
106
|
+
if output_final_state:
|
|
107
|
+
return ops.cast(out, DTYPE), state
|
|
108
|
+
return ops.cast(out, DTYPE)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""
|
|
2
|
+
TensorFlow 版 generalized_delta_rule
|
|
3
|
+
前向用 tf.py_function 调 JAX CUDA 内核,反向同样走 JAX。
|
|
4
|
+
可 @tf.function 编译,可 tf.GradientTape 训练。
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
from typing import Optional, Tuple
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
from .jax_cuda_kernel.wkv7_jax import get_jax_generalized_delta_rule
|
|
11
|
+
from .jax_cuda_kernel_single.wkv7_single_step_jax import (
|
|
12
|
+
get_jax_generalized_delta_rule_single_step,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def transpose_head(x, head_first: bool):
|
|
17
|
+
"""(B, T, H, K) <-> (B, H, T, K)"""
|
|
18
|
+
x = tf.cast(x, dtype=tf.float32)
|
|
19
|
+
if head_first:
|
|
20
|
+
return tf.transpose(x, (0, 2, 1, 3))
|
|
21
|
+
return x
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_tf_generalized_delta_rule(HEAD_SIZE=64):
|
|
25
|
+
generalized_delta_rule_inference = get_jax_generalized_delta_rule(HEAD_SIZE)[1]
|
|
26
|
+
|
|
27
|
+
# ---------- 底层 kernel 包装 ----------
|
|
28
|
+
@tf.py_function(Tout=[tf.bfloat16, tf.float32])
|
|
29
|
+
def _tf_wkv7_fwd(w, q, k, v, a, b, h0):
|
|
30
|
+
"""tf.py_function 包装 JAX 前向"""
|
|
31
|
+
y, s = generalized_delta_rule_inference(
|
|
32
|
+
w=jnp.asarray(w, jnp.bfloat16),
|
|
33
|
+
r=jnp.asarray(q, jnp.bfloat16),
|
|
34
|
+
k=jnp.asarray(k, jnp.bfloat16),
|
|
35
|
+
v=jnp.asarray(v, jnp.bfloat16),
|
|
36
|
+
a=jnp.asarray(a, jnp.bfloat16),
|
|
37
|
+
b=jnp.asarray(b, jnp.bfloat16),
|
|
38
|
+
initial_state=jnp.asarray(h0, jnp.float32),
|
|
39
|
+
)
|
|
40
|
+
return (
|
|
41
|
+
tf.convert_to_tensor(y, tf.bfloat16),
|
|
42
|
+
tf.convert_to_tensor(s, tf.float32),
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# ---------- 用户接口 ----------
|
|
46
|
+
def generalized_delta_rule(
|
|
47
|
+
r: tf.Tensor, # (B, T, H, K) 或 (B, H, T, K)
|
|
48
|
+
w: tf.Tensor,
|
|
49
|
+
k: tf.Tensor,
|
|
50
|
+
v: tf.Tensor,
|
|
51
|
+
a: tf.Tensor,
|
|
52
|
+
b: tf.Tensor,
|
|
53
|
+
initial_state: Optional[tf.Tensor] = None,
|
|
54
|
+
output_final_state: bool = True,
|
|
55
|
+
head_first: bool = False,
|
|
56
|
+
chunk_len: int = 16,
|
|
57
|
+
) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
58
|
+
"""
|
|
59
|
+
与 JAX 版接口 1:1 对齐,返回 (out, last_state)
|
|
60
|
+
可 @tf.function compile,可 tf.GradientTape 训练
|
|
61
|
+
"""
|
|
62
|
+
dtype = r.dtype
|
|
63
|
+
|
|
64
|
+
r = transpose_head(r, head_first)
|
|
65
|
+
w = transpose_head(w, head_first)
|
|
66
|
+
k = transpose_head(k, head_first)
|
|
67
|
+
v = transpose_head(v, head_first)
|
|
68
|
+
a = transpose_head(a, head_first)
|
|
69
|
+
b = transpose_head(b, head_first)
|
|
70
|
+
|
|
71
|
+
B, T, H, K = tf.unstack(tf.shape(r), num=4)
|
|
72
|
+
if T % chunk_len != 0:
|
|
73
|
+
raise ValueError(f"T={T} must be divisible by chunk_len={chunk_len}")
|
|
74
|
+
|
|
75
|
+
if initial_state is None:
|
|
76
|
+
h0 = tf.zeros([B, H, K, K], dtype=tf.float32)
|
|
77
|
+
else:
|
|
78
|
+
h0 = tf.cast(initial_state, tf.float32)
|
|
79
|
+
|
|
80
|
+
# 带梯度前向
|
|
81
|
+
out, last_state = _tf_wkv7_fwd(w, r, k, v, a, b, h0)
|
|
82
|
+
|
|
83
|
+
# 转回用户期望 dtype
|
|
84
|
+
out = tf.cast(out, dtype)
|
|
85
|
+
|
|
86
|
+
return (out, last_state) if output_final_state else out
|
|
87
|
+
|
|
88
|
+
return generalized_delta_rule
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_tf_generalized_delta_rule_single_step(HEAD_SIZE=64):
|
|
92
|
+
# 获取 JAX 版本的单步 generalized delta rule
|
|
93
|
+
_wkv7_single_step_kernel = get_jax_generalized_delta_rule_single_step(HEAD_SIZE)
|
|
94
|
+
|
|
95
|
+
# ---------- 底层 kernel 包装 ----------
|
|
96
|
+
@tf.py_function(Tout=[tf.bfloat16, tf.float32])
|
|
97
|
+
def _tf_wkv7_single_step_fwd(w, r, k, v, a, b, h0):
|
|
98
|
+
"""tf.py_function 包装 JAX 单步前向"""
|
|
99
|
+
y, s = _wkv7_single_step_kernel(
|
|
100
|
+
w=jnp.asarray(w, jnp.bfloat16),
|
|
101
|
+
r=jnp.asarray(r, jnp.bfloat16),
|
|
102
|
+
k=jnp.asarray(k, jnp.bfloat16),
|
|
103
|
+
v=jnp.asarray(v, jnp.bfloat16),
|
|
104
|
+
a=jnp.asarray(a, jnp.bfloat16),
|
|
105
|
+
b=jnp.asarray(b, jnp.bfloat16),
|
|
106
|
+
initial_state=jnp.asarray(h0, jnp.float32),
|
|
107
|
+
)
|
|
108
|
+
return (
|
|
109
|
+
tf.convert_to_tensor(y, tf.bfloat16),
|
|
110
|
+
tf.convert_to_tensor(s, tf.float32),
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# ---------- 用户接口 ----------
|
|
114
|
+
def generalized_delta_rule_single_step(
|
|
115
|
+
r: tf.Tensor, # (B, 1, H, K) 或 (B, H, 1, K)
|
|
116
|
+
w: tf.Tensor,
|
|
117
|
+
k: tf.Tensor,
|
|
118
|
+
v: tf.Tensor,
|
|
119
|
+
a: tf.Tensor,
|
|
120
|
+
b: tf.Tensor,
|
|
121
|
+
initial_state: Optional[tf.Tensor] = None,
|
|
122
|
+
output_final_state: bool = True,
|
|
123
|
+
head_first: bool = False,
|
|
124
|
+
) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
125
|
+
"""
|
|
126
|
+
单步 generalized delta rule 实现
|
|
127
|
+
与 JAX 版单步接口对齐,返回 (out, last_state)
|
|
128
|
+
"""
|
|
129
|
+
dtype = r.dtype
|
|
130
|
+
|
|
131
|
+
r = transpose_head(r, head_first)
|
|
132
|
+
w = transpose_head(w, head_first)
|
|
133
|
+
k = transpose_head(k, head_first)
|
|
134
|
+
v = transpose_head(v, head_first)
|
|
135
|
+
a = transpose_head(a, head_first)
|
|
136
|
+
b = transpose_head(b, head_first)
|
|
137
|
+
|
|
138
|
+
B, T, H, K = tf.unstack(tf.shape(r), num=4)
|
|
139
|
+
if T != 1:
|
|
140
|
+
raise ValueError(f"Single-step kernel requires T=1, but got T={T}")
|
|
141
|
+
|
|
142
|
+
if initial_state is None:
|
|
143
|
+
h0 = tf.zeros([B, H, K, K], dtype=tf.float32)
|
|
144
|
+
else:
|
|
145
|
+
h0 = tf.cast(initial_state, tf.float32)
|
|
146
|
+
|
|
147
|
+
# 前向计算
|
|
148
|
+
y, s = _tf_wkv7_single_step_fwd(w, r, k, v, a, b, h0)
|
|
149
|
+
|
|
150
|
+
# 转回用户期望 dtype
|
|
151
|
+
out = tf.cast(y, dtype)
|
|
152
|
+
|
|
153
|
+
return (out, s) if output_final_state else out
|
|
154
|
+
|
|
155
|
+
return generalized_delta_rule_single_step
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
#include <cuda_bf16.h>
|
|
2
|
+
#include <assert.h>
|
|
3
|
+
#include <cstdint>
|
|
4
|
+
// ref link:https://github.com/BlinkDL/RWKV-CUDA/tree/main/rwkv7_fast_fused
|
|
5
|
+
using bf = __nv_bfloat16;
|
|
6
|
+
|
|
7
|
+
__device__ inline float to_float(const bf & u) {
|
|
8
|
+
return __bfloat162float(u);
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
__device__ inline bf to_bf(const float & u) {
|
|
12
|
+
return __float2bfloat16_rn(u);
|
|
13
|
+
}
|
|
14
|
+
typedef bf * __restrict__ F_;
|
|
15
|
+
|
|
16
|
+
/* -------------------- 前向传播 Kernel -------------------- */
|
|
17
|
+
template<int C> __launch_bounds__(C, 2) // 【优化1】显式指定 launch bounds,提升 Occupancy
|
|
18
|
+
__global__ void forward_kernel(int T, int H,
|
|
19
|
+
F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
|
|
20
|
+
bf* y_, float* s_, float* sa_, float* h0_) {
|
|
21
|
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
|
22
|
+
float state[C] = {0};
|
|
23
|
+
__shared__ float q[C], k[C], w[C], a[C], b[C];
|
|
24
|
+
|
|
25
|
+
int64_t h0_base = ((int64_t)bb*H + hh)*C*C + i*C;
|
|
26
|
+
#pragma unroll
|
|
27
|
+
for (int j = 0; j < C; j++) {
|
|
28
|
+
state[j] = h0_[h0_base + j];
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
for (int t = 0; t < T; t++) {
|
|
32
|
+
int64_t ind = (int64_t)bb*T*H*C + (int64_t)t*H*C + hh * C + i;
|
|
33
|
+
|
|
34
|
+
__syncthreads();
|
|
35
|
+
q[i] = to_float(q_[ind]);
|
|
36
|
+
w[i] = __expf(-__expf(to_float(w_[ind])));
|
|
37
|
+
k[i] = to_float(k_[ind]);
|
|
38
|
+
a[i] = to_float(a_[ind]);
|
|
39
|
+
b[i] = to_float(b_[ind]);
|
|
40
|
+
__syncthreads();
|
|
41
|
+
|
|
42
|
+
float sa = 0;
|
|
43
|
+
#pragma unroll
|
|
44
|
+
for (int j = 0; j < C; j++) {
|
|
45
|
+
sa += a[j] * state[j];
|
|
46
|
+
}
|
|
47
|
+
sa_[ind] = sa;
|
|
48
|
+
|
|
49
|
+
float v_val = to_float(v_[ind]);
|
|
50
|
+
float y = 0;
|
|
51
|
+
#pragma unroll
|
|
52
|
+
for (int j = 0; j < C; j++) {
|
|
53
|
+
float &s = state[j];
|
|
54
|
+
s = s * w[j] + sa * b[j] + k[j] * v_val;
|
|
55
|
+
y += s * q[j];
|
|
56
|
+
}
|
|
57
|
+
y_[ind] = to_bf(y);
|
|
58
|
+
|
|
59
|
+
if ((t+1)%_CHUNK_LEN_ == 0) {
|
|
60
|
+
int64_t base = ((int64_t)bb*H+hh)*(T/_CHUNK_LEN_)*C*C + ((int64_t)t/_CHUNK_LEN_)*C*C + i;
|
|
61
|
+
#pragma unroll
|
|
62
|
+
for (int j = 0; j < C; j++) {
|
|
63
|
+
s_[base + j*C] = state[j];
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
/* -------------------- 反向传播 Kernel -------------------- */
|
|
70
|
+
template<int C> __launch_bounds__(C, 2) // 【优化1】显式指定 launch bounds
|
|
71
|
+
__global__ void backward_kernel(int T, int H,
|
|
72
|
+
F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_,
|
|
73
|
+
float * __restrict__ s_, float * __restrict__ sa_,
|
|
74
|
+
float * __restrict__ dht_, float * __restrict__ dh0_,
|
|
75
|
+
bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) {
|
|
76
|
+
|
|
77
|
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
|
78
|
+
float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
|
|
79
|
+
|
|
80
|
+
int64_t dht_base = ((int64_t)bb*H + hh)*C*C + i*C;
|
|
81
|
+
#pragma unroll
|
|
82
|
+
for (int j = 0; j < C; j++) {
|
|
83
|
+
dstate[j] = dht_[dht_base + j];
|
|
84
|
+
dstateT[j] = dht_[dht_base + j];
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
__shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
|
|
88
|
+
float qi, wi, ki, ai, bi, dyi;
|
|
89
|
+
|
|
90
|
+
for (int t = T-1; t >= 0; t--) {
|
|
91
|
+
int64_t ind = (int64_t)bb*T*H*C + (int64_t)t*H*C + hh * C + i;
|
|
92
|
+
|
|
93
|
+
__syncthreads();
|
|
94
|
+
q[i] = qi = to_float(q_[ind]);
|
|
95
|
+
float wi_fac = -__expf(to_float(w_[ind]));
|
|
96
|
+
w[i] = wi = __expf(wi_fac);
|
|
97
|
+
k[i] = ki = to_float(k_[ind]);
|
|
98
|
+
v[i] = to_float(v_[ind]);
|
|
99
|
+
a[i] = ai = to_float(a_[ind]);
|
|
100
|
+
b[i] = bi = to_float(b_[ind]);
|
|
101
|
+
dy[i] = dyi = to_float(dy_[ind]);
|
|
102
|
+
sa[i] = sa_[ind];
|
|
103
|
+
__syncthreads();
|
|
104
|
+
|
|
105
|
+
if ((t+1)%_CHUNK_LEN_ == 0) {
|
|
106
|
+
int64_t base = ((int64_t)bb*H+hh)*(T/_CHUNK_LEN_)*C*C + ((int64_t)t/_CHUNK_LEN_)*C*C + i*C;
|
|
107
|
+
|
|
108
|
+
// 【优化2】使用 float4 向量加载,内存带宽提升 4倍
|
|
109
|
+
const float4* s4 = (const float4*)(s_ + base);
|
|
110
|
+
#pragma unroll
|
|
111
|
+
for (int j4 = 0; j4 < C/4; j4++) {
|
|
112
|
+
float4 q_vec = s4[j4];
|
|
113
|
+
const int j = j4 * 4;
|
|
114
|
+
stateT[j+0] = q_vec.x;
|
|
115
|
+
stateT[j+1] = q_vec.y;
|
|
116
|
+
stateT[j+2] = q_vec.z;
|
|
117
|
+
stateT[j+3] = q_vec.w;
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
float dq_val = 0;
|
|
122
|
+
#pragma unroll
|
|
123
|
+
for (int j = 0; j < C; j++) {
|
|
124
|
+
dq_val += stateT[j] * dy[j];
|
|
125
|
+
}
|
|
126
|
+
dq_[ind] = to_bf(dq_val);
|
|
127
|
+
|
|
128
|
+
float iwi = 1.0f/(wi + 0.000001f);
|
|
129
|
+
#pragma unroll
|
|
130
|
+
for (int j = 0; j < C; j++) {
|
|
131
|
+
stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi;
|
|
132
|
+
dstate[j] += dyi * q[j];
|
|
133
|
+
dstateT[j] += qi * dy[j];
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0;
|
|
137
|
+
#pragma unroll
|
|
138
|
+
for (int j = 0; j < C; j++) {
|
|
139
|
+
dw += dstateT[j] * stateT[j];
|
|
140
|
+
dk += dstateT[j] * v[j];
|
|
141
|
+
dv += dstate[j] * k[j];
|
|
142
|
+
dSb += dstate[j] * b[j];
|
|
143
|
+
db += dstateT[j] * sa[j];
|
|
144
|
+
}
|
|
145
|
+
dw_[ind] = to_bf(dw * wi * wi_fac);
|
|
146
|
+
dk_[ind] = to_bf(dk);
|
|
147
|
+
dv_[ind] = to_bf(dv);
|
|
148
|
+
db_[ind] = to_bf(db);
|
|
149
|
+
|
|
150
|
+
__syncthreads();
|
|
151
|
+
dSb_shared[i] = dSb;
|
|
152
|
+
__syncthreads();
|
|
153
|
+
|
|
154
|
+
float da = 0;
|
|
155
|
+
#pragma unroll
|
|
156
|
+
for (int j = 0; j < C; j++) {
|
|
157
|
+
da += stateT[j] * dSb_shared[j];
|
|
158
|
+
}
|
|
159
|
+
da_[ind] = to_bf(da);
|
|
160
|
+
|
|
161
|
+
#pragma unroll
|
|
162
|
+
for (int j = 0; j < C; j++) {
|
|
163
|
+
dstate[j] = dstate[j] * w[j] + dSb * a[j];
|
|
164
|
+
dstateT[j] = dstateT[j] * wi + ai * dSb_shared[j];
|
|
165
|
+
if (t == 0) {
|
|
166
|
+
dh0_[dht_base + j] = dstate[j];
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
/* -------------------- 推理专用 Kernel -------------------- */
|
|
173
|
+
template<int C> __launch_bounds__(C, 2) // 【优化1】推理 kernel 同样优化
|
|
174
|
+
__global__ void forward_inference_kernel(int T, int H,
|
|
175
|
+
F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
|
|
176
|
+
bf *y_, float *s_, float *h0_) {
|
|
177
|
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
|
178
|
+
float state[C] = {0};
|
|
179
|
+
__shared__ float q[C], k[C], w[C], a[C], b[C];
|
|
180
|
+
|
|
181
|
+
int64_t h0_base = ((int64_t)bb * H + hh) * C * C + i * C;
|
|
182
|
+
#pragma unroll
|
|
183
|
+
for (int j = 0; j < C; ++j) state[j] = h0_[h0_base + j];
|
|
184
|
+
|
|
185
|
+
for (int t = 0; t < T; ++t) {
|
|
186
|
+
int64_t ind = (int64_t)bb * T * H * C + (int64_t)t * H * C + hh * C + i;
|
|
187
|
+
|
|
188
|
+
__syncthreads();
|
|
189
|
+
q[i] = to_float(q_[ind]);
|
|
190
|
+
w[i] = __expf(-__expf(to_float(w_[ind])));
|
|
191
|
+
k[i] = to_float(k_[ind]);
|
|
192
|
+
a[i] = to_float(a_[ind]);
|
|
193
|
+
b[i] = to_float(b_[ind]);
|
|
194
|
+
__syncthreads();
|
|
195
|
+
|
|
196
|
+
float sa = 0.f;
|
|
197
|
+
#pragma unroll
|
|
198
|
+
for (int j = 0; j < C; ++j) sa += a[j] * state[j];
|
|
199
|
+
|
|
200
|
+
float v_val = to_float(v_[ind]);
|
|
201
|
+
float y = 0.f;
|
|
202
|
+
#pragma unroll
|
|
203
|
+
for (int j = 0; j < C; ++j) {
|
|
204
|
+
float &s = state[j];
|
|
205
|
+
s = s * w[j] + sa * b[j] + k[j] * v_val;
|
|
206
|
+
y += s * q[j];
|
|
207
|
+
}
|
|
208
|
+
y_[ind] = to_bf(y);
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
// 仅写入最终状态
|
|
212
|
+
int64_t base = ((int64_t)bb * H + hh) * C * C + i * C;
|
|
213
|
+
#pragma unroll
|
|
214
|
+
for (int j = 0; j < C; ++j) s_[base + j] = state[j];
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
/* -------------------- C 接口函数 -------------------- */
|
|
218
|
+
void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa, float* h0) {
|
|
219
|
+
forward_kernel<_C_><<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,y,s,sa,h0);
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
void cuda_backward(int B, int T, int H,
|
|
223
|
+
bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy,
|
|
224
|
+
float*s, float*sa,float*dht,float*dh0,
|
|
225
|
+
bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da
|
|
226
|
+
) {
|
|
227
|
+
assert(T%_CHUNK_LEN_ == 0);
|
|
228
|
+
backward_kernel<_C_><<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,dy,s,sa,dht,dh0,dw,dq,dk,dv,dz,da);
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
void cuda_forward_inference(int B, int T, int H,
|
|
232
|
+
bf* w, bf* q, bf* k, bf* v, bf* a, bf* b,
|
|
233
|
+
bf* y, float* s, float* h0) {
|
|
234
|
+
forward_inference_kernel<_C_><<<dim3(H, B), dim3(_C_)>>>(T, H, w, q, k, v, a, b, y, s, h0);
|
|
235
|
+
}
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
#include <torch/extension.h>
|
|
2
|
+
#include <cuda_bf16.h>
|
|
3
|
+
|
|
4
|
+
using bf = __nv_bfloat16;
|
|
5
|
+
|
|
6
|
+
// ---------- 原有函数声明 ----------
|
|
7
|
+
void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa, float* h0);
|
|
8
|
+
void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa,float*dht,float*dh0, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da);
|
|
9
|
+
|
|
10
|
+
// ---------- 新增推理函数声明(必须!)----------
|
|
11
|
+
void cuda_forward_inference(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*y, float*s, float* h0);
|
|
12
|
+
|
|
13
|
+
// ---------- 原有forward函数 ----------
|
|
14
|
+
void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v,
|
|
15
|
+
torch::Tensor &z, torch::Tensor &a,
|
|
16
|
+
torch::Tensor &y,
|
|
17
|
+
torch::Tensor &s, torch::Tensor &sa,torch::Tensor &h0) {
|
|
18
|
+
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
|
|
19
|
+
cuda_forward(B, T, H,
|
|
20
|
+
(bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)y.data_ptr(),
|
|
21
|
+
(float*)s.data_ptr(), (float*)sa.data_ptr(),(float*)h0.data_ptr());
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
// ---------- 原有backward函数 ----------
|
|
25
|
+
void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy,
|
|
26
|
+
torch::Tensor &s, torch::Tensor &sa,torch::Tensor &dht,torch::Tensor &dh0,
|
|
27
|
+
torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da) {
|
|
28
|
+
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
|
|
29
|
+
cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(),
|
|
30
|
+
(bf*)dy.data_ptr(),
|
|
31
|
+
(float*)s.data_ptr(), (float*)sa.data_ptr(),(float*)dht.data_ptr(),(float*)dh0.data_ptr(),
|
|
32
|
+
(bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr());
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
// ---------- 新增推理forward函数 ----------
|
|
36
|
+
void forward_inference(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v,
|
|
37
|
+
torch::Tensor &a, torch::Tensor &b,
|
|
38
|
+
torch::Tensor &y,
|
|
39
|
+
torch::Tensor &s, torch::Tensor &h0) {
|
|
40
|
+
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
|
|
41
|
+
cuda_forward_inference(B, T, H,
|
|
42
|
+
(bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(),
|
|
43
|
+
(bf*)a.data_ptr(), (bf*)b.data_ptr(),
|
|
44
|
+
(bf*)y.data_ptr(),
|
|
45
|
+
(float*)s.data_ptr(), (float*)h0.data_ptr());
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
// ---------- 合并的算子注册(不要分开写!)----------
|
|
49
|
+
TORCH_LIBRARY(wind_backstepping, m) {
|
|
50
|
+
// 训练算子
|
|
51
|
+
m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa, Tensor(f!) h0) -> ()");
|
|
52
|
+
m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor sa,Tensor dht,Tensor(a!) dh0, Tensor(b!) dw, Tensor(c!) dq, Tensor(d!) dk, Tensor(e!) dv, Tensor(f!) dz, Tensor(g!) da) -> ()");
|
|
53
|
+
|
|
54
|
+
// 推理算子(追加到同一个块内)
|
|
55
|
+
m.def("forward_inference(Tensor w, Tensor q, Tensor k, Tensor v, Tensor a, Tensor b, Tensor(a!) y, Tensor(b!) s, Tensor(c!) h0) -> ()");
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
// ---------- 合并的实现注册 ----------
|
|
59
|
+
TORCH_LIBRARY_IMPL(wind_backstepping, CUDA, m) {
|
|
60
|
+
m.impl("forward", &forward);
|
|
61
|
+
m.impl("backward", &backward);
|
|
62
|
+
m.impl("forward_inference", &forward_inference);
|
|
63
|
+
}
|