rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,152 @@
1
+ #pragma once
2
+
3
+ #include <cuda_runtime.h>
4
+ #include <cuda_bf16.h>
5
+ #include <cooperative_groups.h>
6
+ #include <cooperative_groups/reduce.h>
7
+ #include "../include/mhc_types.h"
8
+
9
+ namespace cg = cooperative_groups;
10
+
11
+ namespace mhc {
12
+
13
+ /* -------------------- 前向传播 (修正索引溢出) -------------------- */
14
+ template<int BLOCK_SIZE>
15
+ __global__ void rmsnorm_fwd_kernel(
16
+ floatX* __restrict__ out,
17
+ const floatX* __restrict__ inp,
18
+ int N, int C, float eps) {
19
+
20
+ cg::thread_block block = cg::this_thread_block();
21
+ cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
22
+
23
+ // 使用 int64_t 接收 blockIdx.x
24
+ int64_t row_idx = blockIdx.x;
25
+ if (row_idx >= N) return;
26
+
27
+ // 关键修正:强制 size_t 运算,防止 N * C 溢出
28
+ size_t offset = (size_t)row_idx * C;
29
+ const floatX* x_ptr = inp + offset;
30
+ floatX* o_ptr = out + offset;
31
+
32
+ // 共享内存用于存储每个 Warp 的局部和
33
+ extern __shared__ float s_reduce[];
34
+
35
+ // 1. 计算平方和 (使用 FP32 累加)
36
+ float thread_sum_sq = 0.0f;
37
+ for (int i = threadIdx.x; i < C; i += BLOCK_SIZE) {
38
+ float val = mhc::to_float(x_ptr[i]);
39
+ thread_sum_sq += val * val;
40
+ }
41
+
42
+ // 2. Warp 级规约
43
+ float warp_sum = cg::reduce(warp, thread_sum_sq, cg::plus<float>());
44
+
45
+ int warp_id = threadIdx.x / 32;
46
+ int lane_id = threadIdx.x % 32;
47
+ if (lane_id == 0) s_reduce[warp_id] = warp_sum;
48
+ block.sync();
49
+
50
+ // 3. Block 级规约
51
+ if (warp_id == 0) {
52
+ float b_sum = (lane_id < (BLOCK_SIZE / 32)) ? s_reduce[lane_id] : 0.0f;
53
+ b_sum = cg::reduce(warp, b_sum, cg::plus<float>());
54
+ if (lane_id == 0) s_reduce[0] = b_sum;
55
+ }
56
+ block.sync();
57
+
58
+ // 4. 计算 RMS 逆
59
+ float rms_inv = rsqrtf((s_reduce[0] / (float)C) + eps);
60
+
61
+ // 5. 写回结果
62
+ for (int i = threadIdx.x; i < C; i += BLOCK_SIZE) {
63
+ float val = mhc::to_float(x_ptr[i]);
64
+ o_ptr[i] = mhc::to_bf(val * rms_inv);
65
+ }
66
+ }
67
+
68
+ /* -------------------- 反向传播 (修正索引溢出) -------------------- */
69
+ template<int BLOCK_SIZE>
70
+ __global__ void rmsnorm_bwd_kernel(
71
+ floatX* __restrict__ dx,
72
+ const floatX* __restrict__ grad,
73
+ const floatX* __restrict__ x,
74
+ int N, int C, float eps) {
75
+
76
+ cg::thread_block block = cg::this_thread_block();
77
+ cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
78
+
79
+ int64_t row_idx = blockIdx.x;
80
+ if (row_idx >= N) return;
81
+
82
+ // 关键修正:强制 size_t 运算
83
+ size_t offset = (size_t)row_idx * C;
84
+ const floatX* g_ptr = grad + offset;
85
+ const floatX* x_ptr = x + offset;
86
+ floatX* dx_ptr = dx + offset;
87
+
88
+ extern __shared__ float s_mem[];
89
+ int num_warps = BLOCK_SIZE / 32;
90
+ float* s_sum_sq = s_mem;
91
+ float* s_dot = s_mem + num_warps;
92
+
93
+ // 1. 局部累加
94
+ float t_sum_sq = 0.0f;
95
+ float t_dot = 0.0f;
96
+ for (int i = threadIdx.x; i < C; i += BLOCK_SIZE) {
97
+ float xv = mhc::to_float(x_ptr[i]);
98
+ float gv = mhc::to_float(g_ptr[i]);
99
+ t_sum_sq += xv * xv;
100
+ t_dot += gv * xv;
101
+ }
102
+
103
+ // 2. Warp 级规约
104
+ float w_sum = cg::reduce(warp, t_sum_sq, cg::plus<float>());
105
+ float w_dot = cg::reduce(warp, t_dot, cg::plus<float>());
106
+
107
+ int warp_id = threadIdx.x / 32;
108
+ int lane_id = threadIdx.x % 32;
109
+ if (lane_id == 0) {
110
+ s_sum_sq[warp_id] = w_sum;
111
+ s_dot[warp_id] = w_dot;
112
+ }
113
+ block.sync();
114
+
115
+ // 3. Block 级规约
116
+ if (warp_id == 0) {
117
+ float v1 = (lane_id < num_warps) ? s_sum_sq[lane_id] : 0.0f;
118
+ float v2 = (lane_id < num_warps) ? s_dot[lane_id] : 0.0f;
119
+ s_sum_sq[0] = cg::reduce(warp, v1, cg::plus<float>());
120
+ s_dot[0] = cg::reduce(warp, v2, cg::plus<float>());
121
+ }
122
+ block.sync();
123
+
124
+ // 4. 计算中间项
125
+ float r2_inv = 1.0f / (s_sum_sq[0] / (float)C + eps);
126
+ float rms_inv = sqrtf(r2_inv);
127
+ float projection = s_dot[0] * (r2_inv * rms_inv) / (float)C;
128
+
129
+ // 5. 应用公式并写回
130
+ for (int i = threadIdx.x; i < C; i += BLOCK_SIZE) {
131
+ float xv = mhc::to_float(x_ptr[i]);
132
+ float gv = mhc::to_float(g_ptr[i]);
133
+ dx_ptr[i] = mhc::to_bf(gv * rms_inv - xv * projection);
134
+ }
135
+ }
136
+
137
+ /* -------------------- 包装函数 -------------------- */
138
+
139
+ inline void rmsnorm_forward(floatX* out, const floatX* inp, int N, int C, float eps, cudaStream_t stream) {
140
+ const int BLOCK_SIZE = 256;
141
+ size_t smem = (BLOCK_SIZE / 32) * sizeof(float);
142
+ // Grid size 使用 int64 兼容的 N
143
+ rmsnorm_fwd_kernel<BLOCK_SIZE><<<N, BLOCK_SIZE, smem, stream>>>(out, inp, N, C, eps);
144
+ }
145
+
146
+ inline void rmsnorm_backward(floatX* dx, const floatX* grad, const floatX* x, int N, int C, float eps, cudaStream_t stream) {
147
+ const int BLOCK_SIZE = 256;
148
+ size_t smem = (BLOCK_SIZE / 32) * 2 * sizeof(float);
149
+ rmsnorm_bwd_kernel<BLOCK_SIZE><<<N, BLOCK_SIZE, smem, stream>>>(dx, grad, x, N, C, eps);
150
+ }
151
+
152
+ } // namespace mhc
@@ -0,0 +1,158 @@
1
+ #pragma once
2
+
3
+ #include <cuda_runtime.h>
4
+ #include <cuda_bf16.h>
5
+ #include <cooperative_groups.h>
6
+ #include "../include/mhc_types.h"
7
+ #include "type_conversions.cuh"
8
+
9
+ namespace mhc {
10
+
11
+ /**
12
+ * 1. 前向传播 Kernel (保持不变,但确保 eps 使用一致)
13
+ */
14
+ template<int BLOCK_SIZE>
15
+ __global__ void sinkhorn_knopp_fwd_kernel(
16
+ float* __restrict__ out,
17
+ const float* __restrict__ inp,
18
+ int M, int N, int num_iters, float eps) {
19
+
20
+ extern __shared__ float smem[];
21
+ float* tile = smem;
22
+ float* row_sums = smem + M * N;
23
+ float* col_sums = row_sums + M;
24
+
25
+ int tid = threadIdx.x;
26
+ int64_t total = (int64_t)M * N;
27
+
28
+ for (int64_t i = tid; i < total; i += BLOCK_SIZE) {
29
+ tile[i] = __expf(inp[i]);
30
+ }
31
+ __syncthreads();
32
+
33
+ for (int iter = 0; iter < num_iters; iter++) {
34
+ // 行归一化
35
+ for (int r = tid; r < M; r += BLOCK_SIZE) {
36
+ float sum = 0.0f;
37
+ for (int c = 0; c < N; c++) sum += tile[(int64_t)r * N + c];
38
+ row_sums[r] = 1.0f / (sum + eps);
39
+ }
40
+ __syncthreads();
41
+
42
+ for (int64_t i = tid; i < total; i += BLOCK_SIZE) {
43
+ tile[i] *= row_sums[i / N];
44
+ }
45
+ __syncthreads();
46
+
47
+ // 列归一化
48
+ for (int c = tid; c < N; c += BLOCK_SIZE) {
49
+ float sum = 0.0f;
50
+ for (int r = 0; r < M; r++) sum += tile[(int64_t)r * N + c];
51
+ col_sums[c] = 1.0f / (sum + eps);
52
+ }
53
+ __syncthreads();
54
+
55
+ for (int64_t i = tid; i < total; i += BLOCK_SIZE) {
56
+ tile[i] *= col_sums[i % N];
57
+ }
58
+ __syncthreads();
59
+ }
60
+
61
+ for (int64_t i = tid; i < total; i += BLOCK_SIZE) out[i] = tile[i];
62
+ }
63
+
64
+ /**
65
+ * 2. 反向传播 Kernel (修正:匹配自动微分迭代)
66
+ */
67
+ template<int BLOCK_SIZE>
68
+ __global__ void sinkhorn_knopp_bwd_kernel(
69
+ float* __restrict__ d_inp,
70
+ const float* __restrict__ grad,
71
+ const float* __restrict__ out_fwd,
72
+ int M, int N, int num_iters, float eps) {
73
+
74
+ extern __shared__ float smem[];
75
+ float* P = smem; // 前向输出 P [M*N]
76
+ float* dP = smem + M * N; // 传入梯度 G [M*N]
77
+ float* alpha = smem + 2 * M * N; // 辅助变量 alpha [M]
78
+ float* beta = alpha + M; // 辅助变量 beta [N]
79
+
80
+ int tid = threadIdx.x;
81
+ int64_t total = (int64_t)M * N;
82
+
83
+ // 加载数据
84
+ for (int64_t i = tid; i < total; i += BLOCK_SIZE) {
85
+ P[i] = out_fwd[i];
86
+ dP[i] = grad[i];
87
+ }
88
+ __syncthreads();
89
+
90
+ // 核心:Sinkhorn 梯度的迭代解法 (对应 PyTorch 的循环展开)
91
+ // 我们需要求解 alpha 和 beta 满足:
92
+ // alpha_i = sum_j (P_ij * (dP_ij - beta_j))
93
+ // beta_j = sum_i (P_ij * (dP_ij - alpha_i))
94
+
95
+ // 初始化 alpha, beta 为 0
96
+ for(int i = tid; i < M; i += BLOCK_SIZE) alpha[i] = 0.0f;
97
+ for(int j = tid; j < N; j += BLOCK_SIZE) beta[j] = 0.0f;
98
+ __syncthreads();
99
+
100
+ for (int iter = 0; iter < num_iters; iter++) {
101
+ // 更新 beta
102
+ for (int j = tid; j < N; j += BLOCK_SIZE) {
103
+ float sum = 0.0f;
104
+ for (int i = 0; i < M; i++) {
105
+ sum += P[(int64_t)i * N + j] * (dP[(int64_t)i * N + j] - alpha[i]);
106
+ }
107
+ beta[j] = sum;
108
+ }
109
+ __syncthreads();
110
+
111
+ // 更新 alpha
112
+ for (int i = tid; i < M; i += BLOCK_SIZE) {
113
+ float sum = 0.0f;
114
+ for (int j = 0; j < N; j++) {
115
+ sum += P[(int64_t)i * N + j] * (dP[(int64_t)i * N + j] - beta[j]);
116
+ }
117
+ alpha[i] = sum;
118
+ }
119
+ __syncthreads();
120
+ }
121
+
122
+ // 最终梯度公式: dL/dA = P * (G - alpha - beta)
123
+ for (int64_t i = tid; i < total; i += BLOCK_SIZE) {
124
+ int64_t r = i / N;
125
+ int64_t c = i % N;
126
+ d_inp[i] = P[i] * (dP[i] - alpha[r] - beta[c]);
127
+ }
128
+ }
129
+
130
+ /* -------------------- API 接口 -------------------- */
131
+
132
+ inline void sinkhorn_knopp_forward(
133
+ float* out, const float* inp,
134
+ int M, int N, int num_iters, float eps,
135
+ cudaStream_t stream = nullptr) {
136
+
137
+ const int BLOCK_SIZE = 256;
138
+ size_t smem_size = ((size_t)M * N + M + N) * sizeof(float);
139
+ sinkhorn_knopp_fwd_kernel<BLOCK_SIZE><<<1, BLOCK_SIZE, smem_size, stream>>>(
140
+ out, inp, M, N, num_iters, eps
141
+ );
142
+ }
143
+
144
+ inline void sinkhorn_knopp_backward(
145
+ float* d_inp, const float* grad, const float* M_out, const float* M_inp,
146
+ int N, int num_iters, float eps,
147
+ cudaStream_t stream = nullptr) {
148
+
149
+ const int BLOCK_SIZE = 256;
150
+ // 显存占用: P(N*N) + dP(N*N) + alpha(N) + beta(N)
151
+ size_t smem_size = (2 * (size_t)N * N + 2 * N) * sizeof(float);
152
+
153
+ sinkhorn_knopp_bwd_kernel<BLOCK_SIZE><<<1, BLOCK_SIZE, smem_size, stream>>>(
154
+ d_inp, grad, M_out, N, N, num_iters, eps
155
+ );
156
+ }
157
+
158
+ } // namespace mhc
@@ -0,0 +1,141 @@
1
+ #ifndef MHC_STREAM_AGGREGATE_CUH
2
+ #define MHC_STREAM_AGGREGATE_CUH
3
+
4
+ #include <cuda_runtime.h>
5
+ #include <cuda_bf16.h>
6
+ #include <cooperative_groups.h>
7
+ #include <cooperative_groups/reduce.h>
8
+ #include "../include/mhc_types.h"
9
+
10
+ namespace cg = cooperative_groups;
11
+
12
+ namespace mhc {
13
+
14
+ /**
15
+ * 前向传播:高精度版
16
+ * Out = sum(H_i * X_i)
17
+ */
18
+ template<bool PER_TOKEN_H>
19
+ __global__ void stream_aggregate_fwd_kernel(
20
+ floatX* __restrict__ out,
21
+ const floatX* __restrict__ inp,
22
+ const float* __restrict__ H_pre,
23
+ int64_t BT, int n, int64_t C) {
24
+
25
+ // blockIdx.x 强转为 int64_t,防止 32 位溢出
26
+ int64_t btc = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
27
+ if (btc >= BT * C) return;
28
+
29
+ int64_t bt = btc / C;
30
+ int64_t c = btc % C;
31
+
32
+ // 核心:使用 FP32 寄存器进行所有乘加运算
33
+ float sum = 0.0f;
34
+
35
+ #pragma unroll
36
+ for (int i = 0; i < 8; i++) { // 假设 n <= 8,由 mHC 论文设定
37
+ if (i < n) {
38
+ // [修改]: bt * n 可能超过 32 位,且 n 为 int,显式强转 n 为 int64_t
39
+ float h_val = PER_TOKEN_H ? H_pre[bt * (int64_t)n + i] : H_pre[i];
40
+ // 关键:读取 bf16 后立即转为 fp32 参与运算
41
+ // [修改]: 显式强转 n 为 int64_t,确保 bt * n * C 全程为 64 位运算
42
+ float x_val = to_float(inp[bt * (int64_t)n * C + (int64_t)i * C + c]);
43
+ sum += h_val * x_val;
44
+ }
45
+ }
46
+ // 最后一次性转回 bf16
47
+ out[btc] = to_bf(sum);
48
+ }
49
+
50
+ /**
51
+ * 反向传播 dx: d_inp = d_out * H_pre
52
+ */
53
+ template<bool PER_TOKEN_H>
54
+ __global__ void stream_aggregate_bwd_dx_kernel(
55
+ floatX* __restrict__ d_inp,
56
+ const float* __restrict__ d_out,
57
+ const float* __restrict__ H_pre,
58
+ int64_t BT, int n, int64_t C) {
59
+
60
+ // blockIdx.x 强转为 int64_t
61
+ int64_t btc = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
62
+ int i = blockIdx.y;
63
+
64
+ if (btc < BT * C && i < n) {
65
+ int64_t bt = btc / C;
66
+ // [修改]: bt * n 增加 (int64_t)n 强转
67
+ float h_val = PER_TOKEN_H ? H_pre[bt * (int64_t)n + i] : H_pre[i];
68
+ float grad_val = d_out[btc]; // 接收 FP32 梯度
69
+
70
+ // [修改]: bt * n * C 增加 (int64_t)n 强转
71
+ d_inp[bt * (int64_t)n * C + (int64_t)i * C + (btc % C)] = to_bf(grad_val * h_val);
72
+ }
73
+ }
74
+
75
+ /**
76
+ * 反向传播 dH: d_H = sum_over_C(d_out * inp)
77
+ * 采用并行规约(Parallel Reduction)以保证高精度
78
+ */
79
+ template<int BLOCK_SIZE>
80
+ __global__ void stream_aggregate_bwd_dh_kernel(
81
+ float* __restrict__ d_H_pre,
82
+ const float* __restrict__ d_out,
83
+ const floatX* __restrict__ inp,
84
+ int64_t BT, int n, int64_t C) {
85
+
86
+ cg::thread_block block = cg::this_thread_block();
87
+ cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
88
+
89
+ // blockIdx.x 在 C++ 侧对应 BT,若 BT 很大这里 int64_t 是必要的
90
+ int64_t bt = blockIdx.x;
91
+ int i = blockIdx.y;
92
+ if (bt >= BT || i >= n) return;
93
+
94
+ extern __shared__ float s_reduce[];
95
+
96
+ float thread_sum = 0.0f;
97
+ for (int64_t c = threadIdx.x; c < C; c += BLOCK_SIZE) {
98
+ float g_val = d_out[bt * C + c];
99
+ // [修改]: bt * n * C 增加 (int64_t)n 强转
100
+ float x_val = to_float(inp[bt * (int64_t)n * C + (int64_t)i * C + c]);
101
+ thread_sum += g_val * x_val;
102
+ }
103
+
104
+ float warp_sum = cg::reduce(warp, thread_sum, cg::plus<float>());
105
+ int warp_id = threadIdx.x / 32;
106
+ int lane_id = threadIdx.x % 32;
107
+
108
+ if (lane_id == 0) s_reduce[warp_id] = warp_sum;
109
+ block.sync();
110
+
111
+ if (warp_id == 0) {
112
+ float val = (lane_id < (BLOCK_SIZE / 32)) ? s_reduce[lane_id] : 0.0f;
113
+ float block_sum = cg::reduce(warp, val, cg::plus<float>());
114
+ // [修改]: bt * n 增加 (int64_t)n 强转
115
+ if (lane_id == 0) d_H_pre[bt * (int64_t)n + i] = block_sum;
116
+ }
117
+ }
118
+
119
+ inline void stream_aggregate_forward(floatX* out, const floatX* inp, const float* H_pre, int64_t BT, int n, int64_t C, bool per_token, cudaStream_t stream) {
120
+ dim3 threads(256);
121
+ // 注意: BT * C 如果极大,超过 uint32 范围,dim3.x 会截断,这是 CUDA 硬件限制。
122
+ // 但计算过程本身 BT*C 是 int64 不会溢出。
123
+ dim3 blocks((BT * C + 255) / 256);
124
+ if (per_token) stream_aggregate_fwd_kernel<true><<<blocks, threads, 0, stream>>>(out, inp, H_pre, BT, n, C);
125
+ else stream_aggregate_fwd_kernel<false><<<blocks, threads, 0, stream>>>(out, inp, H_pre, BT, n, C);
126
+ }
127
+
128
+ inline void stream_aggregate_backward(floatX* d_inp, float* d_H_pre, const float* d_out, const floatX* inp, const float* H_pre, int64_t BT, int n, int64_t C, bool per_token, cudaStream_t stream) {
129
+ dim3 threads_dx(256);
130
+ dim3 blocks_dx((BT * C + 255) / 256, n);
131
+ if (per_token) stream_aggregate_bwd_dx_kernel<true><<<blocks_dx, threads_dx, 0, stream>>>(d_inp, d_out, H_pre, BT, n, C);
132
+ else stream_aggregate_bwd_dx_kernel<false><<<blocks_dx, threads_dx, 0, stream>>>(d_inp, d_out, H_pre, BT, n, C);
133
+
134
+ constexpr int DH_BLOCK_SIZE = 256;
135
+ dim3 grid_dh(BT, n);
136
+ size_t smem_size = (DH_BLOCK_SIZE / 32) * sizeof(float);
137
+ stream_aggregate_bwd_dh_kernel<DH_BLOCK_SIZE><<<grid_dh, DH_BLOCK_SIZE, smem_size, stream>>>(d_H_pre, d_out, inp, BT, n, C);
138
+ }
139
+
140
+ } // namespace mhc
141
+ #endif
@@ -0,0 +1,111 @@
1
+ #ifndef MHC_STREAM_DISTRIBUTE_CUH
2
+ #define MHC_STREAM_DISTRIBUTE_CUH
3
+
4
+ #include <cuda_runtime.h>
5
+ #include <cuda_bf16.h>
6
+ #include <cooperative_groups.h>
7
+ #include <cooperative_groups/reduce.h>
8
+ #include "../include/mhc_types.h"
9
+
10
+ namespace cg = cooperative_groups;
11
+
12
+ namespace mhc {
13
+
14
+ /**
15
+ * Forward: Out = Inp * H_post
16
+ * Shape: Inp [B, T, C], H_post [B, T, n] -> Out [B, T, n, C]
17
+ */
18
+ __global__ void stream_distribute_fwd_kernel(
19
+ floatX* __restrict__ out,
20
+ const floatX* __restrict__ inp,
21
+ const float* __restrict__ H_post,
22
+ int64_t B, int64_t T, int n, int64_t C) {
23
+
24
+ // 索引加固:强制使用 int64_t 防止 (B*T*C) 超过 21 亿
25
+ int64_t btc = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
26
+ int i = blockIdx.y; // 流索引
27
+
28
+ if (btc < B * T * C && i < n) {
29
+ int64_t bt = btc / C;
30
+ int64_t c = btc % C;
31
+
32
+ float val = to_float(inp[btc]);
33
+ float weight = H_post[bt * n + i];
34
+
35
+ // 计算 64 位偏移量
36
+ int64_t target_idx = bt * n * C + (int64_t)i * C + c;
37
+ out[target_idx] = to_bf(val * weight);
38
+ }
39
+ }
40
+
41
+ /**
42
+ * Backward dx: dx = sum_i(grad_i * H_post_i)
43
+ * 精度:FP32 累加
44
+ */
45
+ __global__ void stream_distribute_bwd_dx_kernel(
46
+ floatX* __restrict__ dx,
47
+ const floatX* __restrict__ grad,
48
+ const float* __restrict__ H_post,
49
+ int64_t B, int64_t T, int n, int64_t C) {
50
+
51
+ int64_t btc = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
52
+ if (btc >= B * T * C) return;
53
+
54
+ int64_t bt = btc / C;
55
+ int64_t c = btc % C;
56
+
57
+ float sum = 0.0f;
58
+ for (int i = 0; i < n; i++) {
59
+ // 强制 64 位偏移计算
60
+ float g = to_float(grad[bt * n * C + (int64_t)i * C + c]);
61
+ float w = H_post[bt * n + i];
62
+ sum += g * w;
63
+ }
64
+ dx[btc] = to_bf(sum);
65
+ }
66
+
67
+ /**
68
+ * Backward dH: dH = sum_c(grad * inp)
69
+ * 精度:利用 Warp Shuffle 在通道维度 C 上进行全精度规约
70
+ */
71
+ template<int BLOCK_SIZE>
72
+ __global__ void stream_distribute_bwd_dh_kernel(
73
+ float* __restrict__ d_H_post,
74
+ const floatX* __restrict__ grad,
75
+ const floatX* __restrict__ inp,
76
+ int64_t B, int64_t T, int n, int64_t C) {
77
+
78
+ cg::thread_block block = cg::this_thread_block();
79
+ cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
80
+
81
+ int64_t bt = blockIdx.x;
82
+ int i = blockIdx.y;
83
+ if (bt >= B * T || i >= n) return;
84
+
85
+ int64_t base_grad = bt * n * C + (int64_t)i * C;
86
+ int64_t base_inp = bt * C;
87
+
88
+ float thread_sum = 0.0f;
89
+ for (int64_t c = threadIdx.x; c < C; c += BLOCK_SIZE) {
90
+ float g = to_float(grad[base_grad + c]);
91
+ float x = to_float(inp[base_inp + c]);
92
+ thread_sum += g * x;
93
+ }
94
+
95
+ // Block 级规约
96
+ float sum = cg::reduce(warp, thread_sum, cg::plus<float>());
97
+ static __shared__ float s_reduce[32];
98
+ int warp_id = threadIdx.x / 32;
99
+ int lane_id = threadIdx.x % 32;
100
+ if (lane_id == 0) s_reduce[warp_id] = sum;
101
+ block.sync();
102
+
103
+ if (warp_id == 0) {
104
+ float val = (lane_id < (BLOCK_SIZE / 32)) ? s_reduce[lane_id] : 0.0f;
105
+ float block_sum = cg::reduce(warp, val, cg::plus<float>());
106
+ if (lane_id == 0) d_H_post[bt * n + i] = block_sum;
107
+ }
108
+ }
109
+
110
+ } // namespace mhc
111
+ #endif