rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,193 @@
1
+ from keras import ops
2
+
3
+ # --- 辅助函数:确保在 fp32 下计算以保证数值稳定性 ---
4
+
5
+
6
+ def fp32_sigmoid(x):
7
+ dtype = x.dtype
8
+ return ops.cast(ops.nn.sigmoid(ops.cast(x, "float32")), dtype)
9
+
10
+
11
+ # --- 核心 MHC 算子 ---
12
+
13
+
14
+ def sinkhorn_knopp(inp, num_iters=20, eps=1e-8):
15
+ """
16
+ 将输入矩阵投影为双拟随机矩阵 (Doubly Stochastic Matrix)。
17
+ 通常 inp 是 log 域的矩阵 (H_res_raw)。
18
+ """
19
+ dtype = inp.dtype
20
+ # 转换到 fp32 并应用 exp (论文 Eq. 9 之前的步骤)
21
+ x = ops.cast(inp, "float32")
22
+ # 防溢出技巧:减去最大值
23
+ x = x - ops.max(x, axis=(-1, -2), keepdims=True)
24
+ P = ops.exp(x)
25
+
26
+ for _ in range(num_iters):
27
+ # 行归一化
28
+ P = P / (ops.sum(P, axis=-1, keepdims=True) + eps)
29
+ # 列归一化
30
+ P = P / (ops.sum(P, axis=-2, keepdims=True) + eps)
31
+
32
+ return ops.cast(P, dtype)
33
+
34
+
35
+ def rmsnorm(inp, eps=1e-5):
36
+ """
37
+ 标准 RMSNorm 算子。
38
+ inp: [..., C], weight: [C]
39
+ """
40
+ dtype = inp.dtype
41
+ x = ops.cast(inp, "float32")
42
+ # 计算均方根
43
+ rms = ops.sqrt(ops.mean(ops.square(x), axis=-1, keepdims=True) + eps)
44
+ x_normed = x / rms
45
+ # 应用权重
46
+ return ops.cast(x_normed, dtype)
47
+
48
+
49
+ def stream_aggregate(inp, H_pre):
50
+ # 1. 转换为 float32 进行高精度计算
51
+ inp_f32 = ops.cast(inp, "float32")
52
+ H_f32 = ops.cast(H_pre, "float32")
53
+
54
+ # 2. 在 float32 空间完成乘法和累加
55
+ out_f32 = ops.sum(inp_f32 * ops.expand_dims(H_f32, -1), axis=-2)
56
+
57
+ # 3. 最后转回原先的格式 (如 bf16)
58
+ return ops.cast(out_f32, inp.dtype)
59
+
60
+
61
+ def stream_distribute(inp, H_post):
62
+ """
63
+ Distribute (1 -> n): 将单流输出分发回多流。
64
+ 对齐精度版:强制在 FP32 下进行广播乘法。
65
+
66
+ inp: [B, T, C] (BF16)
67
+ H_post: [B, T, n] (FP32)
68
+ """
69
+ # 1. 记录原始类型
70
+ original_dtype = inp.dtype
71
+
72
+ # 2. 提升到 FP32 进行运算 (对齐 CUDA 内核内部的 to_float 逻辑)
73
+ # [B, T, 1, C]
74
+ x_fp32 = ops.cast(ops.expand_dims(inp, -2), "float32")
75
+
76
+ # [B, T, n, 1]
77
+ w_fp32 = ops.cast(ops.expand_dims(H_post, -1), "float32")
78
+
79
+ # 3. 执行广播乘法
80
+ # 结果为 [B, T, n, C]
81
+ res_fp32 = x_fp32 * w_fp32
82
+
83
+ # 4. 转回原始类型 (对齐 CUDA 内核末尾的 to_bf 逻辑)
84
+ return ops.cast(res_fp32, original_dtype)
85
+
86
+
87
+ def stream_mix(inp, M):
88
+ """
89
+ Mix (n -> n): 残差流之间的线性交互。
90
+ inp: [B, T, n, C]
91
+ M: [B, T, n, n] 或 [n, n] (由 sinkhorn_knopp 生成的方阵)
92
+ """
93
+ # 使用 einsum 表达矩阵乘法:M @ inp
94
+ # i,j 是流索引,k 是通道索引
95
+ dtype = inp.dtype
96
+ inp = ops.cast(inp, M.dtype)
97
+ if len(ops.shape(M)) == 2:
98
+ out = ops.einsum("ij,btjk->btik", M, inp)
99
+ else:
100
+ out = ops.einsum("btij,btjk->btik", M, inp)
101
+ return ops.cast(out, dtype)
102
+
103
+
104
+ def stream_mix_fp32(x_expanded, H_res):
105
+ """内部强制使用 FP32 计算的流混合"""
106
+ # x_expanded: [B, T, n, C], H_res: [B, T, n, n]
107
+ x_f32 = ops.cast(x_expanded, "float32")
108
+ h_f32 = ops.cast(H_res, "float32")
109
+ # 执行矩阵乘法: [B, T, n, n] @ [B, T, n, C] -> [B, T, n, C]
110
+ return ops.matmul(h_f32, x_f32)
111
+
112
+
113
+ def stream_distribute_fp32(layer_out, H_post):
114
+ """内部强制使用 FP32 计算的分发"""
115
+ # layer_out: [B, T, C], H_post: [B, T, n]
116
+ l_f32 = ops.cast(layer_out, "float32")
117
+ h_f32 = ops.cast(H_post, "float32")
118
+
119
+ # [B, T, 1, C] * [B, T, n, 1] -> [B, T, n, C]
120
+ return ops.expand_dims(l_f32, -2) * ops.expand_dims(h_f32, -1)
121
+
122
+
123
+ def mhc_post_op(layer_out, x_expanded, H_post, H_res):
124
+ """
125
+ mHC 后处理融合算子
126
+ 输入:
127
+ layer_out: [B, T, C] - 核心层 (Attention/FFN) 处理后的输出
128
+ x_expanded: [B, T, n, C] - 之前的扩展残差流 (Pre-Op 之前的状态)
129
+ H_post: [B, T, n] - 分发权重 (来自 Pre-Op),2*sigmoid后的数值
130
+ H_res: [B, T, n, n] - 流混合矩阵 (来自 Pre-Op)
131
+ 返回:
132
+ x_next: [B, T, n, C] - 更新后的扩展残差流
133
+ """
134
+ # 1. 在 FP32 下计算混合路径
135
+ x_mixed_f32 = stream_mix_fp32(x_expanded, H_res)
136
+
137
+ # 2. 在 FP32 下计算增量路径
138
+ x_delta_f32 = stream_distribute_fp32(layer_out, H_post)
139
+
140
+ # 3. 在 FP32 下完成最后的残差加法
141
+ x_next_f32 = x_mixed_f32 + x_delta_f32
142
+
143
+ # 4. 只在最后输出时进行一次 BF16 转换
144
+ # 这一步对应 CUDA 内核中最后的 to_bf()
145
+ return ops.cast(x_next_f32, x_expanded.dtype)
146
+
147
+
148
+ def mhc_pre_op(x_expanded, h_pre_raw, h_post_raw, h_res_raw, num_iters=20, eps=1e-8):
149
+ """
150
+ mHC 前处理融合算子
151
+ 输入:
152
+ x_expanded: [B, T, n, C] - 当前的扩展残差流
153
+ h_pre_raw, h_post_raw: [B, T, n] - 线性投影后的原始激活值
154
+ h_res_raw: [B, T, n*n] - 用于生成 Sinkhorn 矩阵的原始值
155
+ 返回:
156
+ x_layer_in: [B, T, C] - 聚合后准备进入 Layer (Attention/FFN) 的输入
157
+ H_post: [B, T, n] - 激活后的分发权重
158
+ H_res: [B, T, n, n] - 经过流形约束后的混合矩阵
159
+ """
160
+ original_dtype = x_expanded.dtype
161
+ B, T, n, C = ops.shape(x_expanded)
162
+
163
+ # --- 0. 提升精度 ---
164
+ x_exp_f32 = ops.cast(x_expanded, "float32")
165
+ h_pre_f32 = ops.cast(h_pre_raw, "float32")
166
+ h_post_f32 = ops.cast(h_post_raw, "float32")
167
+ h_res_f32 = ops.cast(h_res_raw, "float32")
168
+
169
+ # --- 1. Stream Aggregate (n -> 1) ---
170
+ # H_pre = sigmoid(h_pre)
171
+ H_pre_f32 = ops.nn.sigmoid(h_pre_f32)
172
+
173
+ # x_layer_in = sum(H_pre_i * x_expanded_i)
174
+ # 使用 expand_dims 确保广播正确:[B, T, n, 1] * [B, T, n, C] -> [B, T, n, C] -> sum -> [B, T, C]
175
+ x_layer_in_f32 = ops.sum(ops.expand_dims(H_pre_f32, -1) * x_exp_f32, axis=-2)
176
+
177
+ # --- 2. H_post 计算 (带 2.0 缩放) ---
178
+ # 根据论文,这里使用 2.0 * sigmoid 确保恒等映射的初始化稳定性
179
+ H_post_f32 = 2.0 * ops.nn.sigmoid(h_post_f32)
180
+
181
+ # --- 3. Sinkhorn-Knopp (n x n 投影) ---
182
+ h_res_reshaped_f32 = ops.reshape(h_res_f32, (B, T, n, n))
183
+
184
+ # 内部执行 sinkhorn 迭代(确保 sinkhorn_knopp 内部也是 float32)
185
+ H_res_f32 = sinkhorn_knopp(h_res_reshaped_f32, num_iters=num_iters, eps=eps)
186
+
187
+ # --- 4. 转换回原始格式 ---
188
+ # 模拟 CUDA Kernel 最后写回显存时的 cast 操作
189
+ return (
190
+ ops.cast(x_layer_in_f32, original_dtype),
191
+ ops.cast(H_post_f32, "float32"), # H 权重通常在模型中保持 FP32 精度
192
+ ops.cast(H_res_f32, "float32"),
193
+ )
@@ -0,0 +1,207 @@
1
+ #include <cuda_runtime.h>
2
+ #include <cuda_bf16.h>
3
+ #include "../common_kernel/include/mhc_types.h"
4
+ #include "../common_kernel/kernels/sinkhorn_knopp.cuh"
5
+ #include "../common_kernel/kernels/rmsnorm.cuh"
6
+ #include "../common_kernel/kernels/stream_mix.cuh"
7
+ #include "../common_kernel/kernels/stream_aggregate.cuh"
8
+ #include "../common_kernel/kernels/stream_distribute.cuh"
9
+ #include "../common_kernel/kernels/mhc_post_op.cuh"
10
+ #include "../common_kernel/kernels/mhc_pre_op.cuh"
11
+
12
+ namespace mhc {
13
+ // --- Post-Op 融合前向 ---
14
+ void cuda_mhc_post_op_fwd(
15
+ nv_bfloat16* out, // [B, T, n, C]
16
+ const nv_bfloat16* layer_out, // [B, T, C]
17
+ const nv_bfloat16* x_expanded, // [B, T, n, C]
18
+ const float* H_post, // [B, T, n]
19
+ const float* H_res, // [B, T, n, n]
20
+ int64_t B, int64_t T, int n, int64_t C,
21
+ cudaStream_t stream)
22
+ {
23
+ // 调用 .cuh 中的 inline 包装函数
24
+ mhc::mhc_post_op_forward(
25
+ reinterpret_cast<mhc::floatX*>(out),
26
+ reinterpret_cast<const mhc::floatX*>(layer_out),
27
+ reinterpret_cast<const mhc::floatX*>(x_expanded),
28
+ H_post,
29
+ H_res,
30
+ B, T, n, C,
31
+ stream
32
+ );
33
+ }
34
+
35
+ // --- 反向传播包装 (融合版) ---
36
+ void cuda_mhc_post_op_bwd(
37
+ nv_bfloat16* d_layer_out, // [B, T, C]
38
+ nv_bfloat16* d_x_expanded, // [B, T, n, C]
39
+ float* d_H_post, // [B, T, n]
40
+ float* d_H_res, // [B, T, n, n]
41
+ const nv_bfloat16* grad_next, // [B, T, n, C]
42
+ const nv_bfloat16* layer_out, // [B, T, C]
43
+ const nv_bfloat16* x_expanded, // [B, T, n, C]
44
+ const float* H_post,
45
+ const float* H_res,
46
+ int64_t B, int64_t T, int n, int64_t C,
47
+ cudaStream_t stream)
48
+ {
49
+ // 调用 .cuh 中的全量反向融合内核
50
+ // 该内核内部会处理 dx, dl, dH_post, dH_res 的全部逻辑
51
+ mhc::mhc_post_op_backward_full(
52
+ reinterpret_cast<mhc::floatX*>(d_layer_out),
53
+ reinterpret_cast<mhc::floatX*>(d_x_expanded),
54
+ d_H_post,
55
+ d_H_res,
56
+ reinterpret_cast<const mhc::floatX*>(grad_next),
57
+ reinterpret_cast<const mhc::floatX*>(layer_out),
58
+ reinterpret_cast<const mhc::floatX*>(x_expanded),
59
+ H_post,
60
+ H_res,
61
+ B, T, n, C,
62
+ stream
63
+ );
64
+ }
65
+
66
+ // --- Sinkhorn 包装 ---
67
+ void cuda_sinkhorn_fwd(float* out, const float* inp, int64_t B, int64_t M, int64_t N, int iters, float eps, cudaStream_t stream) {
68
+ for (int64_t b = 0; b < B; b++) {
69
+ mhc::sinkhorn_knopp_forward(out + b * M * N, inp + b * M * N, (int)M, (int)N, iters, eps, stream);
70
+ }
71
+ }
72
+
73
+ void cuda_sinkhorn_bwd(float* d_inp, const float* grad, const float* M_out, const float* M_inp, int64_t B, int64_t N, int iters, float eps, cudaStream_t stream) {
74
+ for (int64_t b = 0; b < B; b++) {
75
+ mhc::sinkhorn_knopp_backward(d_inp + b * N * N, grad + b * N * N, M_out + b * N * N, M_inp + b * N * N, (int)N, iters, eps, stream);
76
+ }
77
+ }
78
+
79
+ // --- RMSNorm 包装 ---
80
+ void cuda_rmsnorm_fwd(nv_bfloat16* out, const nv_bfloat16* inp, int64_t N, int64_t C, float eps, cudaStream_t stream) {
81
+ mhc::rmsnorm_forward(reinterpret_cast<mhc::floatX*>(out), reinterpret_cast<const mhc::floatX*>(inp), N, C, eps, stream);
82
+ }
83
+
84
+ void cuda_rmsnorm_bwd(nv_bfloat16* dx, const nv_bfloat16* grad, const nv_bfloat16* x, int64_t N, int64_t C, float eps, cudaStream_t stream) {
85
+ mhc::rmsnorm_backward(reinterpret_cast<mhc::floatX*>(dx), reinterpret_cast<const mhc::floatX*>(grad), reinterpret_cast<const mhc::floatX*>(x), N, C, eps, stream);
86
+ }
87
+
88
+ // --- Stream Mix 包装 ---
89
+ void cuda_stream_mix_fwd(nv_bfloat16* out, const nv_bfloat16* inp, const float* M, int64_t B, int64_t T, int n, int64_t C, cudaStream_t stream) {
90
+ mhc::stream_mix_forward(reinterpret_cast<mhc::floatX*>(out), reinterpret_cast<const mhc::floatX*>(inp), M, B, T, n, C, stream);
91
+ }
92
+
93
+ void cuda_stream_mix_bwd(nv_bfloat16* d_inp, float* d_M, const float* grad, const nv_bfloat16* inp, const float* M, int64_t B, int64_t T, int n, int64_t C, cudaStream_t stream) {
94
+ mhc::stream_mix_backward(reinterpret_cast<mhc::floatX*>(d_inp), d_M, grad, reinterpret_cast<const mhc::floatX*>(inp), M, B, T, n, C, stream);
95
+ }
96
+
97
+ // --- 新增:Stream Aggregate 包装 ---
98
+ void cuda_stream_aggregate_fwd(nv_bfloat16* out, const nv_bfloat16* inp, const float* H_pre, int64_t B, int64_t T, int n, int64_t C, bool per_token, cudaStream_t stream) {
99
+ mhc::stream_aggregate_forward(reinterpret_cast<mhc::floatX*>(out), reinterpret_cast<const mhc::floatX*>(inp), H_pre, B * T, n, C, per_token, stream);
100
+ }
101
+
102
+ void cuda_stream_aggregate_bwd(nv_bfloat16* d_inp, float* d_H_pre, const float* grad, const nv_bfloat16* inp, const float* H_pre, int64_t B, int64_t T, int n, int64_t C, bool per_token, cudaStream_t stream) {
103
+ mhc::stream_aggregate_backward(reinterpret_cast<mhc::floatX*>(d_inp), d_H_pre, grad, reinterpret_cast<const mhc::floatX*>(inp), H_pre, B * T, n, C, per_token, stream);
104
+ }
105
+
106
+ void cuda_stream_distribute_fwd(
107
+ nv_bfloat16* out,
108
+ const nv_bfloat16* inp,
109
+ const float* H,
110
+ int64_t B, int64_t T, int n, int64_t C,
111
+ cudaStream_t stream)
112
+ {
113
+ // 计算总线程数(针对单流输入 B*T*C)
114
+ int64_t total_btc = B * T * C;
115
+ dim3 threads(256);
116
+ // x轴覆盖所有的元素,y轴覆盖流索引 n
117
+ dim3 blocks((total_btc + 255) / 256, (unsigned int)n);
118
+
119
+ mhc::stream_distribute_fwd_kernel<<<blocks, threads, 0, stream>>>(
120
+ reinterpret_cast<mhc::floatX*>(out),
121
+ reinterpret_cast<const mhc::floatX*>(inp),
122
+ H, B, T, n, C
123
+ );
124
+ }
125
+
126
+ void cuda_stream_distribute_bwd(
127
+ nv_bfloat16* d_inp,
128
+ float* d_H,
129
+ const nv_bfloat16* grad,
130
+ const nv_bfloat16* inp,
131
+ const float* H,
132
+ int64_t B, int64_t T, int n, int64_t C,
133
+ cudaStream_t stream)
134
+ {
135
+ // 1. 计算 dx [B, T, C]
136
+ int64_t total_btc = B * T * C;
137
+ dim3 threads_dx(256);
138
+ dim3 blocks_dx((total_btc + 255) / 256);
139
+
140
+ mhc::stream_distribute_bwd_dx_kernel<<<blocks_dx, threads_dx, 0, stream>>>(
141
+ reinterpret_cast<mhc::floatX*>(d_inp),
142
+ reinterpret_cast<const mhc::floatX*>(grad),
143
+ H, B, T, n, C
144
+ );
145
+
146
+ // 2. 计算 dH [B, T, n]
147
+ // 每个 (bt, i) 分配一个 block 进行通道 C 维度的规约
148
+ dim3 threads_dh(256);
149
+ dim3 blocks_dh((unsigned int)(B * T), (unsigned int)n);
150
+
151
+ mhc::stream_distribute_bwd_dh_kernel<256><<<blocks_dh, threads_dh, 0, stream>>>(
152
+ d_H,
153
+ reinterpret_cast<const mhc::floatX*>(grad),
154
+ reinterpret_cast<const mhc::floatX*>(inp),
155
+ B, T, n, C
156
+ );
157
+ }
158
+ void cuda_mhc_pre_op_fwd(
159
+ nv_bfloat16* x_layer_in,
160
+ float* H_pre,
161
+ float* H_post,
162
+ float* H_res,
163
+ const nv_bfloat16* x_expanded,
164
+ const float* h_pre_raw,
165
+ const float* h_post_raw,
166
+ const float* h_res_raw,
167
+ int64_t B, int64_t T, int n, int64_t C,
168
+ int sinkhorn_iters, float eps, cudaStream_t stream)
169
+ {
170
+ // 调用 .cuh 中的融合接口
171
+ mhc::mhc_pre_op_forward(
172
+ reinterpret_cast<mhc::floatX*>(x_layer_in),
173
+ H_pre, H_post, H_res,
174
+ reinterpret_cast<const mhc::floatX*>(x_expanded),
175
+ h_pre_raw, h_post_raw, h_res_raw,
176
+ B, T, n, C, sinkhorn_iters, eps, stream
177
+ );
178
+ }
179
+
180
+ void cuda_mhc_pre_op_bwd(
181
+ nv_bfloat16* d_x_expanded,
182
+ float* d_h_pre_raw,
183
+ float* d_h_post_raw,
184
+ float* d_h_res_raw,
185
+ const nv_bfloat16* grad_layer_in,
186
+ const float* grad_H_post,
187
+ const float* grad_H_res,
188
+ const nv_bfloat16* x_expanded,
189
+ const float* H_pre,
190
+ const float* H_post,
191
+ const float* H_res_out,
192
+ const float* H_res_in_raw,
193
+ int64_t B, int64_t T, int n, int64_t C,
194
+ int sinkhorn_iters, float eps, cudaStream_t stream)
195
+ {
196
+ mhc::mhc_pre_op_backward(
197
+ reinterpret_cast<mhc::floatX*>(d_x_expanded),
198
+ d_h_pre_raw, d_h_post_raw, d_h_res_raw,
199
+ reinterpret_cast<const mhc::floatX*>(grad_layer_in),
200
+ grad_H_post, grad_H_res,
201
+ reinterpret_cast<const mhc::floatX*>(x_expanded),
202
+ H_pre, H_post, H_res_out, H_res_in_raw,
203
+ B, T, n, C, sinkhorn_iters, eps, stream
204
+ );
205
+ }
206
+
207
+ } // namespace mhc