rwkv-ops 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_ops/__init__.py +45 -0
- rwkv_ops/mhc_kernel/__init__.py +50 -0
- rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
- rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
- rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
- rwkv_ops/rwkv6_kernel/__init__.py +120 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +113 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.6.1.dist-info/METADATA +495 -0
- rwkv_ops-0.6.1.dist-info/RECORD +89 -0
- rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
- rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
# --- 辅助函数:确保在 fp32 下计算以保证数值稳定性 ---
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def fp32_sigmoid(x):
|
|
7
|
+
dtype = x.dtype
|
|
8
|
+
return ops.cast(ops.nn.sigmoid(ops.cast(x, "float32")), dtype)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# --- 核心 MHC 算子 ---
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def sinkhorn_knopp(inp, num_iters=20, eps=1e-8):
|
|
15
|
+
"""
|
|
16
|
+
将输入矩阵投影为双拟随机矩阵 (Doubly Stochastic Matrix)。
|
|
17
|
+
通常 inp 是 log 域的矩阵 (H_res_raw)。
|
|
18
|
+
"""
|
|
19
|
+
dtype = inp.dtype
|
|
20
|
+
# 转换到 fp32 并应用 exp (论文 Eq. 9 之前的步骤)
|
|
21
|
+
x = ops.cast(inp, "float32")
|
|
22
|
+
# 防溢出技巧:减去最大值
|
|
23
|
+
x = x - ops.max(x, axis=(-1, -2), keepdims=True)
|
|
24
|
+
P = ops.exp(x)
|
|
25
|
+
|
|
26
|
+
for _ in range(num_iters):
|
|
27
|
+
# 行归一化
|
|
28
|
+
P = P / (ops.sum(P, axis=-1, keepdims=True) + eps)
|
|
29
|
+
# 列归一化
|
|
30
|
+
P = P / (ops.sum(P, axis=-2, keepdims=True) + eps)
|
|
31
|
+
|
|
32
|
+
return ops.cast(P, dtype)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def rmsnorm(inp, eps=1e-5):
|
|
36
|
+
"""
|
|
37
|
+
标准 RMSNorm 算子。
|
|
38
|
+
inp: [..., C], weight: [C]
|
|
39
|
+
"""
|
|
40
|
+
dtype = inp.dtype
|
|
41
|
+
x = ops.cast(inp, "float32")
|
|
42
|
+
# 计算均方根
|
|
43
|
+
rms = ops.sqrt(ops.mean(ops.square(x), axis=-1, keepdims=True) + eps)
|
|
44
|
+
x_normed = x / rms
|
|
45
|
+
# 应用权重
|
|
46
|
+
return ops.cast(x_normed, dtype)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def stream_aggregate(inp, H_pre):
|
|
50
|
+
# 1. 转换为 float32 进行高精度计算
|
|
51
|
+
inp_f32 = ops.cast(inp, "float32")
|
|
52
|
+
H_f32 = ops.cast(H_pre, "float32")
|
|
53
|
+
|
|
54
|
+
# 2. 在 float32 空间完成乘法和累加
|
|
55
|
+
out_f32 = ops.sum(inp_f32 * ops.expand_dims(H_f32, -1), axis=-2)
|
|
56
|
+
|
|
57
|
+
# 3. 最后转回原先的格式 (如 bf16)
|
|
58
|
+
return ops.cast(out_f32, inp.dtype)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def stream_distribute(inp, H_post):
|
|
62
|
+
"""
|
|
63
|
+
Distribute (1 -> n): 将单流输出分发回多流。
|
|
64
|
+
对齐精度版:强制在 FP32 下进行广播乘法。
|
|
65
|
+
|
|
66
|
+
inp: [B, T, C] (BF16)
|
|
67
|
+
H_post: [B, T, n] (FP32)
|
|
68
|
+
"""
|
|
69
|
+
# 1. 记录原始类型
|
|
70
|
+
original_dtype = inp.dtype
|
|
71
|
+
|
|
72
|
+
# 2. 提升到 FP32 进行运算 (对齐 CUDA 内核内部的 to_float 逻辑)
|
|
73
|
+
# [B, T, 1, C]
|
|
74
|
+
x_fp32 = ops.cast(ops.expand_dims(inp, -2), "float32")
|
|
75
|
+
|
|
76
|
+
# [B, T, n, 1]
|
|
77
|
+
w_fp32 = ops.cast(ops.expand_dims(H_post, -1), "float32")
|
|
78
|
+
|
|
79
|
+
# 3. 执行广播乘法
|
|
80
|
+
# 结果为 [B, T, n, C]
|
|
81
|
+
res_fp32 = x_fp32 * w_fp32
|
|
82
|
+
|
|
83
|
+
# 4. 转回原始类型 (对齐 CUDA 内核末尾的 to_bf 逻辑)
|
|
84
|
+
return ops.cast(res_fp32, original_dtype)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def stream_mix(inp, M):
|
|
88
|
+
"""
|
|
89
|
+
Mix (n -> n): 残差流之间的线性交互。
|
|
90
|
+
inp: [B, T, n, C]
|
|
91
|
+
M: [B, T, n, n] 或 [n, n] (由 sinkhorn_knopp 生成的方阵)
|
|
92
|
+
"""
|
|
93
|
+
# 使用 einsum 表达矩阵乘法:M @ inp
|
|
94
|
+
# i,j 是流索引,k 是通道索引
|
|
95
|
+
dtype = inp.dtype
|
|
96
|
+
inp = ops.cast(inp, M.dtype)
|
|
97
|
+
if len(ops.shape(M)) == 2:
|
|
98
|
+
out = ops.einsum("ij,btjk->btik", M, inp)
|
|
99
|
+
else:
|
|
100
|
+
out = ops.einsum("btij,btjk->btik", M, inp)
|
|
101
|
+
return ops.cast(out, dtype)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def stream_mix_fp32(x_expanded, H_res):
|
|
105
|
+
"""内部强制使用 FP32 计算的流混合"""
|
|
106
|
+
# x_expanded: [B, T, n, C], H_res: [B, T, n, n]
|
|
107
|
+
x_f32 = ops.cast(x_expanded, "float32")
|
|
108
|
+
h_f32 = ops.cast(H_res, "float32")
|
|
109
|
+
# 执行矩阵乘法: [B, T, n, n] @ [B, T, n, C] -> [B, T, n, C]
|
|
110
|
+
return ops.matmul(h_f32, x_f32)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def stream_distribute_fp32(layer_out, H_post):
|
|
114
|
+
"""内部强制使用 FP32 计算的分发"""
|
|
115
|
+
# layer_out: [B, T, C], H_post: [B, T, n]
|
|
116
|
+
l_f32 = ops.cast(layer_out, "float32")
|
|
117
|
+
h_f32 = ops.cast(H_post, "float32")
|
|
118
|
+
|
|
119
|
+
# [B, T, 1, C] * [B, T, n, 1] -> [B, T, n, C]
|
|
120
|
+
return ops.expand_dims(l_f32, -2) * ops.expand_dims(h_f32, -1)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def mhc_post_op(layer_out, x_expanded, H_post, H_res):
|
|
124
|
+
"""
|
|
125
|
+
mHC 后处理融合算子
|
|
126
|
+
输入:
|
|
127
|
+
layer_out: [B, T, C] - 核心层 (Attention/FFN) 处理后的输出
|
|
128
|
+
x_expanded: [B, T, n, C] - 之前的扩展残差流 (Pre-Op 之前的状态)
|
|
129
|
+
H_post: [B, T, n] - 分发权重 (来自 Pre-Op),2*sigmoid后的数值
|
|
130
|
+
H_res: [B, T, n, n] - 流混合矩阵 (来自 Pre-Op)
|
|
131
|
+
返回:
|
|
132
|
+
x_next: [B, T, n, C] - 更新后的扩展残差流
|
|
133
|
+
"""
|
|
134
|
+
# 1. 在 FP32 下计算混合路径
|
|
135
|
+
x_mixed_f32 = stream_mix_fp32(x_expanded, H_res)
|
|
136
|
+
|
|
137
|
+
# 2. 在 FP32 下计算增量路径
|
|
138
|
+
x_delta_f32 = stream_distribute_fp32(layer_out, H_post)
|
|
139
|
+
|
|
140
|
+
# 3. 在 FP32 下完成最后的残差加法
|
|
141
|
+
x_next_f32 = x_mixed_f32 + x_delta_f32
|
|
142
|
+
|
|
143
|
+
# 4. 只在最后输出时进行一次 BF16 转换
|
|
144
|
+
# 这一步对应 CUDA 内核中最后的 to_bf()
|
|
145
|
+
return ops.cast(x_next_f32, x_expanded.dtype)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def mhc_pre_op(x_expanded, h_pre_raw, h_post_raw, h_res_raw, num_iters=20, eps=1e-8):
|
|
149
|
+
"""
|
|
150
|
+
mHC 前处理融合算子
|
|
151
|
+
输入:
|
|
152
|
+
x_expanded: [B, T, n, C] - 当前的扩展残差流
|
|
153
|
+
h_pre_raw, h_post_raw: [B, T, n] - 线性投影后的原始激活值
|
|
154
|
+
h_res_raw: [B, T, n*n] - 用于生成 Sinkhorn 矩阵的原始值
|
|
155
|
+
返回:
|
|
156
|
+
x_layer_in: [B, T, C] - 聚合后准备进入 Layer (Attention/FFN) 的输入
|
|
157
|
+
H_post: [B, T, n] - 激活后的分发权重
|
|
158
|
+
H_res: [B, T, n, n] - 经过流形约束后的混合矩阵
|
|
159
|
+
"""
|
|
160
|
+
original_dtype = x_expanded.dtype
|
|
161
|
+
B, T, n, C = ops.shape(x_expanded)
|
|
162
|
+
|
|
163
|
+
# --- 0. 提升精度 ---
|
|
164
|
+
x_exp_f32 = ops.cast(x_expanded, "float32")
|
|
165
|
+
h_pre_f32 = ops.cast(h_pre_raw, "float32")
|
|
166
|
+
h_post_f32 = ops.cast(h_post_raw, "float32")
|
|
167
|
+
h_res_f32 = ops.cast(h_res_raw, "float32")
|
|
168
|
+
|
|
169
|
+
# --- 1. Stream Aggregate (n -> 1) ---
|
|
170
|
+
# H_pre = sigmoid(h_pre)
|
|
171
|
+
H_pre_f32 = ops.nn.sigmoid(h_pre_f32)
|
|
172
|
+
|
|
173
|
+
# x_layer_in = sum(H_pre_i * x_expanded_i)
|
|
174
|
+
# 使用 expand_dims 确保广播正确:[B, T, n, 1] * [B, T, n, C] -> [B, T, n, C] -> sum -> [B, T, C]
|
|
175
|
+
x_layer_in_f32 = ops.sum(ops.expand_dims(H_pre_f32, -1) * x_exp_f32, axis=-2)
|
|
176
|
+
|
|
177
|
+
# --- 2. H_post 计算 (带 2.0 缩放) ---
|
|
178
|
+
# 根据论文,这里使用 2.0 * sigmoid 确保恒等映射的初始化稳定性
|
|
179
|
+
H_post_f32 = 2.0 * ops.nn.sigmoid(h_post_f32)
|
|
180
|
+
|
|
181
|
+
# --- 3. Sinkhorn-Knopp (n x n 投影) ---
|
|
182
|
+
h_res_reshaped_f32 = ops.reshape(h_res_f32, (B, T, n, n))
|
|
183
|
+
|
|
184
|
+
# 内部执行 sinkhorn 迭代(确保 sinkhorn_knopp 内部也是 float32)
|
|
185
|
+
H_res_f32 = sinkhorn_knopp(h_res_reshaped_f32, num_iters=num_iters, eps=eps)
|
|
186
|
+
|
|
187
|
+
# --- 4. 转换回原始格式 ---
|
|
188
|
+
# 模拟 CUDA Kernel 最后写回显存时的 cast 操作
|
|
189
|
+
return (
|
|
190
|
+
ops.cast(x_layer_in_f32, original_dtype),
|
|
191
|
+
ops.cast(H_post_f32, "float32"), # H 权重通常在模型中保持 FP32 精度
|
|
192
|
+
ops.cast(H_res_f32, "float32"),
|
|
193
|
+
)
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
#include <cuda_runtime.h>
|
|
2
|
+
#include <cuda_bf16.h>
|
|
3
|
+
#include "../common_kernel/include/mhc_types.h"
|
|
4
|
+
#include "../common_kernel/kernels/sinkhorn_knopp.cuh"
|
|
5
|
+
#include "../common_kernel/kernels/rmsnorm.cuh"
|
|
6
|
+
#include "../common_kernel/kernels/stream_mix.cuh"
|
|
7
|
+
#include "../common_kernel/kernels/stream_aggregate.cuh"
|
|
8
|
+
#include "../common_kernel/kernels/stream_distribute.cuh"
|
|
9
|
+
#include "../common_kernel/kernels/mhc_post_op.cuh"
|
|
10
|
+
#include "../common_kernel/kernels/mhc_pre_op.cuh"
|
|
11
|
+
|
|
12
|
+
namespace mhc {
|
|
13
|
+
// --- Post-Op 融合前向 ---
|
|
14
|
+
void cuda_mhc_post_op_fwd(
|
|
15
|
+
nv_bfloat16* out, // [B, T, n, C]
|
|
16
|
+
const nv_bfloat16* layer_out, // [B, T, C]
|
|
17
|
+
const nv_bfloat16* x_expanded, // [B, T, n, C]
|
|
18
|
+
const float* H_post, // [B, T, n]
|
|
19
|
+
const float* H_res, // [B, T, n, n]
|
|
20
|
+
int64_t B, int64_t T, int n, int64_t C,
|
|
21
|
+
cudaStream_t stream)
|
|
22
|
+
{
|
|
23
|
+
// 调用 .cuh 中的 inline 包装函数
|
|
24
|
+
mhc::mhc_post_op_forward(
|
|
25
|
+
reinterpret_cast<mhc::floatX*>(out),
|
|
26
|
+
reinterpret_cast<const mhc::floatX*>(layer_out),
|
|
27
|
+
reinterpret_cast<const mhc::floatX*>(x_expanded),
|
|
28
|
+
H_post,
|
|
29
|
+
H_res,
|
|
30
|
+
B, T, n, C,
|
|
31
|
+
stream
|
|
32
|
+
);
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
// --- 反向传播包装 (融合版) ---
|
|
36
|
+
void cuda_mhc_post_op_bwd(
|
|
37
|
+
nv_bfloat16* d_layer_out, // [B, T, C]
|
|
38
|
+
nv_bfloat16* d_x_expanded, // [B, T, n, C]
|
|
39
|
+
float* d_H_post, // [B, T, n]
|
|
40
|
+
float* d_H_res, // [B, T, n, n]
|
|
41
|
+
const nv_bfloat16* grad_next, // [B, T, n, C]
|
|
42
|
+
const nv_bfloat16* layer_out, // [B, T, C]
|
|
43
|
+
const nv_bfloat16* x_expanded, // [B, T, n, C]
|
|
44
|
+
const float* H_post,
|
|
45
|
+
const float* H_res,
|
|
46
|
+
int64_t B, int64_t T, int n, int64_t C,
|
|
47
|
+
cudaStream_t stream)
|
|
48
|
+
{
|
|
49
|
+
// 调用 .cuh 中的全量反向融合内核
|
|
50
|
+
// 该内核内部会处理 dx, dl, dH_post, dH_res 的全部逻辑
|
|
51
|
+
mhc::mhc_post_op_backward_full(
|
|
52
|
+
reinterpret_cast<mhc::floatX*>(d_layer_out),
|
|
53
|
+
reinterpret_cast<mhc::floatX*>(d_x_expanded),
|
|
54
|
+
d_H_post,
|
|
55
|
+
d_H_res,
|
|
56
|
+
reinterpret_cast<const mhc::floatX*>(grad_next),
|
|
57
|
+
reinterpret_cast<const mhc::floatX*>(layer_out),
|
|
58
|
+
reinterpret_cast<const mhc::floatX*>(x_expanded),
|
|
59
|
+
H_post,
|
|
60
|
+
H_res,
|
|
61
|
+
B, T, n, C,
|
|
62
|
+
stream
|
|
63
|
+
);
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
// --- Sinkhorn 包装 ---
|
|
67
|
+
void cuda_sinkhorn_fwd(float* out, const float* inp, int64_t B, int64_t M, int64_t N, int iters, float eps, cudaStream_t stream) {
|
|
68
|
+
for (int64_t b = 0; b < B; b++) {
|
|
69
|
+
mhc::sinkhorn_knopp_forward(out + b * M * N, inp + b * M * N, (int)M, (int)N, iters, eps, stream);
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
void cuda_sinkhorn_bwd(float* d_inp, const float* grad, const float* M_out, const float* M_inp, int64_t B, int64_t N, int iters, float eps, cudaStream_t stream) {
|
|
74
|
+
for (int64_t b = 0; b < B; b++) {
|
|
75
|
+
mhc::sinkhorn_knopp_backward(d_inp + b * N * N, grad + b * N * N, M_out + b * N * N, M_inp + b * N * N, (int)N, iters, eps, stream);
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
// --- RMSNorm 包装 ---
|
|
80
|
+
void cuda_rmsnorm_fwd(nv_bfloat16* out, const nv_bfloat16* inp, int64_t N, int64_t C, float eps, cudaStream_t stream) {
|
|
81
|
+
mhc::rmsnorm_forward(reinterpret_cast<mhc::floatX*>(out), reinterpret_cast<const mhc::floatX*>(inp), N, C, eps, stream);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
void cuda_rmsnorm_bwd(nv_bfloat16* dx, const nv_bfloat16* grad, const nv_bfloat16* x, int64_t N, int64_t C, float eps, cudaStream_t stream) {
|
|
85
|
+
mhc::rmsnorm_backward(reinterpret_cast<mhc::floatX*>(dx), reinterpret_cast<const mhc::floatX*>(grad), reinterpret_cast<const mhc::floatX*>(x), N, C, eps, stream);
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// --- Stream Mix 包装 ---
|
|
89
|
+
void cuda_stream_mix_fwd(nv_bfloat16* out, const nv_bfloat16* inp, const float* M, int64_t B, int64_t T, int n, int64_t C, cudaStream_t stream) {
|
|
90
|
+
mhc::stream_mix_forward(reinterpret_cast<mhc::floatX*>(out), reinterpret_cast<const mhc::floatX*>(inp), M, B, T, n, C, stream);
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
void cuda_stream_mix_bwd(nv_bfloat16* d_inp, float* d_M, const float* grad, const nv_bfloat16* inp, const float* M, int64_t B, int64_t T, int n, int64_t C, cudaStream_t stream) {
|
|
94
|
+
mhc::stream_mix_backward(reinterpret_cast<mhc::floatX*>(d_inp), d_M, grad, reinterpret_cast<const mhc::floatX*>(inp), M, B, T, n, C, stream);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
// --- 新增:Stream Aggregate 包装 ---
|
|
98
|
+
void cuda_stream_aggregate_fwd(nv_bfloat16* out, const nv_bfloat16* inp, const float* H_pre, int64_t B, int64_t T, int n, int64_t C, bool per_token, cudaStream_t stream) {
|
|
99
|
+
mhc::stream_aggregate_forward(reinterpret_cast<mhc::floatX*>(out), reinterpret_cast<const mhc::floatX*>(inp), H_pre, B * T, n, C, per_token, stream);
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
void cuda_stream_aggregate_bwd(nv_bfloat16* d_inp, float* d_H_pre, const float* grad, const nv_bfloat16* inp, const float* H_pre, int64_t B, int64_t T, int n, int64_t C, bool per_token, cudaStream_t stream) {
|
|
103
|
+
mhc::stream_aggregate_backward(reinterpret_cast<mhc::floatX*>(d_inp), d_H_pre, grad, reinterpret_cast<const mhc::floatX*>(inp), H_pre, B * T, n, C, per_token, stream);
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
void cuda_stream_distribute_fwd(
|
|
107
|
+
nv_bfloat16* out,
|
|
108
|
+
const nv_bfloat16* inp,
|
|
109
|
+
const float* H,
|
|
110
|
+
int64_t B, int64_t T, int n, int64_t C,
|
|
111
|
+
cudaStream_t stream)
|
|
112
|
+
{
|
|
113
|
+
// 计算总线程数(针对单流输入 B*T*C)
|
|
114
|
+
int64_t total_btc = B * T * C;
|
|
115
|
+
dim3 threads(256);
|
|
116
|
+
// x轴覆盖所有的元素,y轴覆盖流索引 n
|
|
117
|
+
dim3 blocks((total_btc + 255) / 256, (unsigned int)n);
|
|
118
|
+
|
|
119
|
+
mhc::stream_distribute_fwd_kernel<<<blocks, threads, 0, stream>>>(
|
|
120
|
+
reinterpret_cast<mhc::floatX*>(out),
|
|
121
|
+
reinterpret_cast<const mhc::floatX*>(inp),
|
|
122
|
+
H, B, T, n, C
|
|
123
|
+
);
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
void cuda_stream_distribute_bwd(
|
|
127
|
+
nv_bfloat16* d_inp,
|
|
128
|
+
float* d_H,
|
|
129
|
+
const nv_bfloat16* grad,
|
|
130
|
+
const nv_bfloat16* inp,
|
|
131
|
+
const float* H,
|
|
132
|
+
int64_t B, int64_t T, int n, int64_t C,
|
|
133
|
+
cudaStream_t stream)
|
|
134
|
+
{
|
|
135
|
+
// 1. 计算 dx [B, T, C]
|
|
136
|
+
int64_t total_btc = B * T * C;
|
|
137
|
+
dim3 threads_dx(256);
|
|
138
|
+
dim3 blocks_dx((total_btc + 255) / 256);
|
|
139
|
+
|
|
140
|
+
mhc::stream_distribute_bwd_dx_kernel<<<blocks_dx, threads_dx, 0, stream>>>(
|
|
141
|
+
reinterpret_cast<mhc::floatX*>(d_inp),
|
|
142
|
+
reinterpret_cast<const mhc::floatX*>(grad),
|
|
143
|
+
H, B, T, n, C
|
|
144
|
+
);
|
|
145
|
+
|
|
146
|
+
// 2. 计算 dH [B, T, n]
|
|
147
|
+
// 每个 (bt, i) 分配一个 block 进行通道 C 维度的规约
|
|
148
|
+
dim3 threads_dh(256);
|
|
149
|
+
dim3 blocks_dh((unsigned int)(B * T), (unsigned int)n);
|
|
150
|
+
|
|
151
|
+
mhc::stream_distribute_bwd_dh_kernel<256><<<blocks_dh, threads_dh, 0, stream>>>(
|
|
152
|
+
d_H,
|
|
153
|
+
reinterpret_cast<const mhc::floatX*>(grad),
|
|
154
|
+
reinterpret_cast<const mhc::floatX*>(inp),
|
|
155
|
+
B, T, n, C
|
|
156
|
+
);
|
|
157
|
+
}
|
|
158
|
+
void cuda_mhc_pre_op_fwd(
|
|
159
|
+
nv_bfloat16* x_layer_in,
|
|
160
|
+
float* H_pre,
|
|
161
|
+
float* H_post,
|
|
162
|
+
float* H_res,
|
|
163
|
+
const nv_bfloat16* x_expanded,
|
|
164
|
+
const float* h_pre_raw,
|
|
165
|
+
const float* h_post_raw,
|
|
166
|
+
const float* h_res_raw,
|
|
167
|
+
int64_t B, int64_t T, int n, int64_t C,
|
|
168
|
+
int sinkhorn_iters, float eps, cudaStream_t stream)
|
|
169
|
+
{
|
|
170
|
+
// 调用 .cuh 中的融合接口
|
|
171
|
+
mhc::mhc_pre_op_forward(
|
|
172
|
+
reinterpret_cast<mhc::floatX*>(x_layer_in),
|
|
173
|
+
H_pre, H_post, H_res,
|
|
174
|
+
reinterpret_cast<const mhc::floatX*>(x_expanded),
|
|
175
|
+
h_pre_raw, h_post_raw, h_res_raw,
|
|
176
|
+
B, T, n, C, sinkhorn_iters, eps, stream
|
|
177
|
+
);
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
void cuda_mhc_pre_op_bwd(
|
|
181
|
+
nv_bfloat16* d_x_expanded,
|
|
182
|
+
float* d_h_pre_raw,
|
|
183
|
+
float* d_h_post_raw,
|
|
184
|
+
float* d_h_res_raw,
|
|
185
|
+
const nv_bfloat16* grad_layer_in,
|
|
186
|
+
const float* grad_H_post,
|
|
187
|
+
const float* grad_H_res,
|
|
188
|
+
const nv_bfloat16* x_expanded,
|
|
189
|
+
const float* H_pre,
|
|
190
|
+
const float* H_post,
|
|
191
|
+
const float* H_res_out,
|
|
192
|
+
const float* H_res_in_raw,
|
|
193
|
+
int64_t B, int64_t T, int n, int64_t C,
|
|
194
|
+
int sinkhorn_iters, float eps, cudaStream_t stream)
|
|
195
|
+
{
|
|
196
|
+
mhc::mhc_pre_op_backward(
|
|
197
|
+
reinterpret_cast<mhc::floatX*>(d_x_expanded),
|
|
198
|
+
d_h_pre_raw, d_h_post_raw, d_h_res_raw,
|
|
199
|
+
reinterpret_cast<const mhc::floatX*>(grad_layer_in),
|
|
200
|
+
grad_H_post, grad_H_res,
|
|
201
|
+
reinterpret_cast<const mhc::floatX*>(x_expanded),
|
|
202
|
+
H_pre, H_post, H_res_out, H_res_in_raw,
|
|
203
|
+
B, T, n, C, sinkhorn_iters, eps, stream
|
|
204
|
+
);
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
} // namespace mhc
|