rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
rwkv_ops/__init__.py ADDED
@@ -0,0 +1,45 @@
1
+ __version__ = "0.6.0"
2
+ import os
3
+
4
+ KERNEL_TYPE = os.environ.get("KERNEL_TYPE", "cuda").lower()
5
+ KERAS_BACKEND = os.environ.get("KERAS_BACKEND")
6
+ BACKEND = os.environ.get("KERNEL_BACKEND")
7
+
8
+
9
+ if KERAS_BACKEND is not None:
10
+ BACKEND = KERAS_BACKEND.lower()
11
+ elif BACKEND is not None:
12
+ os.environ["KERAS_BACKEND"] = BACKEND.lower()
13
+ else:
14
+ import keras
15
+
16
+ BACKEND = "torch"
17
+ os.environ["KERAS_BACKEND"] = BACKEND
18
+ keras.config.set_backend("torch")
19
+ assert KERNEL_TYPE in ["triton", "cuda", "native"]
20
+ assert BACKEND in ["torch", "jax", "numpy", "tensorflow"]
21
+ from .rwkv7_kernel import get_generalized_delta_rule, get_rnn_generalized_delta_rule
22
+ from .rwkv6_kernel import get_rwkv6_kernel
23
+ from .mhc_kernel import get_mhu_kernel
24
+
25
+ (
26
+ mhc_sinkhorn_knopp,
27
+ mhc_rmsnorm,
28
+ mhc_stream_aggregate,
29
+ mhc_stream_distribute,
30
+ mhc_stream_mix,
31
+ mhc_post_op,
32
+ mhc_pre_op,
33
+ ) = get_mhu_kernel(KERNEL_TYPE)
34
+
35
+ generalized_delta_rule, generalized_delta_rule_inference, RWKV7_USE_TRITON_KERNEL = (
36
+ get_generalized_delta_rule(KERNEL_TYPE=KERNEL_TYPE)
37
+ )
38
+ rwkv7_op = generalized_delta_rule
39
+ rwkv7_op_inference = generalized_delta_rule_inference
40
+
41
+ rnn_generalized_delta_rule = get_rnn_generalized_delta_rule(KERNEL_TYPE=KERNEL_TYPE)
42
+ rwkv7_op_rnn = rnn_generalized_delta_rule
43
+
44
+
45
+ RWKV6_OP = get_rwkv6_kernel(KERNEL_TYPE=KERNEL_TYPE)
@@ -0,0 +1,50 @@
1
+ import keras
2
+
3
+
4
+ def get_mhu_kernel(KERNEL_TYPE):
5
+ from .native_keras_op import (
6
+ sinkhorn_knopp,
7
+ rmsnorm,
8
+ stream_aggregate,
9
+ stream_distribute,
10
+ stream_mix,
11
+ mhc_post_op,
12
+ mhc_pre_op,
13
+ )
14
+
15
+ if KERNEL_TYPE == "cuda":
16
+ if keras.config.backend() == "torch":
17
+ import torch
18
+
19
+ if torch.cuda.is_available():
20
+ from .torch_kernel.mhc_torch import (
21
+ sinkhorn_knopp,
22
+ rmsnorm,
23
+ stream_aggregate,
24
+ stream_distribute,
25
+ stream_mix,
26
+ mhc_post_op,
27
+ mhc_pre_op,
28
+ )
29
+ elif keras.config.backend() == "jax":
30
+ import jax
31
+
32
+ if jax.devices()[0].platform == "gpu":
33
+ from .jax_kernel.mhu_jax import (
34
+ sinkhorn_knopp,
35
+ rmsnorm,
36
+ stream_aggregate,
37
+ stream_distribute,
38
+ stream_mix,
39
+ mhc_post_op,
40
+ mhc_pre_op,
41
+ )
42
+ return (
43
+ sinkhorn_knopp,
44
+ rmsnorm,
45
+ stream_aggregate,
46
+ stream_distribute,
47
+ stream_mix,
48
+ mhc_post_op,
49
+ mhc_pre_op,
50
+ )
@@ -0,0 +1,66 @@
1
+ #pragma once
2
+
3
+ #include <cuda_runtime.h>
4
+ #include <cuda_bf16.h>
5
+ #include <cuda_fp16.h>
6
+ #include <cublasLt.h>
7
+ #include <assert.h>
8
+ #include <cstdint>
9
+
10
+ namespace mhc {
11
+
12
+ using floatX = __nv_bfloat16;
13
+ using floatN = float;
14
+
15
+ // 定义统一的转换工具,供所有 .cuh 和 .cu 使用
16
+ __device__ inline float to_float(const floatX& u) {
17
+ return __bfloat162float(u);
18
+ }
19
+
20
+ __device__ inline floatX to_bf(const float& u) {
21
+ #if __CUDA_ARCH__ >= 800
22
+ return __float2bfloat16(u);
23
+ #else
24
+ // 兼容旧架构或强制舍入
25
+ return __float2bfloat16_rn(u);
26
+ #endif
27
+ }
28
+
29
+
30
+ struct MHCConfig {
31
+ int sinkhorn_iters;
32
+ int nC;
33
+ float eps;
34
+ bool use_pdl;
35
+ };
36
+
37
+ struct RMSNormParams {
38
+ int n;
39
+ float eps;
40
+ };
41
+
42
+ inline void check_cuda(cudaError_t err, const char* file, int line) {
43
+ if (err != cudaSuccess) {
44
+ fprintf(stderr, "CUDA error at %s:%d: %s\n", file, line, cudaGetErrorString(err));
45
+ exit(EXIT_FAILURE);
46
+ }
47
+ }
48
+
49
+ inline void check_cublas(cublasStatus_t status, const char* file, int line) {
50
+ if (status != CUBLAS_STATUS_SUCCESS) {
51
+ fprintf(stderr, "cuBLAS error at %s:%d: %d\n", file, line, (int)status);
52
+ exit(EXIT_FAILURE);
53
+ }
54
+ }
55
+ // 错误检查宏
56
+ #define CHECK_CUDA(call) \
57
+ do { \
58
+ cudaError_t err = call; \
59
+ if (err != cudaSuccess) { \
60
+ printf("CUDA Error at %s:%d - %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
61
+ } \
62
+ } while (0)
63
+
64
+ #define CHECK_CUDA(call) mhc::check_cuda((call), __FILE__, __LINE__)
65
+ #define CHECK_CUBLAS(call) mhc::check_cublas((call), __FILE__, __LINE__)
66
+ } // namespace mhc
@@ -0,0 +1,197 @@
1
+ #ifndef MHC_POST_OP_CUH
2
+ #define MHC_POST_OP_CUH
3
+
4
+ #include <cuda_runtime.h>
5
+ #include <cuda_bf16.h>
6
+ #include <cooperative_groups.h>
7
+ #include <cooperative_groups/reduce.h>
8
+ #include "../include/mhc_types.h"
9
+ #include "type_conversions.cuh"
10
+
11
+ namespace cg = cooperative_groups;
12
+
13
+ namespace mhc {
14
+
15
+ /**
16
+ * 1. Fused Forward Kernel
17
+ * 公式: x_next[b,t,i,c] = sum_j(H_res[b,t,i,j] * x_expanded[b,t,j,c]) + layer_out[b,t,c] * H_post[b,t,i]
18
+ * 精度策略:所有中间累加使用 FP32
19
+ */
20
+ template<int MAX_N = 8>
21
+ __global__ void mhc_post_op_fwd_kernel(
22
+ floatX* __restrict__ x_next, // [B, T, n, C]
23
+ const floatX* __restrict__ layer_out, // [B, T, C]
24
+ const floatX* __restrict__ x_expanded, // [B, T, n, C]
25
+ const float* __restrict__ H_post, // [B, T, n]
26
+ const float* __restrict__ H_res, // [B, T, n, n]
27
+ int64_t B, int64_t T, int n, int64_t C)
28
+ {
29
+ int64_t btc = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
30
+ int i = blockIdx.y; // 输出流索引
31
+
32
+ if (btc < B * T * C && i < n) {
33
+ int64_t bt = btc / C;
34
+ int64_t c = btc % C;
35
+ // [修改]: 增加 (int64_t)n 强转
36
+ int64_t bt_offset_n_c = bt * (int64_t)n * C;
37
+
38
+ // --- Stream Mix 部分 ---
39
+ float mixed_val = 0.0f;
40
+ // [修改]: 增加 (int64_t)n 和 (int64_t)i 强转,防止 i*n 在 32 位下溢出
41
+ int64_t res_base = bt * (int64_t)n * n + (int64_t)i * n;
42
+ #pragma unroll
43
+ for (int j = 0; j < MAX_N; j++) {
44
+ if (j < n) {
45
+ float w_res = H_res[res_base + j];
46
+ float val_x = to_float(x_expanded[bt_offset_n_c + (int64_t)j * C + c]);
47
+ mixed_val += w_res * val_x;
48
+ }
49
+ }
50
+
51
+ // --- Stream Distribute 部分 ---
52
+ float l_val = to_float(layer_out[btc]);
53
+ // [修改]: 增加 (int64_t)n 强转
54
+ float w_post = H_post[bt * (int64_t)n + i];
55
+ float dist_val = l_val * w_post;
56
+
57
+ x_next[bt_offset_n_c + (int64_t)i * C + c] = to_bf(mixed_val + dist_val);
58
+ }
59
+ }
60
+
61
+ /**
62
+ * 2. Fused Backward Full Kernel
63
+ * 计算:
64
+ * dl = sum_i(grad_next_i * H_post_i)
65
+ * dx_j = sum_i(grad_next_i * H_res_ij)
66
+ * dH_post_i = sum_c(grad_next_i * layer_out)
67
+ * dH_res_ij = sum_c(grad_next_i * x_expanded_j)
68
+ */
69
+ template<int BLOCK_SIZE, int MAX_N = 8>
70
+ __global__ void mhc_post_op_bwd_full_kernel(
71
+ floatX* __restrict__ d_layer_out, // [B, T, C]
72
+ floatX* __restrict__ d_x_expanded, // [B, T, n, C]
73
+ float* __restrict__ d_H_post, // [B, T, n]
74
+ float* __restrict__ d_H_res, // [B, T, n, n]
75
+ const floatX* __restrict__ grad_next, // [B, T, n, C]
76
+ const floatX* __restrict__ layer_out, // [B, T, C]
77
+ const floatX* __restrict__ x_expanded, // [B, T, n, C]
78
+ const float* __restrict__ H_post,
79
+ const float* __restrict__ H_res,
80
+ int64_t B, int64_t T, int n, int64_t C)
81
+ {
82
+ cg::thread_block block = cg::this_thread_block();
83
+ cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
84
+
85
+ int64_t bt = blockIdx.x;
86
+ if (bt >= B * T) return;
87
+
88
+ // [修改]: 增加 (int64_t)n 强转
89
+ int64_t bt_offset_n_c = bt * (int64_t)n * C;
90
+ int64_t bt_offset_c = bt * C;
91
+
92
+ // 局部存储用于参数梯度的 Reduce
93
+ // dw_post: [n], dw_res: [n, n]
94
+ float thread_dw_post[MAX_N] = {0.0f};
95
+ float thread_dw_res[MAX_N][MAX_N] = {0.0f};
96
+
97
+ // --- 1. 计算数据梯度 (dl, dx) ---
98
+ // 每个线程处理一个通道 c
99
+ for (int64_t c = threadIdx.x; c < C; c += BLOCK_SIZE) {
100
+ float l_val = to_float(layer_out[bt_offset_c + c]);
101
+
102
+ // 先读取所有流在该通道的梯度
103
+ float g_vals[MAX_N];
104
+ #pragma unroll
105
+ for(int i=0; i<MAX_N; i++) {
106
+ if(i < n) g_vals[i] = to_float(grad_next[bt_offset_n_c + (int64_t)i * C + c]);
107
+ }
108
+
109
+ // 计算 dl (数据梯度)
110
+ float dl_sum = 0.0f;
111
+ #pragma unroll
112
+ for(int i=0; i<MAX_N; i++) {
113
+ // [修改]: 增加 (int64_t)n 强转
114
+ if(i < n) dl_sum += g_vals[i] * H_post[bt * (int64_t)n + i];
115
+ }
116
+ d_layer_out[bt_offset_c + c] = to_bf(dl_sum);
117
+
118
+ // 计算 dx (数据梯度) 和 累加参数梯度局部和
119
+ #pragma unroll
120
+ for(int j=0; j<MAX_N; j++) {
121
+ if (j < n) {
122
+ float dx_j = 0.0f;
123
+ float xj_val = to_float(x_expanded[bt_offset_n_c + (int64_t)j * C + c]);
124
+
125
+ #pragma unroll
126
+ for(int i=0; i<MAX_N; i++) {
127
+ if (i < n) {
128
+ // [修改]: 增加 (int64_t)n 和 (int64_t)i 强转
129
+ dx_j += g_vals[i] * H_res[bt * (int64_t)n * n + (int64_t)i * n + j];
130
+ // 顺便计算 dH_res 的线程局部部分
131
+ thread_dw_res[i][j] += g_vals[i] * xj_val;
132
+ }
133
+ }
134
+ d_x_expanded[bt_offset_n_c + (int64_t)j * C + c] = to_bf(dx_j);
135
+ }
136
+ }
137
+
138
+ // 计算 dH_post 的线程局部部分
139
+ #pragma unroll
140
+ for(int i=0; i<MAX_N; i++) {
141
+ if(i < n) thread_dw_post[i] += g_vals[i] * l_val;
142
+ }
143
+ }
144
+
145
+ // --- 2. 参数梯度规约 (C 维度的 Reduction) ---
146
+ // 使用 Warp Shuffle 规约并写回
147
+ #pragma unroll
148
+ for(int i=0; i<MAX_N; i++) {
149
+ if(i < n) {
150
+ float sum_p = cg::reduce(warp, thread_dw_post[i], cg::plus<float>());
151
+ // [修改]: 增加 (int64_t)n 强转
152
+ if (warp.thread_rank() == 0) atomicAdd(&d_H_post[bt * (int64_t)n + i], sum_p);
153
+
154
+ #pragma unroll
155
+ for(int j=0; j<MAX_N; j++) {
156
+ if(j < n) {
157
+ float sum_r = cg::reduce(warp, thread_dw_res[i][j], cg::plus<float>());
158
+ // [修改]: 增加 (int64_t)n 和 (int64_t)i 强转
159
+ if (warp.thread_rank() == 0) atomicAdd(&d_H_res[bt * (int64_t)n * n + (int64_t)i * n + j], sum_r);
160
+ }
161
+ }
162
+ }
163
+ }
164
+ }
165
+
166
+ /* -------------------- API 包装函数 -------------------- */
167
+
168
+ inline void mhc_post_op_forward(
169
+ floatX* x_next, const floatX* layer_out, const floatX* x_expanded,
170
+ const float* H_post, const float* H_res,
171
+ int64_t B, int64_t T, int n, int64_t C, cudaStream_t stream)
172
+ {
173
+ int64_t total_btc = B * T * C;
174
+ dim3 threads(256);
175
+ dim3 blocks((total_btc + 255) / 256, n);
176
+ mhc_post_op_fwd_kernel<8><<<blocks, threads, 0, stream>>>(
177
+ x_next, layer_out, x_expanded, H_post, H_res, B, T, n, C);
178
+ }
179
+
180
+ inline void mhc_post_op_backward_full(
181
+ floatX* d_layer_out, floatX* d_x_expanded, float* d_H_post, float* d_H_res,
182
+ const floatX* grad_next, const floatX* layer_out, const floatX* x_expanded,
183
+ const float* H_post, const float* H_res,
184
+ int64_t B, int64_t T, int n, int64_t C, cudaStream_t stream)
185
+ {
186
+ // 每个 Block 负责一个 Token (bt) 的所有通道规约
187
+ const int BLOCK_SIZE = 256;
188
+ dim3 threads(BLOCK_SIZE);
189
+ dim3 blocks(B * T);
190
+ mhc_post_op_bwd_full_kernel<BLOCK_SIZE, 8><<<blocks, threads, 0, stream>>>(
191
+ d_layer_out, d_x_expanded, d_H_post, d_H_res,
192
+ grad_next, layer_out, x_expanded, H_post, H_res, B, T, n, C);
193
+ }
194
+
195
+ } // namespace mhc
196
+
197
+ #endif
@@ -0,0 +1,212 @@
1
+ #ifndef MHC_PRE_OP_CUH
2
+ #define MHC_PRE_OP_CUH
3
+
4
+ #include <cuda_runtime.h>
5
+ #include <cuda_bf16.h>
6
+ #include <cooperative_groups.h>
7
+ #include <cooperative_groups/reduce.h>
8
+ #include "../include/mhc_types.h"
9
+ #include "type_conversions.cuh"
10
+ #include "sinkhorn_knopp.cuh"
11
+
12
+ namespace cg = cooperative_groups;
13
+
14
+ namespace mhc {
15
+
16
+ /**
17
+ * 1. Fused Pre-Op Forward Kernel
18
+ * 修复:确保所有索引步进均使用 int64_t 避免在大模型/长序列下溢出
19
+ */
20
+ template<int MAX_N = 8>
21
+ __global__ void mhc_pre_op_fwd_kernel(
22
+ floatX* __restrict__ x_layer_in,
23
+ float* __restrict__ H_pre_out,
24
+ float* __restrict__ H_post_out,
25
+ const floatX* __restrict__ x_expanded,
26
+ const float* __restrict__ h_pre_raw,
27
+ const float* __restrict__ h_post_raw,
28
+ int64_t B, int64_t T, int n, int64_t C)
29
+ {
30
+ int64_t btc = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
31
+ if (btc >= B * T * C) return;
32
+
33
+ int64_t bt = btc / C;
34
+ int64_t c = btc % C;
35
+ int64_t bt_offset_n = bt * (int64_t)n;
36
+
37
+ float H_pre[MAX_N];
38
+ #pragma unroll
39
+ for (int i = 0; i < MAX_N; i++) {
40
+ if (i < n) {
41
+ float val_pre = h_pre_raw[bt_offset_n + (int64_t)i];
42
+ H_pre[i] = 1.0f / (1.0f + __expf(-val_pre));
43
+
44
+ if (c == 0) {
45
+ H_pre_out[bt_offset_n + (int64_t)i] = H_pre[i];
46
+ float val_post = h_post_raw[bt_offset_n + (int64_t)i];
47
+ // 2.0 * sigmoid 逻辑保持不变
48
+ H_post_out[bt_offset_n + (int64_t)i] = 2.0f * (1.0f / (1.0f + __expf(-val_post)));
49
+ }
50
+ }
51
+ }
52
+
53
+ float sum_val = 0.0f;
54
+ int64_t bt_offset_n_c = bt * (int64_t)n * C;
55
+ #pragma unroll
56
+ for (int i = 0; i < MAX_N; i++) {
57
+ if (i < n) {
58
+ float xi = to_float(x_expanded[bt_offset_n_c + (int64_t)i * C + c]);
59
+ sum_val += H_pre[i] * xi;
60
+ }
61
+ }
62
+
63
+ x_layer_in[btc] = to_bf(sum_val);
64
+ }
65
+
66
+ /**
67
+ * 2. Fused Pre-Op Backward Kernel
68
+ * 修复重点:
69
+ * 1. 强化规约逻辑:使用 block 级规约确保 sum_grad_x 的准确性。
70
+ * 2. 检查发现之前的 atomicAdd 虽然逻辑正确,但若输入 tensor 未在 python/cpp 层清零会导致 Fail。
71
+ * 3. 这里的 MAX_N 限制了能够并行处理的流数量。
72
+ */
73
+ template<int BLOCK_SIZE, int MAX_N = 8>
74
+ __global__ void mhc_pre_op_bwd_kernel(
75
+ floatX* __restrict__ d_x_expanded,
76
+ float* __restrict__ d_h_pre_raw,
77
+ float* __restrict__ d_h_post_raw,
78
+ const floatX* __restrict__ grad_layer_in,
79
+ const float* __restrict__ grad_H_post,
80
+ const floatX* __restrict__ x_expanded,
81
+ const float* __restrict__ H_pre,
82
+ const float* __restrict__ H_post,
83
+ int64_t B, int64_t T, int n, int64_t C)
84
+ {
85
+ // 定义共享内存用于 Block 级规约 (大小为 BLOCK_SIZE * n)
86
+ // 假设 MAX_N 很小 (如 8),256 * 8 * 4 bytes = 8KB,远小于显卡限制
87
+ __shared__ float s_reduce[BLOCK_SIZE][MAX_N];
88
+
89
+ int64_t bt = blockIdx.x;
90
+ if (bt >= B * T) return;
91
+
92
+ int tid = threadIdx.x;
93
+ int64_t bt_offset_n = bt * (int64_t)n;
94
+ int64_t bt_offset_c = bt * C;
95
+ int64_t bt_offset_n_c = bt * (int64_t)n * C;
96
+
97
+ // 初始化局部累加器
98
+ float thread_dh_pre_sum[MAX_N];
99
+ #pragma unroll
100
+ for(int i=0; i<MAX_N; ++i) thread_dh_pre_sum[i] = 0.0f;
101
+
102
+ // 1. 计算 dx 并收集局部和
103
+ for (int64_t c = (int64_t)tid; c < C; c += (int64_t)BLOCK_SIZE) {
104
+ float g_in = to_float(grad_layer_in[bt_offset_c + c]);
105
+
106
+ #pragma unroll
107
+ for (int i = 0; i < MAX_N; i++) {
108
+ if (i < n) {
109
+ float h_pre_i = H_pre[bt_offset_n + (int64_t)i];
110
+ d_x_expanded[bt_offset_n_c + (int64_t)i * C + c] = to_bf(g_in * h_pre_i);
111
+
112
+ float xi = to_float(x_expanded[bt_offset_n_c + (int64_t)i * C + c]);
113
+ thread_dh_pre_sum[i] += g_in * xi;
114
+ }
115
+ }
116
+ }
117
+
118
+ // 2. 将结果存入共享内存准备规约
119
+ #pragma unroll
120
+ for (int i = 0; i < MAX_N; i++) {
121
+ s_reduce[tid][i] = thread_dh_pre_sum[i];
122
+ }
123
+ __syncthreads();
124
+
125
+ // 3. 树状规约 (Tree Reduction)
126
+ for (int stride = BLOCK_SIZE / 2; stride > 0; stride >>= 1) {
127
+ if (tid < stride) {
128
+ #pragma unroll
129
+ for (int i = 0; i < MAX_N; i++) {
130
+ if (i < n) {
131
+ s_reduce[tid][i] += s_reduce[tid + stride][i];
132
+ }
133
+ }
134
+ }
135
+ __syncthreads();
136
+ }
137
+
138
+ // 4. 写回结果
139
+ if (tid == 0) {
140
+ #pragma unroll
141
+ for (int i = 0; i < MAX_N; i++) {
142
+ if (i < n) {
143
+ int64_t idx = bt_offset_n + (int64_t)i;
144
+ float sum_grad_x = s_reduce[0][i];
145
+
146
+ // d_h_pre_raw 梯度逻辑
147
+ float s_pre = H_pre[idx];
148
+ d_h_pre_raw[idx] = sum_grad_x * (s_pre * (1.0f - s_pre));
149
+
150
+ // d_h_post_raw 梯度逻辑
151
+ float s_post = H_post[idx] * 0.5f;
152
+ d_h_post_raw[idx] = grad_H_post[idx] * 2.0f * (s_post * (1.0f - s_post));
153
+ }
154
+ }
155
+ }
156
+ }
157
+
158
+ /* -------------------- API 封装 -------------------- */
159
+
160
+ inline void mhc_pre_op_forward(
161
+ floatX* x_layer_in, float* H_pre, float* H_post, float* H_res,
162
+ const floatX* x_expanded, const float* h_pre_raw, const float* h_post_raw, const float* h_res_raw,
163
+ int64_t B, int64_t T, int n, int64_t C, int sinkhorn_iters, float eps, cudaStream_t stream)
164
+ {
165
+ int64_t total_elements = B * T * C;
166
+ dim3 threads(256);
167
+ dim3 blocks((unsigned int)((total_elements + 255) / 256));
168
+
169
+ mhc_pre_op_fwd_kernel<8><<<blocks, threads, 0, stream>>>(
170
+ x_layer_in, H_pre, H_post, x_expanded, h_pre_raw, h_post_raw, B, T, n, C);
171
+
172
+ // 处理 Sinkhorn 投影
173
+ for (int64_t i = 0; i < B * T; i++) {
174
+ sinkhorn_knopp_forward(
175
+ H_res + i * (int64_t)n * n,
176
+ h_res_raw + i * (int64_t)n * n,
177
+ n, n, sinkhorn_iters, eps, stream
178
+ );
179
+ }
180
+ }
181
+
182
+ inline void mhc_pre_op_backward(
183
+ floatX* d_x_expanded, float* d_h_pre_raw, float* d_h_post_raw, float* d_h_res_raw,
184
+ const floatX* grad_layer_in, const float* grad_H_post, const float* grad_H_res,
185
+ const floatX* x_expanded, const float* H_pre, const float* H_post,
186
+ const float* H_res_out, const float* H_res_in_raw,
187
+ int64_t B, int64_t T, int n, int64_t C, int sinkhorn_iters, float eps, cudaStream_t stream)
188
+ {
189
+ const int BLOCK_SIZE = 256;
190
+ dim3 threads(BLOCK_SIZE);
191
+ dim3 blocks((unsigned int)(B * T));
192
+
193
+ // 调用反向内核
194
+ mhc_pre_op_bwd_kernel<BLOCK_SIZE, 8><<<blocks, threads, 0, stream>>>(
195
+ d_x_expanded, d_h_pre_raw, d_h_post_raw,
196
+ grad_layer_in, grad_H_post, x_expanded, H_pre, H_post, B, T, n, C);
197
+
198
+ // 处理 Sinkhorn 梯度
199
+ for (int64_t i = 0; i < B * T; i++) {
200
+ sinkhorn_knopp_backward(
201
+ d_h_res_raw + i * (int64_t)n * n,
202
+ grad_H_res + i * (int64_t)n * n,
203
+ H_res_out + i * (int64_t)n * n,
204
+ H_res_in_raw + i * (int64_t)n * n,
205
+ n, sinkhorn_iters, eps, stream
206
+ );
207
+ }
208
+ }
209
+
210
+ } // namespace mhc
211
+
212
+ #endif