rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,90 @@
1
+ from keras import ops
2
+
3
+
4
+ class RWKVKernelOperator:
5
+ def __init__(self, head_size, max_sequence_length):
6
+ self.head_size = head_size
7
+ self.max_sequence_length = max_sequence_length
8
+
9
+ def __call__(
10
+ self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
11
+ ):
12
+ B, T, C = ops.shape(r)
13
+ assert C % self.head_size == 0
14
+ H = C // self.head_size
15
+ w = ops.reshape(w, [B, T, H, self.head_size, 1])
16
+ k = ops.reshape(k, [B, T, H, self.head_size, 1])
17
+
18
+ v = ops.reshape(v, [B, T, H, 1, self.head_size])
19
+ r = ops.reshape(r, [B, T, H, 1, self.head_size])
20
+ u = ops.reshape(u, [1, H, self.head_size, 1])
21
+
22
+ if init_state is not None:
23
+ assert len(init_state.shape) in [
24
+ 3,
25
+ 4,
26
+ ], "init_state的形状必须为(state_kinds,num_heads,head_size,head_size)"
27
+ if len(init_state.shape) == 3:
28
+ assert init_state.shape == (
29
+ H,
30
+ self.head_size,
31
+ self.head_size,
32
+ ), "state_kinds的形状必须为(BatchSize,num_heads,head_size,head_size)"
33
+ init_state = init_state[None, :]
34
+ else:
35
+ assert init_state.shape[1:] == (
36
+ H,
37
+ self.head_size,
38
+ self.head_size,
39
+ ), "state_kinds的形状必须为(BatchSize,num_heads,head_size,head_size)"
40
+ state_kinds = init_state.shape[0]
41
+ if state_map is None:
42
+ state_kinds = init_state.shape[0]
43
+ if state_kinds == 1:
44
+ state_map = ops.zeros(shape=(B,), dtype="int32")
45
+ elif state_kinds == B:
46
+ state_map = ops.convert_to_tensor(
47
+ [i for i in range(B)], dtype="int32"
48
+ )
49
+ else:
50
+ raise ValueError(
51
+ "无法为您推断state_map的形状,请您手动指定state_map"
52
+ )
53
+
54
+ else:
55
+ if isinstance(state_map, list):
56
+ state_map = ops.convert_to_tensor(state_map, dtype="int32")
57
+ state_map = ops.cast(state_map, "int32")
58
+ assert (state_map >= 0).all() and (state_map < state_kinds).all(), (
59
+ f"请确保state_map的值域为[0, {state_kinds})"
60
+ )
61
+ s = ops.take(init_state, state_map, axis=0)
62
+
63
+ else:
64
+ assert state_map is None
65
+ s = ops.zeros((B, H, self.head_size, self.head_size), dtype=u.dtype)
66
+
67
+ w = ops.exp(-ops.exp(w))
68
+
69
+ def cond(i, k, v, w, r, s, y):
70
+ return i < T
71
+
72
+ def body(i, k, v, w, r, s, y):
73
+ k_t = ops.take(k, i, 1)
74
+ v_t = ops.take(v, i, 1)
75
+ kv_t = k_t @ v_t
76
+ w_t = ops.take(w, i, 1)
77
+
78
+ r_t = ops.take(r, i, 1)
79
+ y_t = r_t @ (u * kv_t + s)
80
+ y_t = ops.reshape(y_t, (B, 1, C))
81
+ s = kv_t + w_t * s
82
+
83
+ y = ops.slice_update(y, [0, i, 0], y_t)
84
+ return i + 1, k, v, w, r, s, y
85
+
86
+ y = ops.zeros([B, T, C], r.dtype)
87
+ i, k, v, w, r, s, y = ops.while_loop(cond, body, (0, k, v, w, r, s, y), T)
88
+ if with_state:
89
+ return y, s
90
+ return y, None
@@ -0,0 +1,397 @@
1
+ #include <stdio.h>
2
+ #include <assert.h>
3
+ #include "ATen/ATen.h"
4
+ typedef at::BFloat16 bf16;
5
+ typedef float fp32;
6
+ typedef at::Half fp16;
7
+
8
+
9
+
10
+
11
+
12
+ template <typename F_in,typename F_out>
13
+ __device__ void kernel_forward_core(const int B, const int T, const int C, const int H, const int b, const int h, const int i, const float* state,
14
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
15
+ F_out *__restrict__ const _y)
16
+ {
17
+ _u += h*_N_;
18
+
19
+ __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
20
+ //float state[_N_] = {0};
21
+
22
+ __syncthreads();
23
+ u[i] = float(_u[i]);
24
+ __syncthreads();
25
+
26
+ for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
27
+ {
28
+ __syncthreads();
29
+ w[i] = __expf(-__expf(float(_w[t])));
30
+ r[i] = float(_r[t]);
31
+ k[i] = float(_k[t]);
32
+ __syncthreads();
33
+
34
+ const float v = float(_v[t]);
35
+ float y = 0;
36
+
37
+ #pragma unroll
38
+ for (int j = 0; j < _N_; j+=4)
39
+ {
40
+ const float4& r_ = (float4&)(r[j]);
41
+ const float4& k_ = (float4&)(k[j]);
42
+ const float4& w_ = (float4&)(w[j]);
43
+ const float4& u_ = (float4&)(u[j]);
44
+ float4& s = (float4&)(state[j]);
45
+ float4 x;
46
+
47
+ x.x = k_.x * v;
48
+ x.y = k_.y * v;
49
+ x.z = k_.z * v;
50
+ x.w = k_.w * v;
51
+
52
+ y += r_.x * (u_.x * x.x + s.x);
53
+ y += r_.y * (u_.y * x.y + s.y);
54
+ y += r_.z * (u_.z * x.z + s.z);
55
+ y += r_.w * (u_.w * x.w + s.w);
56
+
57
+ s.x = s.x * w_.x + x.x;
58
+ s.y = s.y * w_.y + x.y;
59
+ s.z = s.z * w_.z + x.z;
60
+ s.w = s.w * w_.w + x.w;
61
+ }
62
+ _y[t] = F_out(y);
63
+ }
64
+ }
65
+
66
+
67
+
68
+
69
+ template <typename F_in,typename F_out>
70
+ __global__ void kernel_forward_state(const int B, const int T, const int C, const int H, const bool is_custom_state, const int64_t* state_map,
71
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
72
+ const F_out *__restrict__ _s, F_out *__restrict__ const _y, F_out *__restrict__ const _ys)
73
+ {
74
+ const int b = blockIdx.x / H;
75
+ const int h = blockIdx.x % H;
76
+ const int i = threadIdx.x;
77
+ float state[_N_] = {0};
78
+
79
+ if(is_custom_state){
80
+ //printf("init\n");
81
+ assert(state_map[b] >=0 && state_map[b] < B);
82
+
83
+ const int64_t input_state_offset = state_map[b] * H * _N_ *_N_ + h * _N_ * _N_ + i;
84
+
85
+ for(int j= 0; j< _N_; j++){
86
+ state[j] = float(_s[j * _N_ + input_state_offset]);
87
+ }
88
+ }
89
+
90
+
91
+ const int64_t current_state_offset = b * H * _N_ *_N_ + h * _N_ * _N_ + i;
92
+
93
+ kernel_forward_core(B, T, C, H, b, h, i, state, _r, _k, _v, _w, _u, _y);
94
+ for(int j=0; j< _N_; j++){
95
+ _ys[j * _N_ + current_state_offset] = F_out(state[j]);
96
+ }
97
+ }
98
+
99
+
100
+ template <typename F_in,typename F_out>
101
+ __global__ void kernel_forward(const int B, const int T, const int C, const int H,
102
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
103
+ F_out *__restrict__ const _y)
104
+ {
105
+ const int b = blockIdx.x / H;
106
+ const int h = blockIdx.x % H;
107
+ const int i = threadIdx.x;
108
+ float state[_N_] = {0};
109
+ kernel_forward_core(B, T, C, H, b, h, i, state, _r, _k, _v, _w, _u, _y);
110
+ }
111
+
112
+
113
+
114
+
115
+ template <typename F_in, typename F_out>
116
+ __global__ void kernel_backward_101(const int B, const int T, const int C, const int H,
117
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w,
118
+ const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
119
+ F_out *__restrict__ const _gr, F_out *__restrict__ const _gu)
120
+ {
121
+ const int b = blockIdx.x / H;
122
+ const int h = blockIdx.x % H;
123
+ const int i = threadIdx.x;
124
+
125
+ __shared__ float v[_N_], gy[_N_];
126
+
127
+ const float u = float(_u[h*_N_ + i]);
128
+
129
+ float state[_N_] = {0};
130
+
131
+ const int t_0 = b*T*C + h*_N_ + i;
132
+ const int t_T = t_0 + T*C;
133
+
134
+ float gu = 0;
135
+ for (int t = t_0; t < t_T; t += C)
136
+ {
137
+ __syncthreads();
138
+ v[i] = float(_v[t]);
139
+ gy[i] = float(_gy[t]);
140
+ __syncthreads();
141
+
142
+ const float k = float(_k[t]);
143
+ const float w = __expf(-__expf(float(_w[t])));
144
+ float gr = 0, gu_ = 0;
145
+
146
+ #pragma unroll
147
+ for (int j = 0; j < _N_; j++)
148
+ {
149
+ float& s = state[j];
150
+ float x = k * v[j];
151
+
152
+ gr += (u * x + s) * gy[j];
153
+ gu_ += x * gy[j];
154
+ s = s * w + x;
155
+ }
156
+ _gr[t] = F_out(gr);
157
+ gu += float(_r[t]) * gu_;
158
+ }
159
+ _gu[b*C + h*_N_ + i] = F_out(gu);
160
+ }
161
+
162
+ template <typename F_in, typename F_out>
163
+ __global__ void kernel_backward_102(const int B, const int T, const int C, const int H,
164
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v,
165
+ const F_in *__restrict__ _w, const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
166
+ F_out *__restrict__ const _gk)
167
+ {
168
+ const int b = blockIdx.x / H;
169
+ const int h = blockIdx.x % H;
170
+ const int i = threadIdx.x;
171
+
172
+ __shared__ float v[_N_], gy[_N_];
173
+
174
+ const float u = float(_u[h*_N_ + i]);
175
+
176
+ float scccc[_N_] = {0};
177
+
178
+ const int t_0 = b*T*C + h*_N_ + i;
179
+ const int t_T_1 = t_0 + (T-1)*C;
180
+
181
+ for (int t = t_T_1; t >= t_0; t -= C)
182
+ {
183
+ __syncthreads();
184
+ v[i] = float(_v[t]);
185
+ gy[i] = float(_gy[t]);
186
+ __syncthreads();
187
+
188
+ const float rr = float(_r[t]);
189
+ const float w = __expf(-__expf(float(_w[t])));
190
+ float gk = 0;
191
+
192
+ #pragma unroll
193
+ for (int j = 0; j < _N_; j++)
194
+ {
195
+ float& s = scccc[j];
196
+ float x = rr * gy[j];
197
+
198
+ gk += (u * x + s) * v[j];
199
+ s = x + s * w;
200
+ }
201
+ _gk[t] = F_out(gk);
202
+ }
203
+ }
204
+
205
+ template <typename F_in, typename F_out>
206
+ __global__ void kernel_backward_103(const int B, const int T, const int C, const int H,
207
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v,
208
+ const F_in *__restrict__ _w, const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
209
+ F_out *__restrict__ const _gv)
210
+ {
211
+ const int b = blockIdx.x / H;
212
+ const int h = blockIdx.x % H;
213
+ const int i = threadIdx.x;
214
+ _u += h*_N_;
215
+
216
+ __shared__ float u_[_N_], r[_N_], k[_N_], w_[_N_];
217
+ __syncthreads();
218
+ u_[i] = float(_u[i]);
219
+ __syncthreads();
220
+
221
+ float sdddd[_N_] = {0};
222
+
223
+ const int t_0 = b*T*C + h*_N_ + i;
224
+ const int t_T_1 = t_0 + (T-1)*C;
225
+
226
+ for (int t = t_T_1; t >= t_0; t -= C)
227
+ {
228
+ __syncthreads();
229
+ r[i] = float(_r[t]);
230
+ k[i] = float(_k[t]);
231
+ w_[i] = __expf(-__expf(float(_w[t])));
232
+ __syncthreads();
233
+
234
+ const float gyy = float(_gy[t]);
235
+ float gv = 0;
236
+
237
+ #pragma unroll
238
+ for (int j = 0; j < _N_; j++)
239
+ {
240
+ float& s = sdddd[j];
241
+ float x = gyy * r[j];
242
+
243
+ gv += (u_[j] * x + s) * k[j];
244
+ s = x + s * w_[j];
245
+ }
246
+ _gv[t] = F_out(gv);
247
+ }
248
+ }
249
+
250
+ template <typename F_in, typename F_out>
251
+ __global__ void kernel_backward_201(const int B, const int T, const int C, const int H,
252
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w,
253
+ const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
254
+ F_out *__restrict__ const _gw)
255
+ {
256
+ const int b = blockIdx.x / H;
257
+ const int h = blockIdx.x % H;
258
+ const int i = threadIdx.x;
259
+
260
+ __shared__ float v[_N_], gy[_N_];
261
+ float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0};
262
+
263
+ const int t_0 = b*T*C + h*_N_ + i;
264
+ const int t_1 = t_0 + C;
265
+ const int t_2 = t_0 + 2*C;
266
+ const int t_T_1 = t_0 + (T-1)*C;
267
+
268
+ for (int t = t_T_1; t > t_1; t -= C)
269
+ {
270
+ __syncthreads();
271
+ gy[i] = float(_gy[t]);
272
+ v[i] = float(_v[t-2*C]);
273
+ __syncthreads();
274
+
275
+ const float r = float(_r[t]);
276
+ const float w = __expf(-__expf(float(_w[t-C])));
277
+ float sum = 0.0f;
278
+
279
+ #pragma unroll
280
+ for (int j = 0; j < _N_; j++)
281
+ {
282
+ float& s = saaaa[j];
283
+ float x = r * gy[j];
284
+ s = (s + x) * w;
285
+ sum += s * v[j];
286
+ }
287
+ sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]);
288
+ }
289
+
290
+ float sss = sbbbb[0];
291
+ _gw[t_0] = 0;
292
+ _gw[t_1] = F_out(sss * -__expf(float(_w[t_1])));
293
+
294
+ for (int t = t_2; t < t_T_1; t += C)
295
+ {
296
+ __syncthreads();
297
+ gy[i] = float(_gy[t]);
298
+ v[i] = float(_v[t-2*C]);
299
+ __syncthreads();
300
+
301
+ const float w = __expf(-__expf(float(_w[t-C])));
302
+ const float k = float(_k[t-2*C]);
303
+ float sum = 0.0f;
304
+
305
+ #pragma unroll
306
+ for (int j = 0; j < _N_; j++)
307
+ {
308
+ float& s = scccc[j];
309
+ float x = k * v[j];
310
+ s = (s + x) * w;
311
+ sum += s * gy[j];
312
+ }
313
+ sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t]));
314
+ _gw[t] = F_out(sss * -__expf(float(_w[t])));
315
+ }
316
+ _gw[t_T_1] = 0;
317
+ }
318
+
319
+
320
+
321
+
322
+
323
+
324
+
325
+
326
+ template<typename F_in,typename F_out>
327
+ void cuda_forward(int B, int T, int C, int H, F_in *r, F_in *k, F_in *v, F_in *w, F_in *u, F_out *y)
328
+ {
329
+ assert(H*_N_ == C);
330
+ assert(_N_%4 == 0);
331
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
332
+ }
333
+
334
+ template<typename F_in,typename F_out>
335
+ void cuda_forward_with_state(int B, int T, int C, int H, bool S, int64_t *map, F_in *r, F_in *k, F_in *v, F_in *w, F_in *u, F_out *s, F_out *y, F_out *ys)
336
+ {
337
+ assert(H*_N_ == C);
338
+ assert(_N_%4 == 0);
339
+ if(S){
340
+ kernel_forward_state<F_in,F_out><<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, S, map, r, k, v, w, u, s, y, ys);
341
+ }else{
342
+ kernel_forward_state<F_in,F_out><<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, S, nullptr, r, k, v, w, u, nullptr, y, ys);
343
+ }
344
+ }
345
+
346
+
347
+ template<typename F_in,typename F_out>
348
+ void cuda_backward(int B, int T, int C, int H, F_in *r, F_in *k, F_in *v, F_in *w, F_in *u, F_out *gy, F_out *gr, F_out *gk, F_out *gv, F_out *gw, F_out *gu)
349
+ {
350
+ assert(H*_N_ == C);
351
+ assert(_N_%4 == 0);
352
+ kernel_backward_101<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gr, gu);
353
+ kernel_backward_102<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gk);
354
+ kernel_backward_103<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gv);
355
+ kernel_backward_201<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gw);
356
+ }
357
+
358
+ void cuda_forward_bf16(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *y){
359
+ cuda_forward<bf16,bf16>(B, T, C, H, r, k, v, w, u, y);
360
+ }
361
+ void cuda_backward_bf16(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu){
362
+ cuda_backward<bf16,bf16>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu);
363
+ }
364
+
365
+ void cuda_forward_fp16(int B, int T, int C, int H, fp16 *r, fp16 *k, fp16 *v, fp16 *w, fp16 *u, fp32 *y){
366
+ cuda_forward<fp16,fp32>(B, T, C, H, r, k, v, w, u, y);
367
+ }
368
+
369
+ void cuda_backward_fp16(int B, int T, int C, int H, fp16 *r, fp16 *k, fp16 *v, fp16 *w, fp16 *u, fp32 *gy, fp32 *gr, fp32 *gk, fp32 *gv, fp32 *gw, fp32 *gu){
370
+ cuda_backward<fp16,fp32>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu);
371
+ }
372
+
373
+
374
+ void cuda_forward_fp32(int B, int T, int C, int H, fp32 *r, fp32 *k, fp32 *v, fp32 *w, fp32 *u, fp32 *y){
375
+ cuda_forward<fp32,fp32>(B, T, C, H, r, k, v, w, u, y);
376
+
377
+ }
378
+ void cuda_backward_fp32(int B, int T, int C, int H, fp32 *r, fp32 *k, fp32 *v, fp32 *w, fp32 *u, fp32 *gy, fp32 *gr, fp32 *gk, fp32 *gv, fp32 *gw, fp32 *gu){
379
+ cuda_backward<fp32,fp32>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu);
380
+ }
381
+
382
+
383
+
384
+
385
+ void cuda_forward_with_state_bf16(int B, int T, int C, int H, bool S, int64_t *map, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y, bf16 *ys){
386
+ cuda_forward_with_state<bf16,bf16>(B, T, C, H, S, map, r, k, v, w, u, s, y, ys);
387
+ }
388
+
389
+
390
+ void cuda_forward_with_state_fp16(int B, int T, int C, int H, bool S,int64_t *map, fp16 *r, fp16 *k, fp16 *v, fp16 *w, fp16 *u,fp32 *s, fp32 *y, fp32 *ys){
391
+ cuda_forward_with_state<fp16,fp32>(B, T, C, H, S, map, r, k, v, w, u, s, y, ys);
392
+ }
393
+
394
+ void cuda_forward_with_state_fp32(int B, int T, int C, int H, bool S,int64_t *map, fp32 *r, fp32 *k, fp32 *v, fp32 *w, fp32 *u, fp32 *s, fp32 *y, fp32 *ys){
395
+ cuda_forward_with_state<fp32,fp32>(B, T, C, H, S, map, r, k, v, w, u, s, y, ys);
396
+
397
+ }
@@ -0,0 +1,93 @@
1
+ #include <torch/extension.h>
2
+ #include "ATen/ATen.h"
3
+ typedef at::BFloat16 bf16;
4
+ typedef float fp32;
5
+ typedef int64_t i64;
6
+ typedef at::Half fp16;
7
+
8
+ void cuda_forward_bf16(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *y);
9
+ void cuda_forward_fp16(int B, int T, int C, int H, fp16 *r, fp16 *k, fp16 *v, fp16 *w, fp16 *u, fp32 *y);
10
+ void cuda_forward_fp32(int B, int T, int C, int H, fp32 *r, fp32 *k, fp32 *v, fp32 *w, fp32 *u, fp32 *y);
11
+
12
+ void cuda_backward_bf16(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
13
+ void cuda_backward_fp16(int B, int T, int C, int H, fp16 *r, fp16 *k, fp16 *v, fp16 *w, fp16 *u, fp32 *gy, fp32 *gr, fp32 *gk, fp32 *gv, fp32 *gw, fp32 *gu);
14
+ void cuda_backward_fp32(int B, int T, int C, int H, fp32 *r, fp32 *k, fp32 *v, fp32 *w, fp32 *u, fp32 *gy, fp32 *gr, fp32 *gk, fp32 *gv, fp32 *gw, fp32 *gu);
15
+
16
+
17
+ void cuda_forward_with_state_bf16(int B, int T, int C, int H, bool S, int64_t *map, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y, bf16 *ys);
18
+ void cuda_forward_with_state_fp16(int B, int T, int C, int H, bool S, int64_t *map, fp16 *r, fp16 *k, fp16 *v, fp16 *w, fp16 *u, fp32 *s, fp32 *y, fp32 *ys);
19
+ void cuda_forward_with_state_fp32(int B, int T, int C, int H, bool S, int64_t *map, fp32 *r, fp32 *k, fp32 *v, fp32 *w, fp32 *u, fp32 *s, fp32 *y, fp32 *ys);
20
+
21
+
22
+ void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
23
+ cuda_forward_bf16(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
24
+ }
25
+ void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
26
+ cuda_forward_fp16(B, T, C, H, r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<fp16>(), u.data_ptr<fp16>(), y.data_ptr<fp32>());
27
+ }
28
+ void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
29
+ cuda_forward_fp32(B, T, C, H, r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<fp32>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
30
+ }
31
+
32
+
33
+ void backward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
34
+ cuda_backward_bf16(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
35
+ }
36
+ void backward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
37
+ cuda_backward_fp16(B, T, C, H, r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<fp16>(), u.data_ptr<fp16>(), gy.data_ptr<fp32>(), gr.data_ptr<fp32>(), gk.data_ptr<fp32>(), gv.data_ptr<fp32>(), gw.data_ptr<fp32>(), gu.data_ptr<fp32>());
38
+ }
39
+ void backward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
40
+ cuda_backward_fp32(B, T, C, H, r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<fp32>(), u.data_ptr<fp32>(), gy.data_ptr<fp32>(), gr.data_ptr<fp32>(), gk.data_ptr<fp32>(), gv.data_ptr<fp32>(), gw.data_ptr<fp32>(), gu.data_ptr<fp32>());
41
+ }
42
+
43
+
44
+ void forward_with_state_bf16(int64_t B, int64_t T, int64_t C, int64_t H, bool S, torch::Tensor &s_map, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y, torch::Tensor &ys) {
45
+ cuda_forward_with_state_bf16(B, T, C, H, S, s_map.data_ptr<i64>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(),y.data_ptr<bf16>(), ys.data_ptr<bf16>());
46
+ }
47
+ void forward_with_state_fp16(int64_t B, int64_t T, int64_t C, int64_t H, bool S, torch::Tensor &s_map, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y, torch::Tensor &ys) {
48
+ cuda_forward_with_state_fp16(B, T, C, H, S, s_map.data_ptr<i64>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<fp16>(), u.data_ptr<fp16>(), s.data_ptr<fp32>(), y.data_ptr<fp32>(), ys.data_ptr<fp32>());
49
+ }
50
+ void forward_with_state_fp32(int64_t B, int64_t T, int64_t C, int64_t H, bool S, torch::Tensor &s_map, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y, torch::Tensor &ys) {
51
+ cuda_forward_with_state_fp32(B, T, C, H, S, s_map.data_ptr<i64>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<fp32>(), u.data_ptr<fp32>(), s.data_ptr<fp32>(), y.data_ptr<fp32>(), ys.data_ptr<fp32>());
52
+ }
53
+
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
63
+ m.def("forward_bf16", &forward_bf16, "wkv6 forward bf16");
64
+ m.def("forward_fp16", &forward_fp16, "wkv6 forward fp16");
65
+ m.def("forward_fp32",&forward_fp32,"wkv6 forward fp32");
66
+
67
+
68
+ m.def("backward_bf16", &backward_bf16, "wkv6 backward bf16");
69
+ m.def("backward_fp16", &backward_fp16, "wkv6 backward fp16");
70
+ m.def("backward_fp32",&backward_fp32,"wkv6 backwrad fp32");
71
+
72
+
73
+ m.def("forward_with_state_bf16", &forward_with_state_bf16, "wkv6 forward with state bf16");
74
+ m.def("forward_with_state_fp16", &forward_with_state_fp16, "wkv6 forward with state fp16");
75
+ m.def("forward_with_state_fp32", &forward_with_state_fp32, "wkv6 forward with state fp32");
76
+ }
77
+
78
+ TORCH_LIBRARY(wkv6, m) {
79
+ m.def("forward_bf16", forward_bf16);
80
+ m.def("forward_fp16", forward_fp16);
81
+ m.def("forward_fp32",forward_fp32);
82
+
83
+
84
+ m.def("backward_bf16", backward_bf16);
85
+ m.def("backward_fp16", backward_fp16);
86
+ m.def("backward_fp32",backward_fp32);
87
+
88
+
89
+ m.def("forward_with_state_bf16", forward_with_state_bf16);
90
+ m.def("forward_with_state_fp16", forward_with_state_fp16);
91
+ m.def("forward_with_state_fp32",forward_with_state_fp32);
92
+
93
+ }