rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,108 @@
1
+ import keras
2
+ from keras import ops
3
+
4
+
5
+ def transpose_head(x, head_first):
6
+ """
7
+ 对输入张量进行转置操作。
8
+
9
+ 参数:
10
+ x: 输入张量。
11
+ head_first: 布尔值,决定是否进行转置。
12
+
13
+ 返回:
14
+ 转置后的张量(如果head_first为True),否则返回原张量。
15
+ """
16
+ x = ops.cast(x, "float32")
17
+ if head_first:
18
+ return ops.transpose(x, (0, 2, 1, 3))
19
+ else:
20
+ return x
21
+
22
+
23
+ def generalized_delta_rule(
24
+ r,
25
+ w,
26
+ k,
27
+ v,
28
+ a,
29
+ b,
30
+ initial_state=None,
31
+ output_final_state: bool = True,
32
+ head_first: bool = False,
33
+ ):
34
+ """
35
+ 实现广义delta规则的函数。
36
+
37
+ 参数:
38
+ r: 输入张量。
39
+ w: 权重张量。
40
+ k, v, a, b: 其他输入张量。
41
+ initial_state: 初始状态张量。
42
+ output_final_state: 是否输出最终状态。
43
+ head_first: 是否在计算中将head维度放在第一位。
44
+
45
+ 返回:
46
+ 根据output_final_state参数决定是否返回最终状态。
47
+ """
48
+ DTYPE = r.dtype
49
+ B, T, H, N = ops.shape(r)
50
+ r = transpose_head(r, head_first)
51
+
52
+ k = transpose_head(k, head_first)
53
+
54
+ v = transpose_head(v, head_first)
55
+ a = transpose_head(a, head_first)
56
+ b = transpose_head(b, head_first)
57
+ w = transpose_head(w, head_first)
58
+ w = ops.exp(-ops.exp(w))
59
+
60
+ if initial_state is not None:
61
+ state = initial_state
62
+ if ops.shape(state)[0] == 1:
63
+ state = ops.broadcast_to(state, (B, H, N, N))
64
+ else:
65
+ state = ops.zeros((B, H, N, N))
66
+ state = ops.cast(state, "float32")
67
+
68
+ keras_backend = keras.config.backend()
69
+
70
+ def step(t, inputs):
71
+ """
72
+ 执行单个时间步的计算。
73
+
74
+ 参数:
75
+ t: 当前时间步。
76
+ inputs: 包含当前状态和输出的列表。
77
+
78
+ 返回:
79
+ 更新后的状态和输出。
80
+ """
81
+ state, out = inputs
82
+ kk = ops.reshape(k[:, t, :], (B, H, 1, N))
83
+ rr = ops.reshape(r[:, t, :], (B, H, N, 1))
84
+ vv = ops.reshape(v[:, t, :], (B, H, N, 1))
85
+ aa = ops.reshape(a[:, t, :], (B, H, N, 1))
86
+ bb = ops.reshape(b[:, t, :], (B, H, 1, N))
87
+ state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk
88
+ o = ops.cast((state @ rr), out.dtype)
89
+ if keras_backend == "tensorflow":
90
+ out = out.write(t, ops.reshape(o, (B, H, N)))
91
+ elif keras_backend == "torch":
92
+ out[:, t : t + 1] = ops.reshape(o, (B, 1, H, N))
93
+ else:
94
+ out = ops.slice_update(out, [0, t, 0, 0], ops.reshape(o, (B, 1, H, N)))
95
+ return [state, out]
96
+
97
+ if keras_backend == "tensorflow":
98
+ import tensorflow as tf
99
+
100
+ out = tf.TensorArray(DTYPE, size=T)
101
+ else:
102
+ out = ops.zeros((B, T, H, N), DTYPE)
103
+ state, out = ops.fori_loop(0, T, step, [state, out])
104
+ if keras_backend == "tensorflow":
105
+ out = ops.transpose(out.stack(), [1, 0, 2, 3])
106
+ if output_final_state:
107
+ return ops.cast(out, DTYPE), state
108
+ return ops.cast(out, DTYPE)
@@ -0,0 +1,155 @@
1
+ """
2
+ TensorFlow 版 generalized_delta_rule
3
+ 前向用 tf.py_function 调 JAX CUDA 内核,反向同样走 JAX。
4
+ 可 @tf.function 编译,可 tf.GradientTape 训练。
5
+ """
6
+
7
+ import tensorflow as tf
8
+ from typing import Optional, Tuple
9
+ import jax.numpy as jnp
10
+ from .jax_cuda_kernel.wkv7_jax import get_jax_generalized_delta_rule
11
+ from .jax_cuda_kernel_single.wkv7_single_step_jax import (
12
+ get_jax_generalized_delta_rule_single_step,
13
+ )
14
+
15
+
16
+ def transpose_head(x, head_first: bool):
17
+ """(B, T, H, K) <-> (B, H, T, K)"""
18
+ x = tf.cast(x, dtype=tf.float32)
19
+ if head_first:
20
+ return tf.transpose(x, (0, 2, 1, 3))
21
+ return x
22
+
23
+
24
+ def get_tf_generalized_delta_rule(HEAD_SIZE=64):
25
+ generalized_delta_rule_inference = get_jax_generalized_delta_rule(HEAD_SIZE)[1]
26
+
27
+ # ---------- 底层 kernel 包装 ----------
28
+ @tf.py_function(Tout=[tf.bfloat16, tf.float32])
29
+ def _tf_wkv7_fwd(w, q, k, v, a, b, h0):
30
+ """tf.py_function 包装 JAX 前向"""
31
+ y, s = generalized_delta_rule_inference(
32
+ w=jnp.asarray(w, jnp.bfloat16),
33
+ r=jnp.asarray(q, jnp.bfloat16),
34
+ k=jnp.asarray(k, jnp.bfloat16),
35
+ v=jnp.asarray(v, jnp.bfloat16),
36
+ a=jnp.asarray(a, jnp.bfloat16),
37
+ b=jnp.asarray(b, jnp.bfloat16),
38
+ initial_state=jnp.asarray(h0, jnp.float32),
39
+ )
40
+ return (
41
+ tf.convert_to_tensor(y, tf.bfloat16),
42
+ tf.convert_to_tensor(s, tf.float32),
43
+ )
44
+
45
+ # ---------- 用户接口 ----------
46
+ def generalized_delta_rule(
47
+ r: tf.Tensor, # (B, T, H, K) 或 (B, H, T, K)
48
+ w: tf.Tensor,
49
+ k: tf.Tensor,
50
+ v: tf.Tensor,
51
+ a: tf.Tensor,
52
+ b: tf.Tensor,
53
+ initial_state: Optional[tf.Tensor] = None,
54
+ output_final_state: bool = True,
55
+ head_first: bool = False,
56
+ chunk_len: int = 16,
57
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
58
+ """
59
+ 与 JAX 版接口 1:1 对齐,返回 (out, last_state)
60
+ 可 @tf.function compile,可 tf.GradientTape 训练
61
+ """
62
+ dtype = r.dtype
63
+
64
+ r = transpose_head(r, head_first)
65
+ w = transpose_head(w, head_first)
66
+ k = transpose_head(k, head_first)
67
+ v = transpose_head(v, head_first)
68
+ a = transpose_head(a, head_first)
69
+ b = transpose_head(b, head_first)
70
+
71
+ B, T, H, K = tf.unstack(tf.shape(r), num=4)
72
+ if T % chunk_len != 0:
73
+ raise ValueError(f"T={T} must be divisible by chunk_len={chunk_len}")
74
+
75
+ if initial_state is None:
76
+ h0 = tf.zeros([B, H, K, K], dtype=tf.float32)
77
+ else:
78
+ h0 = tf.cast(initial_state, tf.float32)
79
+
80
+ # 带梯度前向
81
+ out, last_state = _tf_wkv7_fwd(w, r, k, v, a, b, h0)
82
+
83
+ # 转回用户期望 dtype
84
+ out = tf.cast(out, dtype)
85
+
86
+ return (out, last_state) if output_final_state else out
87
+
88
+ return generalized_delta_rule
89
+
90
+
91
+ def get_tf_generalized_delta_rule_single_step(HEAD_SIZE=64):
92
+ # 获取 JAX 版本的单步 generalized delta rule
93
+ _wkv7_single_step_kernel = get_jax_generalized_delta_rule_single_step(HEAD_SIZE)
94
+
95
+ # ---------- 底层 kernel 包装 ----------
96
+ @tf.py_function(Tout=[tf.bfloat16, tf.float32])
97
+ def _tf_wkv7_single_step_fwd(w, r, k, v, a, b, h0):
98
+ """tf.py_function 包装 JAX 单步前向"""
99
+ y, s = _wkv7_single_step_kernel(
100
+ w=jnp.asarray(w, jnp.bfloat16),
101
+ r=jnp.asarray(r, jnp.bfloat16),
102
+ k=jnp.asarray(k, jnp.bfloat16),
103
+ v=jnp.asarray(v, jnp.bfloat16),
104
+ a=jnp.asarray(a, jnp.bfloat16),
105
+ b=jnp.asarray(b, jnp.bfloat16),
106
+ initial_state=jnp.asarray(h0, jnp.float32),
107
+ )
108
+ return (
109
+ tf.convert_to_tensor(y, tf.bfloat16),
110
+ tf.convert_to_tensor(s, tf.float32),
111
+ )
112
+
113
+ # ---------- 用户接口 ----------
114
+ def generalized_delta_rule_single_step(
115
+ r: tf.Tensor, # (B, 1, H, K) 或 (B, H, 1, K)
116
+ w: tf.Tensor,
117
+ k: tf.Tensor,
118
+ v: tf.Tensor,
119
+ a: tf.Tensor,
120
+ b: tf.Tensor,
121
+ initial_state: Optional[tf.Tensor] = None,
122
+ output_final_state: bool = True,
123
+ head_first: bool = False,
124
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
125
+ """
126
+ 单步 generalized delta rule 实现
127
+ 与 JAX 版单步接口对齐,返回 (out, last_state)
128
+ """
129
+ dtype = r.dtype
130
+
131
+ r = transpose_head(r, head_first)
132
+ w = transpose_head(w, head_first)
133
+ k = transpose_head(k, head_first)
134
+ v = transpose_head(v, head_first)
135
+ a = transpose_head(a, head_first)
136
+ b = transpose_head(b, head_first)
137
+
138
+ B, T, H, K = tf.unstack(tf.shape(r), num=4)
139
+ if T != 1:
140
+ raise ValueError(f"Single-step kernel requires T=1, but got T={T}")
141
+
142
+ if initial_state is None:
143
+ h0 = tf.zeros([B, H, K, K], dtype=tf.float32)
144
+ else:
145
+ h0 = tf.cast(initial_state, tf.float32)
146
+
147
+ # 前向计算
148
+ y, s = _tf_wkv7_single_step_fwd(w, r, k, v, a, b, h0)
149
+
150
+ # 转回用户期望 dtype
151
+ out = tf.cast(y, dtype)
152
+
153
+ return (out, s) if output_final_state else out
154
+
155
+ return generalized_delta_rule_single_step
@@ -0,0 +1,235 @@
1
+ #include <cuda_bf16.h>
2
+ #include <assert.h>
3
+ #include <cstdint>
4
+ // ref link:https://github.com/BlinkDL/RWKV-CUDA/tree/main/rwkv7_fast_fused
5
+ using bf = __nv_bfloat16;
6
+
7
+ __device__ inline float to_float(const bf & u) {
8
+ return __bfloat162float(u);
9
+ }
10
+
11
+ __device__ inline bf to_bf(const float & u) {
12
+ return __float2bfloat16_rn(u);
13
+ }
14
+ typedef bf * __restrict__ F_;
15
+
16
+ /* -------------------- 前向传播 Kernel -------------------- */
17
+ template<int C> __launch_bounds__(C, 2) // 【优化1】显式指定 launch bounds,提升 Occupancy
18
+ __global__ void forward_kernel(int T, int H,
19
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
20
+ bf* y_, float* s_, float* sa_, float* h0_) {
21
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
22
+ float state[C] = {0};
23
+ __shared__ float q[C], k[C], w[C], a[C], b[C];
24
+
25
+ int64_t h0_base = ((int64_t)bb*H + hh)*C*C + i*C;
26
+ #pragma unroll
27
+ for (int j = 0; j < C; j++) {
28
+ state[j] = h0_[h0_base + j];
29
+ }
30
+
31
+ for (int t = 0; t < T; t++) {
32
+ int64_t ind = (int64_t)bb*T*H*C + (int64_t)t*H*C + hh * C + i;
33
+
34
+ __syncthreads();
35
+ q[i] = to_float(q_[ind]);
36
+ w[i] = __expf(-__expf(to_float(w_[ind])));
37
+ k[i] = to_float(k_[ind]);
38
+ a[i] = to_float(a_[ind]);
39
+ b[i] = to_float(b_[ind]);
40
+ __syncthreads();
41
+
42
+ float sa = 0;
43
+ #pragma unroll
44
+ for (int j = 0; j < C; j++) {
45
+ sa += a[j] * state[j];
46
+ }
47
+ sa_[ind] = sa;
48
+
49
+ float v_val = to_float(v_[ind]);
50
+ float y = 0;
51
+ #pragma unroll
52
+ for (int j = 0; j < C; j++) {
53
+ float &s = state[j];
54
+ s = s * w[j] + sa * b[j] + k[j] * v_val;
55
+ y += s * q[j];
56
+ }
57
+ y_[ind] = to_bf(y);
58
+
59
+ if ((t+1)%_CHUNK_LEN_ == 0) {
60
+ int64_t base = ((int64_t)bb*H+hh)*(T/_CHUNK_LEN_)*C*C + ((int64_t)t/_CHUNK_LEN_)*C*C + i;
61
+ #pragma unroll
62
+ for (int j = 0; j < C; j++) {
63
+ s_[base + j*C] = state[j];
64
+ }
65
+ }
66
+ }
67
+ }
68
+
69
+ /* -------------------- 反向传播 Kernel -------------------- */
70
+ template<int C> __launch_bounds__(C, 2) // 【优化1】显式指定 launch bounds
71
+ __global__ void backward_kernel(int T, int H,
72
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_,
73
+ float * __restrict__ s_, float * __restrict__ sa_,
74
+ float * __restrict__ dht_, float * __restrict__ dh0_,
75
+ bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) {
76
+
77
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
78
+ float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
79
+
80
+ int64_t dht_base = ((int64_t)bb*H + hh)*C*C + i*C;
81
+ #pragma unroll
82
+ for (int j = 0; j < C; j++) {
83
+ dstate[j] = dht_[dht_base + j];
84
+ dstateT[j] = dht_[dht_base + j];
85
+ }
86
+
87
+ __shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
88
+ float qi, wi, ki, ai, bi, dyi;
89
+
90
+ for (int t = T-1; t >= 0; t--) {
91
+ int64_t ind = (int64_t)bb*T*H*C + (int64_t)t*H*C + hh * C + i;
92
+
93
+ __syncthreads();
94
+ q[i] = qi = to_float(q_[ind]);
95
+ float wi_fac = -__expf(to_float(w_[ind]));
96
+ w[i] = wi = __expf(wi_fac);
97
+ k[i] = ki = to_float(k_[ind]);
98
+ v[i] = to_float(v_[ind]);
99
+ a[i] = ai = to_float(a_[ind]);
100
+ b[i] = bi = to_float(b_[ind]);
101
+ dy[i] = dyi = to_float(dy_[ind]);
102
+ sa[i] = sa_[ind];
103
+ __syncthreads();
104
+
105
+ if ((t+1)%_CHUNK_LEN_ == 0) {
106
+ int64_t base = ((int64_t)bb*H+hh)*(T/_CHUNK_LEN_)*C*C + ((int64_t)t/_CHUNK_LEN_)*C*C + i*C;
107
+
108
+ // 【优化2】使用 float4 向量加载,内存带宽提升 4倍
109
+ const float4* s4 = (const float4*)(s_ + base);
110
+ #pragma unroll
111
+ for (int j4 = 0; j4 < C/4; j4++) {
112
+ float4 q_vec = s4[j4];
113
+ const int j = j4 * 4;
114
+ stateT[j+0] = q_vec.x;
115
+ stateT[j+1] = q_vec.y;
116
+ stateT[j+2] = q_vec.z;
117
+ stateT[j+3] = q_vec.w;
118
+ }
119
+ }
120
+
121
+ float dq_val = 0;
122
+ #pragma unroll
123
+ for (int j = 0; j < C; j++) {
124
+ dq_val += stateT[j] * dy[j];
125
+ }
126
+ dq_[ind] = to_bf(dq_val);
127
+
128
+ float iwi = 1.0f/(wi + 0.000001f);
129
+ #pragma unroll
130
+ for (int j = 0; j < C; j++) {
131
+ stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi;
132
+ dstate[j] += dyi * q[j];
133
+ dstateT[j] += qi * dy[j];
134
+ }
135
+
136
+ float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0;
137
+ #pragma unroll
138
+ for (int j = 0; j < C; j++) {
139
+ dw += dstateT[j] * stateT[j];
140
+ dk += dstateT[j] * v[j];
141
+ dv += dstate[j] * k[j];
142
+ dSb += dstate[j] * b[j];
143
+ db += dstateT[j] * sa[j];
144
+ }
145
+ dw_[ind] = to_bf(dw * wi * wi_fac);
146
+ dk_[ind] = to_bf(dk);
147
+ dv_[ind] = to_bf(dv);
148
+ db_[ind] = to_bf(db);
149
+
150
+ __syncthreads();
151
+ dSb_shared[i] = dSb;
152
+ __syncthreads();
153
+
154
+ float da = 0;
155
+ #pragma unroll
156
+ for (int j = 0; j < C; j++) {
157
+ da += stateT[j] * dSb_shared[j];
158
+ }
159
+ da_[ind] = to_bf(da);
160
+
161
+ #pragma unroll
162
+ for (int j = 0; j < C; j++) {
163
+ dstate[j] = dstate[j] * w[j] + dSb * a[j];
164
+ dstateT[j] = dstateT[j] * wi + ai * dSb_shared[j];
165
+ if (t == 0) {
166
+ dh0_[dht_base + j] = dstate[j];
167
+ }
168
+ }
169
+ }
170
+ }
171
+
172
+ /* -------------------- 推理专用 Kernel -------------------- */
173
+ template<int C> __launch_bounds__(C, 2) // 【优化1】推理 kernel 同样优化
174
+ __global__ void forward_inference_kernel(int T, int H,
175
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
176
+ bf *y_, float *s_, float *h0_) {
177
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
178
+ float state[C] = {0};
179
+ __shared__ float q[C], k[C], w[C], a[C], b[C];
180
+
181
+ int64_t h0_base = ((int64_t)bb * H + hh) * C * C + i * C;
182
+ #pragma unroll
183
+ for (int j = 0; j < C; ++j) state[j] = h0_[h0_base + j];
184
+
185
+ for (int t = 0; t < T; ++t) {
186
+ int64_t ind = (int64_t)bb * T * H * C + (int64_t)t * H * C + hh * C + i;
187
+
188
+ __syncthreads();
189
+ q[i] = to_float(q_[ind]);
190
+ w[i] = __expf(-__expf(to_float(w_[ind])));
191
+ k[i] = to_float(k_[ind]);
192
+ a[i] = to_float(a_[ind]);
193
+ b[i] = to_float(b_[ind]);
194
+ __syncthreads();
195
+
196
+ float sa = 0.f;
197
+ #pragma unroll
198
+ for (int j = 0; j < C; ++j) sa += a[j] * state[j];
199
+
200
+ float v_val = to_float(v_[ind]);
201
+ float y = 0.f;
202
+ #pragma unroll
203
+ for (int j = 0; j < C; ++j) {
204
+ float &s = state[j];
205
+ s = s * w[j] + sa * b[j] + k[j] * v_val;
206
+ y += s * q[j];
207
+ }
208
+ y_[ind] = to_bf(y);
209
+ }
210
+
211
+ // 仅写入最终状态
212
+ int64_t base = ((int64_t)bb * H + hh) * C * C + i * C;
213
+ #pragma unroll
214
+ for (int j = 0; j < C; ++j) s_[base + j] = state[j];
215
+ }
216
+
217
+ /* -------------------- C 接口函数 -------------------- */
218
+ void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa, float* h0) {
219
+ forward_kernel<_C_><<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,y,s,sa,h0);
220
+ }
221
+
222
+ void cuda_backward(int B, int T, int H,
223
+ bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy,
224
+ float*s, float*sa,float*dht,float*dh0,
225
+ bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da
226
+ ) {
227
+ assert(T%_CHUNK_LEN_ == 0);
228
+ backward_kernel<_C_><<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,dy,s,sa,dht,dh0,dw,dq,dk,dv,dz,da);
229
+ }
230
+
231
+ void cuda_forward_inference(int B, int T, int H,
232
+ bf* w, bf* q, bf* k, bf* v, bf* a, bf* b,
233
+ bf* y, float* s, float* h0) {
234
+ forward_inference_kernel<_C_><<<dim3(H, B), dim3(_C_)>>>(T, H, w, q, k, v, a, b, y, s, h0);
235
+ }
@@ -0,0 +1,63 @@
1
+ #include <torch/extension.h>
2
+ #include <cuda_bf16.h>
3
+
4
+ using bf = __nv_bfloat16;
5
+
6
+ // ---------- 原有函数声明 ----------
7
+ void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa, float* h0);
8
+ void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa,float*dht,float*dh0, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da);
9
+
10
+ // ---------- 新增推理函数声明(必须!)----------
11
+ void cuda_forward_inference(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*a, bf*b, bf*y, float*s, float* h0);
12
+
13
+ // ---------- 原有forward函数 ----------
14
+ void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v,
15
+ torch::Tensor &z, torch::Tensor &a,
16
+ torch::Tensor &y,
17
+ torch::Tensor &s, torch::Tensor &sa,torch::Tensor &h0) {
18
+ int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
19
+ cuda_forward(B, T, H,
20
+ (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)y.data_ptr(),
21
+ (float*)s.data_ptr(), (float*)sa.data_ptr(),(float*)h0.data_ptr());
22
+ }
23
+
24
+ // ---------- 原有backward函数 ----------
25
+ void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy,
26
+ torch::Tensor &s, torch::Tensor &sa,torch::Tensor &dht,torch::Tensor &dh0,
27
+ torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da) {
28
+ int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
29
+ cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(),
30
+ (bf*)dy.data_ptr(),
31
+ (float*)s.data_ptr(), (float*)sa.data_ptr(),(float*)dht.data_ptr(),(float*)dh0.data_ptr(),
32
+ (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr());
33
+ }
34
+
35
+ // ---------- 新增推理forward函数 ----------
36
+ void forward_inference(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v,
37
+ torch::Tensor &a, torch::Tensor &b,
38
+ torch::Tensor &y,
39
+ torch::Tensor &s, torch::Tensor &h0) {
40
+ int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
41
+ cuda_forward_inference(B, T, H,
42
+ (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(),
43
+ (bf*)a.data_ptr(), (bf*)b.data_ptr(),
44
+ (bf*)y.data_ptr(),
45
+ (float*)s.data_ptr(), (float*)h0.data_ptr());
46
+ }
47
+
48
+ // ---------- 合并的算子注册(不要分开写!)----------
49
+ TORCH_LIBRARY(wind_backstepping, m) {
50
+ // 训练算子
51
+ m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa, Tensor(f!) h0) -> ()");
52
+ m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor sa,Tensor dht,Tensor(a!) dh0, Tensor(b!) dw, Tensor(c!) dq, Tensor(d!) dk, Tensor(e!) dv, Tensor(f!) dz, Tensor(g!) da) -> ()");
53
+
54
+ // 推理算子(追加到同一个块内)
55
+ m.def("forward_inference(Tensor w, Tensor q, Tensor k, Tensor v, Tensor a, Tensor b, Tensor(a!) y, Tensor(b!) s, Tensor(c!) h0) -> ()");
56
+ }
57
+
58
+ // ---------- 合并的实现注册 ----------
59
+ TORCH_LIBRARY_IMPL(wind_backstepping, CUDA, m) {
60
+ m.impl("forward", &forward);
61
+ m.impl("backward", &backward);
62
+ m.impl("forward_inference", &forward_inference);
63
+ }