rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,296 @@
1
+ #include <torch/extension.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
+ #include <cuda_runtime.h>
4
+ #include "../common_kernel/include/mhc_types.h"
5
+
6
+ namespace mhc {
7
+ // Sinkhorn 接口
8
+ void cuda_sinkhorn_fwd(float* out, const float* inp, int64_t B, int64_t M, int64_t N, int iters, float eps, cudaStream_t stream);
9
+ void cuda_sinkhorn_bwd(float* d_inp, const float* grad, const float* M_out, const float* M_inp, int64_t B, int64_t N, int iters, float eps, cudaStream_t stream);
10
+
11
+ // RMSNorm 接口
12
+ void cuda_rmsnorm_fwd(nv_bfloat16* out, const nv_bfloat16* inp, int64_t N, int64_t C, float eps, cudaStream_t stream);
13
+ void cuda_rmsnorm_bwd(nv_bfloat16* dx, const nv_bfloat16* grad, const nv_bfloat16* x, int64_t N, int64_t C, float eps, cudaStream_t stream);
14
+
15
+ // Stream Mix 接口
16
+ void cuda_stream_mix_fwd(nv_bfloat16* out, const nv_bfloat16* inp, const float* M, int64_t B, int64_t T, int n, int64_t C, cudaStream_t stream);
17
+ void cuda_stream_mix_bwd(nv_bfloat16* d_inp, float* d_M, const float* grad, const nv_bfloat16* inp, const float* M, int64_t B, int64_t T, int n, int64_t C, cudaStream_t stream);
18
+
19
+ // 新增:Stream Aggregate 接口
20
+ void cuda_stream_aggregate_fwd(nv_bfloat16* out, const nv_bfloat16* inp, const float* H_pre, int64_t B, int64_t T, int n, int64_t C, bool per_token, cudaStream_t stream);
21
+ void cuda_stream_aggregate_bwd(nv_bfloat16* d_inp, float* d_H_pre, const float* grad, const nv_bfloat16* inp, const float* H_pre, int64_t B, int64_t T, int n, int64_t C, bool per_token, cudaStream_t stream);
22
+
23
+ void cuda_stream_distribute_fwd(nv_bfloat16* out, const nv_bfloat16* inp, const float* H, int64_t B, int64_t T, int n, int64_t C, cudaStream_t stream);
24
+ void cuda_stream_distribute_bwd(nv_bfloat16* d_inp, float* d_H, const nv_bfloat16* grad, const nv_bfloat16* inp, const float* H, int64_t B, int64_t T, int n, int64_t C, cudaStream_t stream);
25
+
26
+ void cuda_mhc_post_op_fwd(nv_bfloat16* out, const nv_bfloat16* layer_out, const nv_bfloat16* x_expanded,
27
+ const float* H_post, const float* H_res, int64_t B, int64_t T, int n, int64_t C, cudaStream_t stream);
28
+
29
+ void cuda_mhc_post_op_bwd(nv_bfloat16* d_layer_out, nv_bfloat16* d_x_expanded, float* d_H_post, float* d_H_res,
30
+ const nv_bfloat16* grad_next, const nv_bfloat16* layer_out, const nv_bfloat16* x_expanded,
31
+ const float* H_post, const float* H_res, int64_t B, int64_t T, int n, int64_t C, cudaStream_t stream);
32
+
33
+ void cuda_mhc_pre_op_fwd(nv_bfloat16* x_layer_in, float* H_pre, float* H_post, float* H_res,
34
+ const nv_bfloat16* x_expanded, const float* h_pre_raw, const float* h_post_raw, const float* h_res_raw,
35
+ int64_t B, int64_t T, int n, int64_t C, int sinkhorn_iters, float eps, cudaStream_t stream);
36
+
37
+ void cuda_mhc_pre_op_bwd(nv_bfloat16* d_x_expanded, float* d_h_pre_raw, float* d_h_post_raw, float* d_h_res_raw,
38
+ const nv_bfloat16* grad_layer_in, const float* grad_H_post, const float* grad_H_res,
39
+ const nv_bfloat16* x_expanded, const float* H_pre, const float* H_post, const float* H_res_out, const float* H_res_in_raw,
40
+ int64_t B, int64_t T, int n, int64_t C, int sinkhorn_iters, float eps, cudaStream_t stream);
41
+ }
42
+
43
+ // --- Sinkhorn 绑定 ---
44
+ torch::Tensor sinkhorn_forward(torch::Tensor inp, int iters, float eps) {
45
+ auto out = torch::empty_like(inp);
46
+ int64_t B = inp.numel() / (inp.size(-1) * inp.size(-2));
47
+ mhc::cuda_sinkhorn_fwd(out.data_ptr<float>(), inp.contiguous().data_ptr<float>(), B, inp.size(-2), inp.size(-1), iters, eps, at::cuda::getCurrentCUDAStream());
48
+ return out;
49
+ }
50
+
51
+ torch::Tensor sinkhorn_backward(torch::Tensor grad, torch::Tensor out, torch::Tensor inp, int iters, float eps) {
52
+ auto d_inp = torch::empty_like(grad);
53
+ int64_t B = grad.numel() / (grad.size(-1) * grad.size(-1));
54
+ mhc::cuda_sinkhorn_bwd(d_inp.data_ptr<float>(), grad.contiguous().data_ptr<float>(), out.contiguous().data_ptr<float>(), inp.contiguous().data_ptr<float>(), B, grad.size(-1), iters, eps, at::cuda::getCurrentCUDAStream());
55
+ return d_inp;
56
+ }
57
+
58
+ // --- RMSNorm 绑定 ---
59
+ torch::Tensor rmsnorm_forward(torch::Tensor inp, float eps) {
60
+ auto out = torch::empty_like(inp);
61
+ int64_t C = inp.size(-1);
62
+ int64_t N = inp.numel() / C;
63
+ mhc::cuda_rmsnorm_fwd((nv_bfloat16*)out.data_ptr<at::BFloat16>(), (nv_bfloat16*)inp.contiguous().data_ptr<at::BFloat16>(), N, C, eps, at::cuda::getCurrentCUDAStream());
64
+ return out;
65
+ }
66
+
67
+ torch::Tensor rmsnorm_backward(torch::Tensor grad, torch::Tensor x, float eps) {
68
+ auto dx = torch::empty_like(x);
69
+ int64_t C = x.size(-1);
70
+ int64_t N = x.numel() / C;
71
+ mhc::cuda_rmsnorm_bwd((nv_bfloat16*)dx.data_ptr<at::BFloat16>(), (nv_bfloat16*)grad.contiguous().data_ptr<at::BFloat16>(), (nv_bfloat16*)x.contiguous().data_ptr<at::BFloat16>(), N, C, eps, at::cuda::getCurrentCUDAStream());
72
+ return dx;
73
+ }
74
+
75
+ // --- Stream Mix 绑定 ---
76
+ torch::Tensor stream_mix_fwd(torch::Tensor inp, torch::Tensor M) {
77
+ auto B = inp.size(0); auto T = inp.size(1); auto n = inp.size(2); auto C = inp.size(3);
78
+ auto out = torch::empty_like(inp);
79
+ mhc::cuda_stream_mix_fwd((nv_bfloat16*)out.data_ptr<at::BFloat16>(), (nv_bfloat16*)inp.contiguous().data_ptr<at::BFloat16>(), M.contiguous().data_ptr<float>(), B, T, n, C, at::cuda::getCurrentCUDAStream());
80
+ return out;
81
+ }
82
+
83
+ std::vector<torch::Tensor> stream_mix_backward(torch::Tensor grad, torch::Tensor inp, torch::Tensor M) {
84
+ int64_t B = inp.size(0); int64_t T = inp.size(1); int n = inp.size(2); int64_t C = inp.size(3);
85
+ auto d_inp = torch::empty_like(inp);
86
+ auto d_M = torch::empty_like(M);
87
+ mhc::cuda_stream_mix_bwd((nv_bfloat16*)d_inp.data_ptr<at::BFloat16>(), d_M.data_ptr<float>(), grad.contiguous().data_ptr<float>(), (nv_bfloat16*)inp.contiguous().data_ptr<at::BFloat16>(), M.contiguous().data_ptr<float>(), B, T, n, C, at::cuda::getCurrentCUDAStream());
88
+ return {d_inp, d_M};
89
+ }
90
+
91
+ // --- 新增:Stream Aggregate 绑定 ---
92
+ torch::Tensor stream_aggregate_fwd(torch::Tensor inp, torch::Tensor H_pre, bool per_token) {
93
+ int64_t B = inp.size(0); int64_t T = inp.size(1); int n = inp.size(2); int64_t C = inp.size(3);
94
+ auto out = torch::empty({B, T, C}, inp.options());
95
+ mhc::cuda_stream_aggregate_fwd((nv_bfloat16*)out.data_ptr<at::BFloat16>(), (nv_bfloat16*)inp.contiguous().data_ptr<at::BFloat16>(), H_pre.contiguous().data_ptr<float>(), B, T, n, C, per_token, at::cuda::getCurrentCUDAStream());
96
+ return out;
97
+ }
98
+
99
+ std::vector<torch::Tensor> stream_aggregate_bwd(torch::Tensor grad, torch::Tensor inp, torch::Tensor H_pre, bool per_token) {
100
+ int64_t B = inp.size(0); int64_t T = inp.size(1); int n = inp.size(2); int64_t C = inp.size(3);
101
+ auto d_inp = torch::empty_like(inp);
102
+ auto d_H_pre = torch::empty_like(H_pre);
103
+ mhc::cuda_stream_aggregate_bwd((nv_bfloat16*)d_inp.data_ptr<at::BFloat16>(), d_H_pre.data_ptr<float>(), grad.contiguous().data_ptr<float>(), (nv_bfloat16*)inp.contiguous().data_ptr<at::BFloat16>(), H_pre.contiguous().data_ptr<float>(), B, T, n, C, per_token, at::cuda::getCurrentCUDAStream());
104
+ return {d_inp, d_H_pre};
105
+ }
106
+ torch::Tensor stream_distribute_fwd(torch::Tensor inp, torch::Tensor H) {
107
+ // inp: [B, T, C], H: [B, T, n]
108
+ int64_t B = inp.size(0);
109
+ int64_t T = inp.size(1);
110
+ int64_t C = inp.size(2);
111
+ int n = H.size(2);
112
+
113
+ auto out = torch::empty({B, T, n, C}, inp.options());
114
+
115
+ mhc::cuda_stream_distribute_fwd(
116
+ (nv_bfloat16*)out.data_ptr<at::BFloat16>(),
117
+ (nv_bfloat16*)inp.contiguous().data_ptr<at::BFloat16>(),
118
+ H.contiguous().data_ptr<float>(),
119
+ B, T, n, C,
120
+ at::cuda::getCurrentCUDAStream()
121
+ );
122
+ return out;
123
+ }
124
+
125
+ std::vector<torch::Tensor> stream_distribute_backward(torch::Tensor grad, torch::Tensor inp, torch::Tensor H) {
126
+ int64_t B = inp.size(0);
127
+ int64_t T = inp.size(1);
128
+ int64_t C = inp.size(2);
129
+ int n = H.size(2);
130
+
131
+ auto d_inp = torch::empty_like(inp);
132
+ auto d_H = torch::empty_like(H);
133
+
134
+ mhc::cuda_stream_distribute_bwd(
135
+ (nv_bfloat16*)d_inp.data_ptr<at::BFloat16>(),
136
+ d_H.data_ptr<float>(),
137
+ (nv_bfloat16*)grad.contiguous().data_ptr<at::BFloat16>(),
138
+ (nv_bfloat16*)inp.contiguous().data_ptr<at::BFloat16>(),
139
+ H.contiguous().data_ptr<float>(),
140
+ B, T, n, C,
141
+ at::cuda::getCurrentCUDAStream()
142
+ );
143
+ return {d_inp, d_H};
144
+ }
145
+ torch::Tensor mhc_post_op_forward(torch::Tensor layer_out, torch::Tensor x_expanded, torch::Tensor H_post, torch::Tensor H_res) {
146
+ int64_t B = layer_out.size(0);
147
+ int64_t T = layer_out.size(1);
148
+ int64_t C = layer_out.size(2);
149
+ int n = H_post.size(2);
150
+
151
+ auto out = torch::empty_like(x_expanded);
152
+ mhc::cuda_mhc_post_op_fwd(
153
+ (nv_bfloat16*)out.data_ptr<at::BFloat16>(),
154
+ (nv_bfloat16*)layer_out.contiguous().data_ptr<at::BFloat16>(),
155
+ (nv_bfloat16*)x_expanded.contiguous().data_ptr<at::BFloat16>(),
156
+ H_post.contiguous().data_ptr<float>(),
157
+ H_res.contiguous().data_ptr<float>(),
158
+ B, T, n, C, at::cuda::getCurrentCUDAStream()
159
+ );
160
+ return out;
161
+ }
162
+
163
+ // 反向 Torch 接口 (全量融合)
164
+ std::vector<torch::Tensor> mhc_post_op_backward(torch::Tensor grad_next, torch::Tensor layer_out, torch::Tensor x_expanded, torch::Tensor H_post, torch::Tensor H_res) {
165
+ int64_t B = layer_out.size(0);
166
+ int64_t T = layer_out.size(1);
167
+ int64_t C = layer_out.size(2);
168
+ int n = H_post.size(2);
169
+
170
+ auto d_layer_out = torch::empty_like(layer_out);
171
+ auto d_x_expanded = torch::empty_like(x_expanded);
172
+ // 参数梯度使用 zeros,因为内核内部是原子累加
173
+ auto d_H_post = torch::zeros_like(H_post);
174
+ auto d_H_res = torch::zeros_like(H_res);
175
+
176
+ mhc::cuda_mhc_post_op_bwd(
177
+ (nv_bfloat16*)d_layer_out.data_ptr<at::BFloat16>(),
178
+ (nv_bfloat16*)d_x_expanded.data_ptr<at::BFloat16>(),
179
+ d_H_post.data_ptr<float>(),
180
+ d_H_res.data_ptr<float>(),
181
+ (nv_bfloat16*)grad_next.contiguous().data_ptr<at::BFloat16>(),
182
+ (nv_bfloat16*)layer_out.contiguous().data_ptr<at::BFloat16>(),
183
+ (nv_bfloat16*)x_expanded.contiguous().data_ptr<at::BFloat16>(),
184
+ H_post.contiguous().data_ptr<float>(),
185
+ H_res.contiguous().data_ptr<float>(),
186
+ B, T, n, C, at::cuda::getCurrentCUDAStream()
187
+ );
188
+
189
+ return {d_layer_out, d_x_expanded, d_H_post, d_H_res};
190
+ }
191
+ #include <torch/extension.h>
192
+ #include <c10/cuda/CUDAStream.h>
193
+ #include <vector>
194
+
195
+ // 声明 CUDA 包装函数(定义在 mhc_cuda.cu 中)
196
+ namespace mhc {
197
+ void cuda_mhc_pre_op_fwd(nv_bfloat16* x_layer_in, float* H_pre, float* H_post, float* H_res,
198
+ const nv_bfloat16* x_expanded, const float* h_pre_raw, const float* h_post_raw, const float* h_res_raw,
199
+ int64_t B, int64_t T, int n, int64_t C, int sinkhorn_iters, float eps, cudaStream_t stream);
200
+
201
+ void cuda_mhc_pre_op_bwd(nv_bfloat16* d_x_expanded, float* d_h_pre_raw, float* d_h_post_raw, float* d_h_res_raw,
202
+ const nv_bfloat16* grad_layer_in, const float* grad_H_post, const float* grad_H_res,
203
+ const nv_bfloat16* x_expanded, const float* H_pre, const float* H_post, const float* H_res_out, const float* H_res_in_raw,
204
+ int64_t B, int64_t T, int n, int64_t C, int sinkhorn_iters, float eps, cudaStream_t stream);
205
+ }
206
+
207
+ // ----------------------------------------------------------------------------
208
+ // 1. Forward 接口:全部改为 zeros 确保输出纯净
209
+ // ----------------------------------------------------------------------------
210
+ std::vector<torch::Tensor> mhc_pre_op_forward(
211
+ torch::Tensor x_expanded, torch::Tensor h_pre_raw, torch::Tensor h_post_raw, torch::Tensor h_res_raw,
212
+ int sinkhorn_iters, float eps)
213
+ {
214
+ int64_t B = x_expanded.size(0);
215
+ int64_t T = x_expanded.size(1);
216
+ int n = x_expanded.size(2);
217
+ int64_t C = x_expanded.size(3);
218
+
219
+ // 使用 zeros 替代 empty,防止 kernel 未覆盖区域产生脏数据污染 Sinkhorn
220
+ auto x_layer_in = torch::zeros({B, T, C}, x_expanded.options());
221
+ auto H_pre = torch::zeros({B, T, n}, h_pre_raw.options());
222
+ auto H_post = torch::zeros({B, T, n}, h_post_raw.options());
223
+ auto H_res = torch::zeros({B, T, n, n}, h_res_raw.options());
224
+
225
+ mhc::cuda_mhc_pre_op_fwd(
226
+ (nv_bfloat16*)x_layer_in.data_ptr<at::BFloat16>(),
227
+ H_pre.data_ptr<float>(),
228
+ H_post.data_ptr<float>(),
229
+ H_res.data_ptr<float>(),
230
+ (nv_bfloat16*)x_expanded.contiguous().data_ptr<at::BFloat16>(),
231
+ h_pre_raw.contiguous().data_ptr<float>(),
232
+ h_post_raw.contiguous().data_ptr<float>(),
233
+ h_res_raw.contiguous().data_ptr<float>(),
234
+ B, T, n, C, sinkhorn_iters, eps,
235
+ c10::cuda::getCurrentCUDAStream()
236
+ );
237
+
238
+ return {x_layer_in, H_pre, H_post, H_res};
239
+ }
240
+
241
+ // ----------------------------------------------------------------------------
242
+ // 2. Backward 接口:全部改为 zeros 确保梯度累加安全
243
+ // ----------------------------------------------------------------------------
244
+ std::vector<torch::Tensor> mhc_pre_op_backward(
245
+ torch::Tensor grad_layer_in, torch::Tensor grad_H_post, torch::Tensor grad_H_res,
246
+ torch::Tensor x_expanded, torch::Tensor H_pre, torch::Tensor H_post,
247
+ torch::Tensor H_res_out, torch::Tensor h_res_raw,
248
+ int sinkhorn_iters, float eps)
249
+ {
250
+ int64_t B = x_expanded.size(0);
251
+ int64_t T = x_expanded.size(1);
252
+ int n = x_expanded.size(2);
253
+ int64_t C = x_expanded.size(3);
254
+
255
+ // 梯度 Tensor 必须清零,因为内核可能涉及原子加或特定线程写回
256
+ auto d_x_expanded = torch::zeros_like(x_expanded);
257
+ auto d_h_pre_raw = torch::zeros_like(H_pre);
258
+ auto d_h_post_raw = torch::zeros_like(H_post);
259
+ auto d_h_res_raw = torch::zeros({B, T, n * n}, h_res_raw.options());
260
+
261
+ mhc::cuda_mhc_pre_op_bwd(
262
+ (nv_bfloat16*)d_x_expanded.data_ptr<at::BFloat16>(),
263
+ d_h_pre_raw.data_ptr<float>(),
264
+ d_h_post_raw.data_ptr<float>(),
265
+ d_h_res_raw.data_ptr<float>(),
266
+ (nv_bfloat16*)grad_layer_in.contiguous().data_ptr<at::BFloat16>(),
267
+ grad_H_post.contiguous().data_ptr<float>(),
268
+ grad_H_res.contiguous().data_ptr<float>(),
269
+ (nv_bfloat16*)x_expanded.contiguous().data_ptr<at::BFloat16>(),
270
+ H_pre.contiguous().data_ptr<float>(),
271
+ H_post.contiguous().data_ptr<float>(),
272
+ H_res_out.contiguous().data_ptr<float>(),
273
+ h_res_raw.contiguous().data_ptr<float>(),
274
+ B, T, n, C, sinkhorn_iters, eps,
275
+ c10::cuda::getCurrentCUDAStream()
276
+ );
277
+
278
+ return {d_x_expanded, d_h_pre_raw, d_h_post_raw, d_h_res_raw};
279
+ }
280
+
281
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
282
+ m.def("sinkhorn_fwd", &sinkhorn_forward);
283
+ m.def("sinkhorn_bwd", &sinkhorn_backward);
284
+ m.def("rmsnorm_fwd", &rmsnorm_forward);
285
+ m.def("rmsnorm_bwd", &rmsnorm_backward);
286
+ m.def("stream_mix_fwd", &stream_mix_fwd);
287
+ m.def("stream_mix_backward", &stream_mix_backward);
288
+ m.def("stream_aggregate_fwd", &stream_aggregate_fwd);
289
+ m.def("stream_aggregate_bwd", &stream_aggregate_bwd);
290
+ m.def("stream_distribute_fwd", &stream_distribute_fwd, "Stream Distribute Forward");
291
+ m.def("stream_distribute_bwd", &stream_distribute_backward, "Stream Distribute Backward");
292
+ m.def("mhc_post_op_fwd", &mhc_post_op_forward);
293
+ m.def("mhc_post_op_bwd", &mhc_post_op_backward);
294
+ m.def("mhc_pre_op_bwd", &mhc_pre_op_backward);
295
+ m.def("mhc_pre_op_fwd", &mhc_pre_op_forward);
296
+ }
@@ -0,0 +1,306 @@
1
+ import os
2
+ import torch
3
+ from torch.utils.cpp_extension import load
4
+
5
+ # 路径配置
6
+ current_dir = os.path.dirname(os.path.abspath(__file__))
7
+ common_inc = os.path.abspath(os.path.join(current_dir, "../common_kernel/include"))
8
+ common_ker = os.path.abspath(os.path.join(current_dir, "../common_kernel/kernels"))
9
+
10
+ mhc_lib = load(
11
+ name="mhc_cuda_kernel",
12
+ sources=[
13
+ os.path.join(current_dir, "mhc_op.cpp"),
14
+ os.path.join(current_dir, "mhc_cuda.cu"),
15
+ ],
16
+ extra_include_paths=[common_inc, common_ker],
17
+ extra_cuda_cflags=[
18
+ "-O3",
19
+ "--use_fast_math",
20
+ "-std=c++17",
21
+ "-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
22
+ ],
23
+ verbose=True,
24
+ )
25
+
26
+
27
+ class SinkhornKnoppFunction(torch.autograd.Function):
28
+ @staticmethod
29
+ def forward(ctx, inp, num_iters=20, eps=1e-8):
30
+ x = inp.float().contiguous()
31
+ x_max = torch.amax(x, dim=(-1, -2), keepdim=True)
32
+ out = mhc_lib.sinkhorn_fwd(x - x_max, num_iters, eps)
33
+ ctx.save_for_backward(out, x - x_max)
34
+ ctx.num_iters, ctx.eps = num_iters, eps
35
+ return out.to(inp.dtype)
36
+
37
+ @staticmethod
38
+ def backward(ctx, grad_output):
39
+ out, x_stabilized = ctx.saved_tensors
40
+ d_inp = mhc_lib.sinkhorn_bwd(
41
+ grad_output.float().contiguous(), out, x_stabilized, ctx.num_iters, ctx.eps
42
+ )
43
+ return d_inp.to(grad_output.dtype), None, None
44
+
45
+
46
+ class RMSNormFunction(torch.autograd.Function):
47
+ @staticmethod
48
+ def forward(ctx, inp, eps=1e-5):
49
+ inp = inp.to(torch.bfloat16).contiguous()
50
+ out = mhc_lib.rmsnorm_fwd(inp, eps)
51
+ ctx.save_for_backward(inp)
52
+ ctx.eps = eps
53
+ return out
54
+
55
+ @staticmethod
56
+ def backward(ctx, grad_output):
57
+ (inp,) = ctx.saved_tensors
58
+ dx = mhc_lib.rmsnorm_bwd(
59
+ grad_output.to(torch.bfloat16).contiguous(), inp, ctx.eps
60
+ )
61
+ return dx, None
62
+
63
+
64
+ class StreamMixFunction(torch.autograd.Function):
65
+ @staticmethod
66
+ def forward(ctx, inp, M):
67
+ inp = inp.to(torch.bfloat16).contiguous()
68
+ M = M.float().contiguous()
69
+ out = mhc_lib.stream_mix_fwd(inp, M)
70
+ ctx.save_for_backward(inp, M)
71
+ return out
72
+
73
+ @staticmethod
74
+ def backward(ctx, grad_output):
75
+ inp, M = ctx.saved_tensors
76
+ grad_output_fp32 = grad_output.float().contiguous()
77
+ d_inp, d_M = mhc_lib.stream_mix_backward(grad_output_fp32, inp, M)
78
+ return d_inp, d_M
79
+
80
+
81
+ # --- 新增:Stream Aggregate 功能 ---
82
+ class StreamAggregateFunction(torch.autograd.Function):
83
+ @staticmethod
84
+ def forward(ctx, inp, H_pre):
85
+ inp = inp.to(torch.bfloat16).contiguous()
86
+ H_pre = H_pre.float().contiguous()
87
+
88
+ # 判断权重模式
89
+ per_token = H_pre.dim() == 3
90
+
91
+ out = mhc_lib.stream_aggregate_fwd(inp, H_pre, per_token)
92
+ ctx.save_for_backward(inp, H_pre)
93
+ ctx.per_token = per_token
94
+ return out
95
+
96
+ @staticmethod
97
+ def backward(ctx, grad_output):
98
+ inp, H_pre = ctx.saved_tensors
99
+ # 精度核心:强制将梯度转为 float32 传入内核进行规约
100
+ grad_output_fp32 = grad_output.float().contiguous()
101
+ d_inp, d_H_pre = mhc_lib.stream_aggregate_bwd(
102
+ grad_output_fp32, inp, H_pre, ctx.per_token
103
+ )
104
+ return d_inp, d_H_pre
105
+
106
+
107
+ def stream_aggregate(inp, H_pre):
108
+ inp = inp.to(torch.bfloat16)
109
+ H_pre = H_pre.to(torch.float32)
110
+ return StreamAggregateFunction.apply(inp, H_pre)
111
+
112
+
113
+ class StreamDistributeFunction(torch.autograd.Function):
114
+ @staticmethod
115
+ def forward(ctx, inp, H_post):
116
+ """
117
+ inp: [B, T, C] (通常为 bf16)
118
+ H_post: [B, T, n] (通常为 fp32)
119
+ 返回: [B, T, n, C] (bf16)
120
+ """
121
+ # 1. 强制连续性以适配 CUDA 内核
122
+ ctx.inp_dtype = inp.dtype
123
+ ctx.H_post_dtype = H_post.dtype
124
+ inp = inp.bfloat16().contiguous()
125
+ H_post = H_post.float().contiguous()
126
+
127
+ B, T, C = inp.shape
128
+ n = H_post.shape[-1]
129
+
130
+ out = mhc_lib.stream_distribute_fwd(inp, H_post)
131
+
132
+ ctx.save_for_backward(inp, H_post)
133
+ return out
134
+
135
+ @staticmethod
136
+ def backward(ctx, grad_output):
137
+ """
138
+ grad_output: [B, T, n, C] (反向传回的梯度)
139
+ 返回: d_inp, d_H_post
140
+ """
141
+ inp, H_post = ctx.saved_tensors
142
+ grad_output = grad_output.contiguous()
143
+
144
+ # 调用 C++ 绑定的反向内核
145
+ # 内核内部会计算:
146
+ # d_inp = sum_i(grad_output[..., i, :] * H_post[..., i])
147
+ # d_H_post = sum_c(grad_output[..., :, c] * inp[..., c])
148
+ d_inp, d_H_post = mhc_lib.stream_distribute_bwd(grad_output, inp, H_post)
149
+
150
+ # 对应 forward 的参数顺序:inp, H_post
151
+ return d_inp.to(ctx.inp_dtype), d_H_post.to(ctx.H_post_dtype)
152
+
153
+
154
+ class MHCPostOpFunction(torch.autograd.Function):
155
+ @staticmethod
156
+ def forward(ctx, layer_out, x_expanded, H_post, H_res):
157
+ # 强制连续性
158
+ layer_out = layer_out.contiguous()
159
+ x_expanded = x_expanded.contiguous()
160
+ H_post = H_post.contiguous()
161
+ H_res = H_res.contiguous()
162
+
163
+ # 保存用于反向传播的张量
164
+ ctx.save_for_backward(layer_out, x_expanded, H_post, H_res)
165
+
166
+ # 调用融合前向内核
167
+ x_next = mhc_lib.mhc_post_op_fwd(layer_out, x_expanded, H_post, H_res)
168
+ return x_next
169
+
170
+ @staticmethod
171
+ def backward(ctx, grad_next):
172
+ # 获取保存的张量
173
+ layer_out, x_expanded, H_post, H_res = ctx.saved_tensors
174
+ grad_next = grad_next.contiguous()
175
+
176
+ # 调用全量融合反向内核
177
+ # 返回列表: [d_layer_out, d_x_expanded, d_H_post, d_H_res]
178
+ grads = mhc_lib.mhc_post_op_bwd(grad_next, layer_out, x_expanded, H_post, H_res)
179
+
180
+ # 返回 4 个梯度,对应 forward 的 4 个输入
181
+ return grads[0], grads[1], grads[2], grads[3]
182
+
183
+
184
+ def mhc_post_op(layer_out, x_expanded, H_post, H_res):
185
+ """
186
+ mHC 融合后处理算子
187
+ layer_out: [B, T, C]
188
+ x_expanded: [B, T, n, C]
189
+ H_post: [B, T, n]
190
+ H_res: [B, T, n, n]
191
+ """
192
+ layer_out = layer_out.to(torch.bfloat16)
193
+ x_expanded = x_expanded.to(torch.bfloat16)
194
+ H_post = H_post.to(torch.float32)
195
+ H_res = H_res.to(torch.float32)
196
+ return MHCPostOpFunction.apply(layer_out, x_expanded, H_post, H_res)
197
+
198
+
199
+ def stream_distribute(inp, H_post):
200
+ """
201
+ mHC 分发算子 (1 -> n): 将单流信号按照权重分发到 n 个并行流中。
202
+ """
203
+ inp = inp.to(torch.bfloat16)
204
+ H_post = H_post.to(torch.float32)
205
+ return StreamDistributeFunction.apply(inp, H_post)
206
+
207
+
208
+ # 辅助接口
209
+ def sinkhorn_knopp(inp, num_iters=20, eps=1e-8):
210
+ inp = inp.to(torch.float32)
211
+ return SinkhornKnoppFunction.apply(inp, num_iters, eps)
212
+
213
+
214
+ def rmsnorm(inp, eps=1e-5):
215
+ inp = inp.to(torch.bfloat16)
216
+ return RMSNormFunction.apply(inp, eps)
217
+
218
+
219
+ def stream_mix(inp, M):
220
+ inp = inp.to(torch.bfloat16)
221
+ M = M.to(torch.float32)
222
+ return StreamMixFunction.apply(inp, M)
223
+
224
+
225
+ class MHCPreOpFunction(torch.autograd.Function):
226
+ @staticmethod
227
+ def forward(
228
+ ctx, x_expanded, h_pre_raw, h_post_raw, h_res_raw, num_iters=20, eps=1e-8
229
+ ):
230
+ # 1. 保存原始类型
231
+ ctx.x_dtype = x_expanded.dtype
232
+ ctx.h_dtype = h_pre_raw.dtype # 通常是 fp32,但需要记录
233
+
234
+ # 2. 强制类型检查与转换 (为了对齐 C++ 接口)
235
+ # x_expanded 必须是 bfloat16 (对应 nv_bfloat16*)
236
+ x_expanded = x_expanded.to(dtype=torch.bfloat16).contiguous()
237
+ # 参数类 tensor 必须是 float32 (对应 float*)
238
+ h_pre_raw = h_pre_raw.to(dtype=torch.float32).contiguous()
239
+ h_post_raw = h_post_raw.to(dtype=torch.float32).contiguous()
240
+ h_res_raw = h_res_raw.to(dtype=torch.float32).contiguous()
241
+
242
+ # 3. 调用 CUDA 接口 (返回: x_layer_in [bf16], H_pre [f32], H_post [f32], H_res [f32])
243
+ x_layer_in, H_pre, H_post, H_res = mhc_lib.mhc_pre_op_fwd(
244
+ x_expanded, h_pre_raw, h_post_raw, h_res_raw, num_iters, eps
245
+ )
246
+
247
+ # 4. 保存反向传播需要的中间变量
248
+ ctx.save_for_backward(x_expanded, H_pre, H_post, H_res, h_res_raw)
249
+ ctx.num_iters = num_iters
250
+ ctx.eps = eps
251
+
252
+ # 5. 将主干输出转回原始类型 (通常是 bf16)
253
+ return x_layer_in.to(dtype=ctx.x_dtype), H_post, H_res
254
+
255
+ @staticmethod
256
+ def backward(ctx, grad_layer_in, grad_H_post, grad_H_res):
257
+ x_expanded, H_pre, H_post, H_res, h_res_raw = ctx.saved_tensors
258
+
259
+ # 1. 强制梯度类型对齐 C++ 反向接口
260
+ grad_layer_in = grad_layer_in.to(dtype=torch.bfloat16).contiguous()
261
+ grad_H_post = grad_H_post.to(dtype=torch.float32).contiguous()
262
+ grad_H_res = grad_H_res.to(dtype=torch.float32).contiguous()
263
+
264
+ # 2. 调用 CUDA 反向内核
265
+ # 返回 grads: [d_x_expanded, d_h_pre_raw, d_h_post_raw, d_h_res_raw]
266
+ grads = mhc_lib.mhc_pre_op_bwd(
267
+ grad_layer_in,
268
+ grad_H_post,
269
+ grad_H_res,
270
+ x_expanded,
271
+ H_pre,
272
+ H_post,
273
+ H_res,
274
+ h_res_raw,
275
+ ctx.num_iters,
276
+ ctx.eps,
277
+ )
278
+
279
+ # 3. 类型还原:将计算出的梯度转回输入时的原始数据类型
280
+ # 防止下游优化器(如 Adam)因为梯度类型不匹配而报错或增加额外的 cast 开销
281
+ dx = grads[0].to(dtype=ctx.x_dtype)
282
+ d_h_pre = grads[1].to(dtype=ctx.h_dtype)
283
+ d_h_post = grads[2].to(dtype=ctx.h_dtype)
284
+ d_h_res = grads[3].reshape(h_res_raw.shape).to(dtype=ctx.h_dtype)
285
+
286
+ # 返回 4 个输入对应的梯度,最后两个参数 num_iters/eps 对应 None
287
+ return dx, d_h_pre, d_h_post, d_h_res, None, None
288
+
289
+
290
+ def mhc_pre_op(x_expanded, h_pre_raw, h_post_raw, h_res_raw, num_iters=20, eps=1e-8):
291
+ """
292
+ mHC 前处理融合算子接口
293
+ """
294
+ x_expanded = x_expanded.to(torch.bfloat16)
295
+ h_pre_raw = h_pre_raw.to(torch.float32)
296
+ h_post_raw = h_post_raw.to(torch.float32)
297
+ h_res_raw = h_res_raw.to(torch.float32)
298
+ # 预处理:h_res_raw 可能是 [B, T, n, n] 或 [B, T, n*n]
299
+ if h_res_raw.dim() == 4:
300
+ h_res_raw_flat = h_res_raw.reshape(h_res_raw.shape[0], h_res_raw.shape[1], -1)
301
+ else:
302
+ h_res_raw_flat = h_res_raw
303
+
304
+ return MHCPreOpFunction.apply(
305
+ x_expanded, h_pre_raw, h_post_raw, h_res_raw_flat, num_iters, eps
306
+ )