rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,652 @@
1
+ #include <cuda_runtime.h>
2
+ #include <cuda_bf16.h>
3
+ #include <xla/ffi/api/ffi.h>
4
+ #include <vector>
5
+ #include <cstdint>
6
+
7
+ // 公共头文件路径
8
+ #include "../common_kernel/include/mhc_types.h"
9
+ #include "../common_kernel/kernels/sinkhorn_knopp.cuh"
10
+ #include "../common_kernel/kernels/rmsnorm.cuh"
11
+ #include "../common_kernel/kernels/stream_mix.cuh"
12
+ #include "../common_kernel/kernels/stream_aggregate.cuh"
13
+ #include "../common_kernel/kernels/stream_distribute.cuh"
14
+ #include "../common_kernel/kernels/mhc_post_op.cuh"
15
+ #include "../common_kernel/kernels/mhc_pre_op.cuh"
16
+ namespace ffi = xla::ffi;
17
+
18
+ /* -------------------- Sinkhorn Knopp FFI -------------------- */
19
+
20
+ // 前向FFI处理器
21
+ static ffi::Error SinkhornFwdHost(
22
+ cudaStream_t stream,
23
+ ffi::Buffer<ffi::F32> inp, // 输入: [B, T, N, N]
24
+ ffi::ResultBuffer<ffi::F32> out, // 输出: [B, T, N, N]
25
+ std::int32_t num_iters, // 显式使用 std::int32_t
26
+ float eps // float 本身就是32位
27
+ ) {
28
+ // 获取张量维度
29
+ auto dims = inp.dimensions();
30
+ int64_t B = dims[0];
31
+ int64_t T = dims[1];
32
+ int64_t N = dims[2];
33
+
34
+ const float* inp_ptr = inp.typed_data();
35
+ float* out_ptr = out->typed_data();
36
+
37
+ // 批量调用sinkhorn前向
38
+ for (int64_t b = 0; b < B * T; ++b) {
39
+ mhc::sinkhorn_knopp_forward(
40
+ out_ptr + b * N * N,
41
+ inp_ptr + b * N * N,
42
+ static_cast<int>(N),
43
+ static_cast<int>(N),
44
+ num_iters, // 已经是int32
45
+ eps,
46
+ stream
47
+ );
48
+ }
49
+
50
+ return ffi::Error::Success();
51
+ }
52
+
53
+ // 反向FFI处理器
54
+ static ffi::Error SinkhornBwdHost(
55
+ cudaStream_t stream,
56
+ ffi::Buffer<ffi::F32> grad, // 梯度: [B, T, N, N]
57
+ ffi::Buffer<ffi::F32> out_fwd, // 前向输出: [B, T, N, N]
58
+ ffi::Buffer<ffi::F32> inp, // 原始输入: [B, T, N, N]
59
+ ffi::ResultBuffer<ffi::F32> d_inp, // 输入梯度: [B, T, N, N]
60
+ std::int32_t num_iters, // 显式使用 std::int32_t
61
+ float eps
62
+ ) {
63
+ // 获取张量维度
64
+ auto dims = grad.dimensions();
65
+ int64_t B = dims[0];
66
+ int64_t T = dims[1];
67
+ int64_t N = dims[2];
68
+
69
+ const float* grad_ptr = grad.typed_data();
70
+ const float* out_fwd_ptr = out_fwd.typed_data();
71
+ const float* inp_ptr = inp.typed_data();
72
+ float* d_inp_ptr = d_inp->typed_data();
73
+
74
+ // 批量调用sinkhorn反向
75
+ for (int64_t b = 0; b < B * T; ++b) {
76
+ mhc::sinkhorn_knopp_backward(
77
+ d_inp_ptr + b * N * N,
78
+ grad_ptr + b * N * N,
79
+ out_fwd_ptr + b * N * N,
80
+ inp_ptr + b * N * N,
81
+ static_cast<int>(N),
82
+ num_iters, // 已经是int32
83
+ eps,
84
+ stream
85
+ );
86
+ }
87
+
88
+ return ffi::Error::Success();
89
+ }
90
+
91
+ /* -------------------- FFI 符号注册 -------------------- */
92
+
93
+ // 前向符号注册
94
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
95
+ SinkhornFwd, SinkhornFwdHost,
96
+ ffi::Ffi::Bind()
97
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
98
+ .Arg<ffi::Buffer<ffi::F32>>() // inp
99
+ .Ret<ffi::Buffer<ffi::F32>>() // out
100
+ .Attr<std::int32_t>("num_iters") // 显式指定32位整数
101
+ .Attr<float>("eps") // float 默认是32位
102
+ );
103
+
104
+ // 反向符号注册
105
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
106
+ SinkhornBwd, SinkhornBwdHost,
107
+ ffi::Ffi::Bind()
108
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
109
+ .Arg<ffi::Buffer<ffi::F32>>() // grad
110
+ .Arg<ffi::Buffer<ffi::F32>>() // out_fwd
111
+ .Arg<ffi::Buffer<ffi::F32>>() // inp
112
+ .Ret<ffi::Buffer<ffi::F32>>() // d_inp
113
+ .Attr<std::int32_t>("num_iters") // 显式指定32位整数
114
+ .Attr<float>("eps") // float 默认是32位
115
+ );
116
+
117
+ static ffi::Error RMSNormFwdHost(
118
+ cudaStream_t stream,
119
+ ffi::Buffer<ffi::BF16> inp, // 输入: [N, C]
120
+ ffi::ResultBuffer<ffi::BF16> out, // 输出: [N, C]
121
+ float eps
122
+ ) {
123
+ auto dims = inp.dimensions();
124
+ int64_t N = dims[0];
125
+ int64_t C = dims[1];
126
+
127
+ const nv_bfloat16* inp_ptr = reinterpret_cast<const nv_bfloat16*>(inp.typed_data());
128
+ nv_bfloat16* out_ptr = reinterpret_cast<nv_bfloat16*>(out->typed_data());
129
+
130
+ // 调用包装函数
131
+ mhc::rmsnorm_forward(out_ptr, inp_ptr, N, C, eps, stream);
132
+
133
+ return ffi::Error::Success();
134
+ }
135
+
136
+ // 反向FFI处理器
137
+ static ffi::Error RMSNormBwdHost(
138
+ cudaStream_t stream,
139
+ ffi::Buffer<ffi::BF16> grad, // 梯度: [N, C]
140
+ ffi::Buffer<ffi::BF16> inp, // 原始输入: [N, C]
141
+ ffi::ResultBuffer<ffi::BF16> dx, // 输入梯度: [N, C]
142
+ float eps
143
+ ) {
144
+ auto dims = grad.dimensions();
145
+ int64_t N = dims[0];
146
+ int64_t C = dims[1];
147
+
148
+ const nv_bfloat16* grad_ptr = reinterpret_cast<const nv_bfloat16*>(grad.typed_data());
149
+ const nv_bfloat16* inp_ptr = reinterpret_cast<const nv_bfloat16*>(inp.typed_data());
150
+ nv_bfloat16* dx_ptr = reinterpret_cast<nv_bfloat16*>(dx->typed_data());
151
+
152
+ // 调用包装函数
153
+ mhc::rmsnorm_backward(dx_ptr, grad_ptr, inp_ptr, N, C, eps, stream);
154
+
155
+ return ffi::Error::Success();
156
+ }
157
+
158
+ /* -------------------- 注册 FFI 符号 -------------------- */
159
+
160
+ // 在文件末尾追加注册
161
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
162
+ RMSNormFwd, RMSNormFwdHost,
163
+ ffi::Ffi::Bind()
164
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
165
+ .Arg<ffi::Buffer<ffi::BF16>>() // inp
166
+ .Ret<ffi::Buffer<ffi::BF16>>() // out
167
+ .Attr<float>("eps") // eps
168
+ );
169
+
170
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
171
+ RMSNormBwd, RMSNormBwdHost,
172
+ ffi::Ffi::Bind()
173
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
174
+ .Arg<ffi::Buffer<ffi::BF16>>() // grad
175
+ .Arg<ffi::Buffer<ffi::BF16>>() // inp
176
+ .Ret<ffi::Buffer<ffi::BF16>>() // dx
177
+ .Attr<float>("eps") // eps
178
+ );
179
+
180
+ /* -------------------- Stream Mix FFI -------------------- */
181
+
182
+ // 前向FFI处理器
183
+ static ffi::Error StreamMixFwdHost(
184
+ cudaStream_t stream,
185
+ ffi::Buffer<ffi::BF16> inp, // 输入: [B, T, n, C]
186
+ ffi::Buffer<ffi::F32> M, // 权重: [B, T, n, n]
187
+ ffi::ResultBuffer<ffi::BF16> out // 输出: [B, T, n, C]
188
+ ) {
189
+ auto dims = inp.dimensions();
190
+ int64_t B = dims[0];
191
+ int64_t T = dims[1];
192
+ int64_t n = dims[2];
193
+ int64_t C = dims[3];
194
+
195
+ const nv_bfloat16* inp_ptr = reinterpret_cast<const nv_bfloat16*>(inp.typed_data());
196
+ const float* M_ptr = M.typed_data();
197
+ nv_bfloat16* out_ptr = reinterpret_cast<nv_bfloat16*>(out->typed_data());
198
+
199
+ // 调用包装函数
200
+ mhc::stream_mix_forward(out_ptr, inp_ptr, M_ptr, B, T, static_cast<int>(n), C, stream);
201
+
202
+ return ffi::Error::Success();
203
+ }
204
+
205
+ // 反向FFI处理器
206
+ // 修改1: 函数签名
207
+ static ffi::Error StreamMixBwdHost(
208
+ cudaStream_t stream,
209
+ ffi::Buffer<ffi::F32> grad, // 从 BF16 改为 F32
210
+ ffi::Buffer<ffi::BF16> inp,
211
+ ffi::Buffer<ffi::F32> M,
212
+ ffi::ResultBuffer<ffi::BF16> d_inp,
213
+ ffi::ResultBuffer<ffi::F32> d_M
214
+ ) {
215
+ auto dims = grad.dimensions(); // 现在用 grad 获取维度
216
+ int64_t B = dims[0];
217
+ int64_t T = dims[1];
218
+ int64_t n = dims[2];
219
+ int64_t C = dims[3];
220
+
221
+ const float* grad_ptr = grad.typed_data(); // 直接获取 float*
222
+ const nv_bfloat16* inp_ptr = reinterpret_cast<const nv_bfloat16*>(inp.typed_data());
223
+ const float* M_ptr = M.typed_data();
224
+ nv_bfloat16* d_inp_ptr = reinterpret_cast<nv_bfloat16*>(d_inp->typed_data());
225
+ float* d_M_ptr = d_M->typed_data();
226
+
227
+ mhc::stream_mix_backward(d_inp_ptr, d_M_ptr, grad_ptr, inp_ptr, M_ptr,
228
+ B, T, static_cast<int>(n), C, stream);
229
+
230
+ return ffi::Error::Success();
231
+ }
232
+
233
+
234
+
235
+ /* -------------------- 注册 FFI 符号 -------------------- */
236
+
237
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
238
+ StreamMixFwd, StreamMixFwdHost,
239
+ ffi::Ffi::Bind()
240
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
241
+ .Arg<ffi::Buffer<ffi::BF16>>() // inp
242
+ .Arg<ffi::Buffer<ffi::F32>>() // M
243
+ .Ret<ffi::Buffer<ffi::BF16>>() // out
244
+ );
245
+
246
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
247
+ StreamMixBwd, StreamMixBwdHost,
248
+ ffi::Ffi::Bind()
249
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
250
+ .Arg<ffi::Buffer<ffi::F32>>() // grad: F32
251
+ .Arg<ffi::Buffer<ffi::BF16>>() // inp: BF16
252
+ .Arg<ffi::Buffer<ffi::F32>>() // M: F32
253
+ .Ret<ffi::Buffer<ffi::BF16>>() // d_inp: BF16
254
+ .Ret<ffi::Buffer<ffi::F32>>() // d_M: F32
255
+ );
256
+ /* -------------------- Stream Aggregate FFI -------------------- */
257
+
258
+ // 前向FFI处理器
259
+ static ffi::Error StreamAggregateFwdHost(
260
+ cudaStream_t stream,
261
+ ffi::Buffer<ffi::BF16> inp, // 输入: [B, T, n, C]
262
+ ffi::Buffer<ffi::F32> H_pre, // 权重: [B, T, n] 或 [n]
263
+ ffi::ResultBuffer<ffi::BF16> out, // 输出: [B, T, C]
264
+ bool per_token // 是否为per-token权重模式
265
+ ) {
266
+ auto dims = inp.dimensions();
267
+ int64_t B = dims[0];
268
+ int64_t T = dims[1];
269
+ int64_t n = dims[2];
270
+ int64_t C = dims[3];
271
+
272
+ const nv_bfloat16* inp_ptr = reinterpret_cast<const nv_bfloat16*>(inp.typed_data());
273
+ const float* H_pre_ptr = H_pre.typed_data();
274
+ nv_bfloat16* out_ptr = reinterpret_cast<nv_bfloat16*>(out->typed_data());
275
+
276
+ // 调用包装函数(注意:内部会自动处理per_token逻辑)
277
+ mhc::stream_aggregate_forward(
278
+ out_ptr, inp_ptr, H_pre_ptr,
279
+ B * T, static_cast<int>(n), C, per_token, stream
280
+ );
281
+
282
+ return ffi::Error::Success();
283
+ }
284
+
285
+ // 反向FFI处理器
286
+ static ffi::Error StreamAggregateBwdHost(
287
+ cudaStream_t stream,
288
+ ffi::Buffer<ffi::F32> grad, // 梯度: [B, T, C] (float32)
289
+ ffi::Buffer<ffi::BF16> inp, // 原始输入: [B, T, n, C]
290
+ ffi::Buffer<ffi::F32> H_pre, // 权重: [B, T, n] 或 [n]
291
+ ffi::ResultBuffer<ffi::BF16> d_inp, // 输入梯度: [B, T, n, C]
292
+ ffi::ResultBuffer<ffi::F32> d_H_pre, // 权重梯度: [B, T, n] 或 [n]
293
+ bool per_token // 是否为per-token权重模式
294
+ ) {
295
+ auto dims = inp.dimensions();
296
+ int64_t B = dims[0];
297
+ int64_t T = dims[1];
298
+ int64_t n = dims[2];
299
+ int64_t C = dims[3];
300
+
301
+ const float* grad_ptr = grad.typed_data();
302
+ const nv_bfloat16* inp_ptr = reinterpret_cast<const nv_bfloat16*>(inp.typed_data());
303
+ const float* H_pre_ptr = H_pre.typed_data();
304
+ nv_bfloat16* d_inp_ptr = reinterpret_cast<nv_bfloat16*>(d_inp->typed_data());
305
+ float* d_H_pre_ptr = d_H_pre->typed_data();
306
+
307
+ // 调用包装函数(内部会处理per_token逻辑和梯度累加)
308
+ mhc::stream_aggregate_backward(
309
+ d_inp_ptr, d_H_pre_ptr, grad_ptr, inp_ptr, H_pre_ptr,
310
+ B * T, static_cast<int>(n), C, per_token, stream
311
+ );
312
+
313
+ return ffi::Error::Success();
314
+ }
315
+
316
+ /* -------------------- 注册 FFI 符号 -------------------- */
317
+
318
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
319
+ StreamAggregateFwd, StreamAggregateFwdHost,
320
+ ffi::Ffi::Bind()
321
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
322
+ .Arg<ffi::Buffer<ffi::BF16>>() // inp
323
+ .Arg<ffi::Buffer<ffi::F32>>() // H_pre
324
+ .Ret<ffi::Buffer<ffi::BF16>>() // out
325
+ .Attr<bool>("per_token") // 权重模式
326
+ );
327
+
328
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
329
+ StreamAggregateBwd, StreamAggregateBwdHost,
330
+ ffi::Ffi::Bind()
331
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
332
+ .Arg<ffi::Buffer<ffi::F32>>() // grad
333
+ .Arg<ffi::Buffer<ffi::BF16>>() // inp
334
+ .Arg<ffi::Buffer<ffi::F32>>() // H_pre
335
+ .Ret<ffi::Buffer<ffi::BF16>>() // d_inp
336
+ .Ret<ffi::Buffer<ffi::F32>>() // d_H_pre
337
+ .Attr<bool>("per_token") // 权重模式
338
+ );
339
+
340
+ /* -------------------- Stream Distribute FFI -------------------- */
341
+
342
+ // 前向:[B, T, C] (BF16), [B, T, n] (F32) -> [B, T, n, C] (BF16)
343
+ static ffi::Error StreamDistributeFwdHost(
344
+ cudaStream_t stream,
345
+ ffi::Buffer<ffi::BF16> inp, // [B, T, C]
346
+ ffi::Buffer<ffi::F32> H_post, // [B, T, n]
347
+ ffi::ResultBuffer<ffi::BF16> out // [B, T, n, C]
348
+ ) {
349
+ auto dims_inp = inp.dimensions();
350
+ auto dims_h = H_post.dimensions();
351
+
352
+ int64_t B = dims_inp[0];
353
+ int64_t T = dims_inp[1];
354
+ int64_t C = dims_inp[2];
355
+ int64_t n = dims_h[2];
356
+
357
+ // blockIdx.x 覆盖 B*T*C,blockIdx.y 覆盖 n
358
+ dim3 threads(256);
359
+ dim3 blocks((B * T * C + 255) / 256, n);
360
+
361
+ mhc::stream_distribute_fwd_kernel<<<blocks, threads, 0, stream>>>(
362
+ reinterpret_cast<mhc::floatX*>(out->typed_data()),
363
+ reinterpret_cast<const mhc::floatX*>(inp.typed_data()),
364
+ H_post.typed_data(),
365
+ B, T, static_cast<int>(n), C
366
+ );
367
+
368
+ return ffi::Error::Success();
369
+ }
370
+
371
+ // 反向
372
+ static ffi::Error StreamDistributeBwdHost(
373
+ cudaStream_t stream,
374
+ ffi::Buffer<ffi::BF16> grad, // [B, T, n, C]
375
+ ffi::Buffer<ffi::BF16> inp, // [B, T, C]
376
+ ffi::Buffer<ffi::F32> H_post, // [B, T, n]
377
+ ffi::ResultBuffer<ffi::BF16> d_inp, // [B, T, C]
378
+ ffi::ResultBuffer<ffi::F32> d_H_post // [B, T, n]
379
+ ) {
380
+ auto dims = grad.dimensions();
381
+ int64_t B = dims[0];
382
+ int64_t T = dims[1];
383
+ int64_t n = dims[2];
384
+ int64_t C = dims[3];
385
+
386
+ // 1. 计算 dx: [B, T, C]
387
+ dim3 threads(256);
388
+ dim3 blocks_dx((B * T * C + 255) / 256);
389
+ mhc::stream_distribute_bwd_dx_kernel<<<blocks_dx, threads, 0, stream>>>(
390
+ reinterpret_cast<mhc::floatX*>(d_inp->typed_data()),
391
+ reinterpret_cast<const mhc::floatX*>(grad.typed_data()),
392
+ H_post.typed_data(),
393
+ B, T, static_cast<int>(n), C
394
+ );
395
+
396
+ // 2. 计算 dH: [B, T, n]
397
+ dim3 blocks_dh(B * T, n);
398
+ mhc::stream_distribute_bwd_dh_kernel<256><<<blocks_dh, threads, 0, stream>>>(
399
+ d_H_post->typed_data(),
400
+ reinterpret_cast<const mhc::floatX*>(grad.typed_data()),
401
+ reinterpret_cast<const mhc::floatX*>(inp.typed_data()),
402
+ B, T, static_cast<int>(n), C
403
+ );
404
+
405
+ return ffi::Error::Success();
406
+ }
407
+
408
+ // 注册 FFI 符号 (追加到文件末尾的 XLA_FFI_DEFINE_HANDLER_SYMBOL 序列中)
409
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
410
+ StreamDistributeFwd, StreamDistributeFwdHost,
411
+ ffi::Ffi::Bind()
412
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
413
+ .Arg<ffi::Buffer<ffi::BF16>>() // inp
414
+ .Arg<ffi::Buffer<ffi::F32>>() // H_post
415
+ .Ret<ffi::Buffer<ffi::BF16>>() // out
416
+ );
417
+
418
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
419
+ StreamDistributeBwd, StreamDistributeBwdHost,
420
+ ffi::Ffi::Bind()
421
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
422
+ .Arg<ffi::Buffer<ffi::BF16>>() // grad
423
+ .Arg<ffi::Buffer<ffi::BF16>>() // inp
424
+ .Arg<ffi::Buffer<ffi::F32>>() // H_post
425
+ .Ret<ffi::Buffer<ffi::BF16>>() // d_inp
426
+ .Ret<ffi::Buffer<ffi::F32>>() // d_H_post
427
+ );
428
+ /* -------------------- MHC Post-Op FFI -------------------- */
429
+
430
+ // 前向处理器
431
+ static ffi::Error MhcPostOpFwdHost(
432
+ cudaStream_t stream,
433
+ ffi::Buffer<ffi::BF16> layer_out, // [B, T, C]
434
+ ffi::Buffer<ffi::BF16> x_expanded, // [B, T, n, C]
435
+ ffi::Buffer<ffi::F32> H_post, // [B, T, n]
436
+ ffi::Buffer<ffi::F32> H_res, // [B, T, n, n]
437
+ ffi::ResultBuffer<ffi::BF16> out // [B, T, n, C]
438
+ ) {
439
+ auto dims = x_expanded.dimensions();
440
+ int64_t B = dims[0], T = dims[1], n = dims[2], C = dims[3];
441
+
442
+ mhc::mhc_post_op_forward(
443
+ reinterpret_cast<mhc::floatX*>(out->typed_data()),
444
+ reinterpret_cast<const mhc::floatX*>(layer_out.typed_data()),
445
+ reinterpret_cast<const mhc::floatX*>(x_expanded.typed_data()),
446
+ H_post.typed_data(),
447
+ H_res.typed_data(),
448
+ B, T, static_cast<int>(n), C, stream
449
+ );
450
+ return ffi::Error::Success();
451
+ }
452
+ // 反向处理器
453
+ static ffi::Error MhcPostOpBwdHost(
454
+ cudaStream_t stream,
455
+ ffi::Buffer<ffi::BF16> grad, // [B, T, n, C]
456
+ ffi::Buffer<ffi::BF16> layer_out,
457
+ ffi::Buffer<ffi::BF16> x_expanded,
458
+ ffi::Buffer<ffi::F32> H_post,
459
+ ffi::Buffer<ffi::F32> H_res,
460
+ ffi::ResultBuffer<ffi::BF16> d_layer_out,
461
+ ffi::ResultBuffer<ffi::BF16> d_x_expanded,
462
+ ffi::ResultBuffer<ffi::F32> d_H_post, // <--- 需要清零
463
+ ffi::ResultBuffer<ffi::F32> d_H_res // <--- 需要清零
464
+ ) {
465
+ auto dims = x_expanded.dimensions();
466
+ int64_t B = dims[0], T = dims[1], n = dims[2], C = dims[3];
467
+
468
+ // -----------------------------------------------------------------
469
+ // 【关键修复】: 显式清零 Accumulation Buffer
470
+ // 因为 Kernel 内部使用 atomicAdd,而 JAX 分配的显存包含垃圾数据
471
+ // -----------------------------------------------------------------
472
+ size_t size_h_post = B * T * n * sizeof(float);
473
+ size_t size_h_res = B * T * n * n * sizeof(float);
474
+
475
+ cudaMemsetAsync(d_H_post->typed_data(), 0, size_h_post, stream);
476
+ cudaMemsetAsync(d_H_res->typed_data(), 0, size_h_res, stream);
477
+
478
+ // 调用 Kernel
479
+ mhc::mhc_post_op_backward_full(
480
+ reinterpret_cast<mhc::floatX*>(d_layer_out->typed_data()),
481
+ reinterpret_cast<mhc::floatX*>(d_x_expanded->typed_data()),
482
+ d_H_post->typed_data(),
483
+ d_H_res->typed_data(),
484
+ reinterpret_cast<const mhc::floatX*>(grad.typed_data()),
485
+ reinterpret_cast<const mhc::floatX*>(layer_out.typed_data()),
486
+ reinterpret_cast<const mhc::floatX*>(x_expanded.typed_data()),
487
+ H_post.typed_data(),
488
+ H_res.typed_data(),
489
+ B, T, static_cast<int>(n), C, stream
490
+ );
491
+
492
+ return ffi::Error::Success();
493
+ }
494
+ // --- 注册符号 ---
495
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
496
+ MhcPostOpFwd, MhcPostOpFwdHost,
497
+ ffi::Ffi::Bind()
498
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
499
+ .Arg<ffi::Buffer<ffi::BF16>>()
500
+ .Arg<ffi::Buffer<ffi::BF16>>()
501
+ .Arg<ffi::Buffer<ffi::F32>>()
502
+ .Arg<ffi::Buffer<ffi::F32>>()
503
+ .Ret<ffi::Buffer<ffi::BF16>>()
504
+ );
505
+
506
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
507
+ MhcPostOpBwd, MhcPostOpBwdHost,
508
+ ffi::Ffi::Bind()
509
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
510
+ .Arg<ffi::Buffer<ffi::BF16>>() // grad
511
+ .Arg<ffi::Buffer<ffi::BF16>>() // lo
512
+ .Arg<ffi::Buffer<ffi::BF16>>() // xe
513
+ .Arg<ffi::Buffer<ffi::F32>>() // hp
514
+ .Arg<ffi::Buffer<ffi::F32>>() // hr
515
+ .Ret<ffi::Buffer<ffi::BF16>>() // d_lo
516
+ .Ret<ffi::Buffer<ffi::BF16>>() // d_xe
517
+ .Ret<ffi::Buffer<ffi::F32>>() // d_hp
518
+ .Ret<ffi::Buffer<ffi::F32>>() // d_hr
519
+ );
520
+
521
+ /* -------------------- MHC Pre-Op FFI -------------------- */
522
+
523
+ // 前向处理器:融合 Aggregate + Sigmoid + Sinkhorn 投影
524
+ static ffi::Error MhcPreOpFwdHost(
525
+ cudaStream_t stream,
526
+ ffi::Buffer<ffi::BF16> x_expanded, // [B, T, n, C]
527
+ ffi::Buffer<ffi::F32> h_pre_raw, // [B, T, n]
528
+ ffi::Buffer<ffi::F32> h_post_raw, // [B, T, n]
529
+ ffi::Buffer<ffi::F32> h_res_raw, // [B, T, n, n]
530
+ ffi::ResultBuffer<ffi::BF16> x_layer_in, // [B, T, C]
531
+ ffi::ResultBuffer<ffi::F32> H_pre, // [B, T, n] (sigmoid后)
532
+ ffi::ResultBuffer<ffi::F32> H_post, // [B, T, n] (2*sigmoid后)
533
+ ffi::ResultBuffer<ffi::F32> H_res, // [B, T, n, n] (Sinkhorn后)
534
+ std::int32_t sinkhorn_iters,
535
+ float eps
536
+ ) {
537
+ auto dims = x_expanded.dimensions();
538
+ int64_t B = dims[0];
539
+ int64_t T = dims[1];
540
+ int n = static_cast<int>(dims[2]);
541
+ int64_t C = dims[3];
542
+
543
+ // 调用 .cuh 中的融合前向接口
544
+ mhc::mhc_pre_op_forward(
545
+ reinterpret_cast<mhc::floatX*>(x_layer_in->typed_data()),
546
+ H_pre->typed_data(),
547
+ H_post->typed_data(),
548
+ H_res->typed_data(),
549
+ reinterpret_cast<const mhc::floatX*>(x_expanded.typed_data()),
550
+ h_pre_raw.typed_data(),
551
+ h_post_raw.typed_data(),
552
+ h_res_raw.typed_data(),
553
+ B, T, n, C, sinkhorn_iters, eps, stream
554
+ );
555
+
556
+ return ffi::Error::Success();
557
+ }
558
+
559
+ // 反向处理器:全量梯度回传(含 Sinkhorn 反向)
560
+ static ffi::Error MhcPreOpBwdHost(
561
+ cudaStream_t stream,
562
+ ffi::Buffer<ffi::BF16> grad_layer_in, // [B, T, C]
563
+ ffi::Buffer<ffi::F32> grad_H_post, // [B, T, n]
564
+ ffi::Buffer<ffi::F32> grad_H_res, // [B, T, n, n]
565
+ ffi::Buffer<ffi::BF16> x_expanded, // [B, T, n, C] (前向输入)
566
+ ffi::Buffer<ffi::F32> H_pre, // [B, T, n] (前向输出)
567
+ ffi::Buffer<ffi::F32> H_post, // [B, T, n] (前向输出)
568
+ ffi::Buffer<ffi::F32> H_res_out, // [B, T, n, n] (Sinkhorn后)
569
+ ffi::Buffer<ffi::F32> h_res_raw, // [B, T, n, n] (原始输入)
570
+ ffi::ResultBuffer<ffi::BF16> d_x_expanded, // [B, T, n, C]
571
+ ffi::ResultBuffer<ffi::F32> d_h_pre_raw, // [B, T, n]
572
+ ffi::ResultBuffer<ffi::F32> d_h_post_raw, // [B, T, n]
573
+ ffi::ResultBuffer<ffi::F32> d_h_res_raw, // [B, T, n, n]
574
+ std::int32_t sinkhorn_iters,
575
+ float eps
576
+ ) {
577
+ auto dims = x_expanded.dimensions();
578
+ int64_t B = dims[0];
579
+ int64_t T = dims[1];
580
+ int n = static_cast<int>(dims[2]);
581
+ int64_t C = dims[3];
582
+
583
+ // -----------------------------------------------------------------
584
+ // 【关键修复】: 显式清零所有输出梯度缓冲区
585
+ // PyTorch 版本使用 torch.zeros_like,FFI 侧需手动 Memset
586
+ // 原因:1) 对齐框架行为;2) 防止未初始化数据导致的数值误差
587
+ // -----------------------------------------------------------------
588
+ size_t size_h_pre = B * T * n * sizeof(float);
589
+ size_t size_h_post = B * T * n * sizeof(float);
590
+ size_t size_h_res = B * T * n * n * sizeof(float);
591
+ // d_x_expanded 由每个线程独占写入,无需清零
592
+
593
+ cudaMemsetAsync(d_h_pre_raw->typed_data(), 0, size_h_pre, stream);
594
+ cudaMemsetAsync(d_h_post_raw->typed_data(), 0, size_h_post, stream);
595
+ cudaMemsetAsync(d_h_res_raw->typed_data(), 0, size_h_res, stream);
596
+
597
+ // 调用 .cuh 中的融合反向接口
598
+ mhc::mhc_pre_op_backward(
599
+ reinterpret_cast<mhc::floatX*>(d_x_expanded->typed_data()),
600
+ d_h_pre_raw->typed_data(),
601
+ d_h_post_raw->typed_data(),
602
+ d_h_res_raw->typed_data(),
603
+ reinterpret_cast<const mhc::floatX*>(grad_layer_in.typed_data()),
604
+ grad_H_post.typed_data(),
605
+ grad_H_res.typed_data(),
606
+ reinterpret_cast<const mhc::floatX*>(x_expanded.typed_data()),
607
+ H_pre.typed_data(),
608
+ H_post.typed_data(),
609
+ H_res_out.typed_data(),
610
+ h_res_raw.typed_data(),
611
+ B, T, n, C, sinkhorn_iters, eps, stream
612
+ );
613
+
614
+ return ffi::Error::Success();
615
+ }
616
+
617
+ // 注册 FFI 符号(追加到文件末尾)
618
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
619
+ MhcPreOpFwd, MhcPreOpFwdHost,
620
+ ffi::Ffi::Bind()
621
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
622
+ .Arg<ffi::Buffer<ffi::BF16>>() // x_expanded
623
+ .Arg<ffi::Buffer<ffi::F32>>() // h_pre_raw
624
+ .Arg<ffi::Buffer<ffi::F32>>() // h_post_raw
625
+ .Arg<ffi::Buffer<ffi::F32>>() // h_res_raw
626
+ .Ret<ffi::Buffer<ffi::BF16>>() // x_layer_in
627
+ .Ret<ffi::Buffer<ffi::F32>>() // H_pre
628
+ .Ret<ffi::Buffer<ffi::F32>>() // H_post
629
+ .Ret<ffi::Buffer<ffi::F32>>() // H_res
630
+ .Attr<std::int32_t>("sinkhorn_iters")
631
+ .Attr<float>("eps")
632
+ );
633
+
634
+ XLA_FFI_DEFINE_HANDLER_SYMBOL(
635
+ MhcPreOpBwd, MhcPreOpBwdHost,
636
+ ffi::Ffi::Bind()
637
+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
638
+ .Arg<ffi::Buffer<ffi::BF16>>() // grad_layer_in
639
+ .Arg<ffi::Buffer<ffi::F32>>() // grad_H_post
640
+ .Arg<ffi::Buffer<ffi::F32>>() // grad_H_res
641
+ .Arg<ffi::Buffer<ffi::BF16>>() // x_expanded
642
+ .Arg<ffi::Buffer<ffi::F32>>() // H_pre
643
+ .Arg<ffi::Buffer<ffi::F32>>() // H_post
644
+ .Arg<ffi::Buffer<ffi::F32>>() // H_res_out
645
+ .Arg<ffi::Buffer<ffi::F32>>() // h_res_raw
646
+ .Ret<ffi::Buffer<ffi::BF16>>() // d_x_expanded
647
+ .Ret<ffi::Buffer<ffi::F32>>() // d_h_pre_raw
648
+ .Ret<ffi::Buffer<ffi::F32>>() // d_h_post_raw
649
+ .Ret<ffi::Buffer<ffi::F32>>() // d_h_res_raw
650
+ .Attr<std::int32_t>("sinkhorn_iters")
651
+ .Attr<float>("eps")
652
+ );