rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,164 @@
1
+ #ifndef MHC_STREAM_MIX_CUH
2
+ #define MHC_STREAM_MIX_CUH
3
+
4
+ #include <cuda_runtime.h>
5
+ #include <cuda_bf16.h>
6
+ #include <cooperative_groups.h>
7
+ #include <cooperative_groups/reduce.h>
8
+ #include "../include/mhc_types.h"
9
+
10
+ namespace cg = cooperative_groups;
11
+
12
+ namespace mhc {
13
+
14
+ /**
15
+ * 1. 前向传播: Out = M @ Inp
16
+ * Shape: M [B, T, n, n] (FP32), Inp [B, T, n, C] (BF16) -> Out [B, T, n, C] (BF16)
17
+ * 公式: out[b, t, i, c] = \sum_{j=0}^{n-1} M[b, t, i, j] * inp[b, t, j, c]
18
+ */
19
+ __global__ void stream_mix_fwd_kernel(
20
+ floatX* __restrict__ out,
21
+ const floatX* __restrict__ inp,
22
+ const float* __restrict__ M,
23
+ int64_t B, int64_t T, int n, int64_t C) {
24
+
25
+ int64_t btc = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
26
+ int i = blockIdx.y; // 目标流索引 (Row of M)
27
+
28
+ if (btc < B * T * C && i < n) {
29
+ int64_t b_t = btc / C;
30
+ int64_t c = btc % C;
31
+
32
+ float sum = 0.0f;
33
+ #pragma unroll
34
+ for (int j = 0; j < 8; j++) { // 假设 n 最大为 8,可以根据需要调整或改用循环
35
+ if (j < n) {
36
+ float m_val = M[b_t * n * n + (int64_t)i * n + j];
37
+ float x_val = to_float(inp[b_t * n * C + (int64_t)j * C + c]);
38
+ sum += m_val * x_val;
39
+ }
40
+ }
41
+ out[b_t * n * C + (int64_t)i * C + c] = to_bf(sum);
42
+ }
43
+ }
44
+
45
+ /**
46
+ * 2. 反向传播 dx: dx = M^T @ grad
47
+ * 公式: dx[b, t, j, c] = \sum_{i=0}^{n-1} grad[b, t, i, c] * M[b, t, i, j]
48
+ */
49
+ __global__ void stream_mix_bwd_dx_kernel(
50
+ floatX* __restrict__ dx,
51
+ const float* __restrict__ grad, // 使用 FP32 梯度以保证精度
52
+ const float* __restrict__ M,
53
+ int64_t B, int64_t T, int n, int64_t C) {
54
+
55
+ int64_t btc = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
56
+ int j = blockIdx.y; // 输入流索引 (Column of M)
57
+
58
+ if (btc < B * T * C && j < n) {
59
+ int64_t b_t = btc / C;
60
+ int64_t c = btc % C;
61
+
62
+ float sum = 0.0f;
63
+ #pragma unroll
64
+ for (int i = 0; i < 8; i++) {
65
+ if (i < n) {
66
+ float m_val = M[b_t * n * n + (int64_t)i * n + j]; // 注意 M 这里的索引是 [i, j]
67
+ float g_val = grad[b_t * n * C + (int64_t)i * C + c];
68
+ sum += m_val * g_val;
69
+ }
70
+ }
71
+ dx[b_t * n * C + (int64_t)j * C + c] = to_bf(sum);
72
+ }
73
+ }
74
+
75
+ /**
76
+ * 3. 反向传播 dM (优化版): dM = grad @ Inp^T
77
+ * 公式: dM[b, t, i, j] = \sum_{c=0}^{C-1} grad[b, t, i, c] * inp[b, t, j, c]
78
+ * 每个 Block 负责计算 dM 的一个元素,利用共享内存进行并行规约
79
+ */
80
+ template<int BLOCK_SIZE>
81
+ __global__ void stream_mix_bwd_dm_optimized_kernel(
82
+ float* __restrict__ dm,
83
+ const float* __restrict__ grad,
84
+ const floatX* __restrict__ inp,
85
+ int64_t B, int64_t T, int n, int64_t C) {
86
+
87
+ cg::thread_block block = cg::this_thread_block();
88
+ cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
89
+
90
+ // blockIdx.x 对应序列维度 (B*T)
91
+ // blockIdx.y 对应 M 的行 i
92
+ // blockIdx.z 对应 M 的列 j
93
+ int64_t bt = blockIdx.x;
94
+ int i = blockIdx.y;
95
+ int j = blockIdx.z;
96
+
97
+ if (bt >= B * T || i >= n || j >= n) return;
98
+
99
+ extern __shared__ float s_reduce[];
100
+
101
+ float thread_sum = 0.0f;
102
+ int64_t grad_offset = bt * n * C + (int64_t)i * C;
103
+ int64_t inp_offset = bt * n * C + (int64_t)j * C;
104
+
105
+ // 1. 线程局部求和
106
+ for (int64_t c = threadIdx.x; c < C; c += BLOCK_SIZE) {
107
+ float g_val = grad[grad_offset + c];
108
+ float x_val = to_float(inp[inp_offset + c]);
109
+ thread_sum += g_val * x_val;
110
+ }
111
+
112
+ // 2. Warp 级规约
113
+ float warp_sum = cg::reduce(warp, thread_sum, cg::plus<float>());
114
+
115
+ int warp_id = threadIdx.x / 32;
116
+ int lane_id = threadIdx.x % 32;
117
+
118
+ if (lane_id == 0) {
119
+ s_reduce[warp_id] = warp_sum;
120
+ }
121
+ block.sync();
122
+
123
+ // 3. Block 级规约 (由第一个 Warp 完成)
124
+ if (warp_id == 0) {
125
+ float val = (lane_id < (BLOCK_SIZE / 32)) ? s_reduce[lane_id] : 0.0f;
126
+ float block_sum = cg::reduce(warp, val, cg::plus<float>());
127
+ if (lane_id == 0) {
128
+ dm[bt * n * n + (int64_t)i * n + j] = block_sum;
129
+ }
130
+ }
131
+ }
132
+
133
+ /* -------------------- API 包装函数 -------------------- */
134
+
135
+ inline void stream_mix_forward(
136
+ floatX* out, const floatX* inp, const float* M,
137
+ int64_t B, int64_t T, int n, int64_t C, cudaStream_t stream) {
138
+
139
+ dim3 threads(256);
140
+ // x方向覆盖总元素,y方向负责矩阵行索引
141
+ dim3 blocks((B * T * C + 255) / 256, n);
142
+ stream_mix_fwd_kernel<<<blocks, threads, 0, stream>>>(out, inp, M, B, T, n, C);
143
+ }
144
+
145
+ inline void stream_mix_backward(
146
+ floatX* dx, float* dm, const float* grad, const floatX* inp, const float* M,
147
+ int64_t B, int64_t T, int n, int64_t C, cudaStream_t stream) {
148
+
149
+ // 1. 计算 dx
150
+ dim3 threads_dx(256);
151
+ dim3 blocks_dx((B * T * C + 255) / 256, n);
152
+ stream_mix_bwd_dx_kernel<<<blocks_dx, threads_dx, 0, stream>>>(dx, grad, M, B, T, n, C);
153
+
154
+ // 2. 计算 dm (每个元素一个 Block 以实现 C 轴并行规约)
155
+ constexpr int DM_BLOCK_SIZE = 256;
156
+ dim3 grid_dm(B * T, n, n);
157
+ size_t smem_size = (DM_BLOCK_SIZE / 32) * sizeof(float);
158
+ stream_mix_bwd_dm_optimized_kernel<DM_BLOCK_SIZE>
159
+ <<<grid_dm, DM_BLOCK_SIZE, smem_size, stream>>>(dm, grad, inp, B, T, n, C);
160
+ }
161
+
162
+ } // namespace mhc
163
+
164
+ #endif
@@ -0,0 +1,52 @@
1
+ #pragma once
2
+
3
+ #include <cuda_runtime.h>
4
+ #include <cuda_bf16.h>
5
+ #include "../include/mhc_types.h"
6
+
7
+ namespace mhc {
8
+
9
+ template<int BLOCK_SIZE>
10
+ // [修改]: size 参数改为 int64_t
11
+ __global__ void float_to_bf16_kernel(floatX* __restrict__ out, const float* __restrict__ inp, int64_t size) {
12
+ // [修改]: idx 改为 int64_t,并强制转换 blockIdx.x 避免 32 位乘法溢出
13
+ int64_t idx = (int64_t)blockIdx.x * BLOCK_SIZE + threadIdx.x;
14
+ if (idx < size) {
15
+ out[idx] = to_bf(inp[idx]); // 使用新定义的工具
16
+ }
17
+ }
18
+
19
+ template<int BLOCK_SIZE>
20
+ // [修改]: size 参数改为 int64_t
21
+ __global__ void bf16_to_float_kernel(float* __restrict__ out, const floatX* __restrict__ inp, int64_t size) {
22
+ // [修改]: idx 改为 int64_t,并强制转换 blockIdx.x
23
+ int64_t idx = (int64_t)blockIdx.x * BLOCK_SIZE + threadIdx.x;
24
+ if (idx < size) {
25
+ out[idx] = to_float(inp[idx]); // 使用新定义的工具
26
+ }
27
+ }
28
+
29
+ // [修改]: size 参数改为 int64_t
30
+ inline void float_to_bf16(floatX* out, const float* inp, int64_t size, cudaStream_t stream = nullptr) {
31
+ constexpr int BLOCK_SIZE = 256;
32
+ // num_blocks 本身通常不会溢出 int (除非 size > 5000亿),但计算过程需用 64 位
33
+ int num_blocks = (size + BLOCK_SIZE - 1) / BLOCK_SIZE;
34
+ float_to_bf16_kernel<BLOCK_SIZE><<<num_blocks, BLOCK_SIZE, 0, stream>>>(out, inp, size);
35
+ }
36
+
37
+ // [修改]: size 参数改为 int64_t
38
+ inline void bf16_to_float(float* out, const floatX* inp, int64_t size, cudaStream_t stream = nullptr) {
39
+ constexpr int BLOCK_SIZE = 256;
40
+ int num_blocks = (size + BLOCK_SIZE - 1) / BLOCK_SIZE;
41
+ bf16_to_float_kernel<BLOCK_SIZE><<<num_blocks, BLOCK_SIZE, 0, stream>>>(out, inp, size);
42
+ }
43
+
44
+ __device__ __forceinline__ float fast_exp(float x) {
45
+ return __expf(x);
46
+ }
47
+
48
+ __device__ __forceinline__ float fast_sigmoid(float x) {
49
+ return __frcp_rn(1.0f + __expf(-x));
50
+ }
51
+
52
+ } // namespace mhc
@@ -0,0 +1,47 @@
1
+ cmake_minimum_required(VERSION 3.18)
2
+ project(mhu_jax LANGUAGES CXX CUDA)
3
+
4
+ find_package(CUDAToolkit REQUIRED)
5
+ find_package(Python3 REQUIRED COMPONENTS Interpreter)
6
+
7
+ # 获取XLA头文件路径
8
+ execute_process(
9
+ COMMAND "${Python3_EXECUTABLE}" -c "from jax import ffi; print(ffi.include_dir())"
10
+ OUTPUT_VARIABLE XLA_INCLUDE_DIR
11
+ OUTPUT_STRIP_TRAILING_WHITESPACE
12
+ )
13
+ if(NOT XLA_INCLUDE_DIR)
14
+ message(FATAL_ERROR "无法从jax.ffi获取XLA头文件路径,请确保JAX版本>=0.4.31")
15
+ endif()
16
+ message(STATUS "XLA include directory: ${XLA_INCLUDE_DIR}")
17
+
18
+ # 设置公共头文件路径
19
+ set(COMMON_KERNEL_DIR "${CMAKE_SOURCE_DIR}/../common_kernel")
20
+
21
+ # 生成共享库(库名改为mhu_ffi避免冲突)
22
+ add_library(mhu_ffi SHARED mhu_ffi.cu)
23
+
24
+ # 包含路径
25
+ target_include_directories(mhu_ffi PRIVATE
26
+ ${XLA_INCLUDE_DIR}
27
+ ${COMMON_KERNEL_DIR}/include
28
+ ${COMMON_KERNEL_DIR}/kernels
29
+ )
30
+
31
+ # 链接CUDA运行时
32
+ target_link_libraries(mhu_ffi PRIVATE CUDA::cudart)
33
+
34
+ # 编译选项
35
+ target_compile_features(mhu_ffi PUBLIC cxx_std_17)
36
+ set_target_properties(mhu_ffi PROPERTIES
37
+ CUDA_STANDARD 17
38
+ CUDA_SEPARABLE_COMPILATION ON
39
+ POSITION_INDEPENDENT_CODE ON
40
+ PREFIX "" # 移除lib前缀
41
+ OUTPUT_NAME "mhu" # 输出文件名仍为mhu.so
42
+ )
43
+
44
+ # 安装到源码目录(关键:与Python查找路径一致)
45
+ install(TARGETS mhu_ffi
46
+ LIBRARY DESTINATION "${CMAKE_SOURCE_DIR}"
47
+ RUNTIME DESTINATION "${CMAKE_SOURCE_DIR}")