rwkv-ops 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (31) hide show
  1. rwkv_ops/__init__.py +5 -6
  2. rwkv_ops/rwkv6_kernel/__init__.py +0 -6
  3. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  4. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  5. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  6. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  7. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  8. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  9. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  10. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  11. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  12. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  13. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +21 -23
  14. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +14 -10
  15. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  16. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  17. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +4 -4
  18. rwkv_ops/rwkv7_kernel/__init__.py +77 -29
  19. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  20. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +279 -0
  21. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +237 -0
  22. rwkv_ops/rwkv7_kernel/jax_op.py +6 -5
  23. rwkv_ops/rwkv7_kernel/native_keras_op.py +5 -6
  24. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +123 -0
  25. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +165 -0
  26. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +35 -0
  27. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/METADATA +28 -27
  28. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/RECORD +30 -13
  29. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/WHEEL +1 -2
  30. rwkv_ops-0.2.2.dist-info/top_level.txt +0 -1
  31. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info/licenses}/LICENSE.txt +0 -0
@@ -0,0 +1,41 @@
1
+ /* Copyright 2024 The JAX Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ // This header extends kernel_helpers.h with the pybind11 specific interface to
17
+ // serializing descriptors. It also adds a pybind11 function for wrapping our
18
+ // custom calls in a Python capsule. This is separate from kernel_helpers so
19
+ // that the CUDA code itself doesn't include pybind11. I don't think that this
20
+ // is strictly necessary, but they do it in jaxlib, so let's do it here too.
21
+
22
+ #ifndef _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
23
+ #define _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
24
+
25
+ #include <pybind11/pybind11.h>
26
+
27
+ #include "kernel_helpers.h"
28
+
29
+ namespace gpu_ops {
30
+
31
+ template <typename T> pybind11::bytes PackDescriptor(const T &descriptor) {
32
+ return pybind11::bytes(PackDescriptorAsString(descriptor));
33
+ }
34
+
35
+ template <typename T> pybind11::capsule EncapsulateFunction(T *fn) {
36
+ return pybind11::capsule(bit_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
37
+ }
38
+
39
+ } // namespace gpu_ops
40
+
41
+ #endif
@@ -0,0 +1,514 @@
1
+ /* Copyright 2024 The JAX Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+
17
+ #include <hip/hip_runtime.h>
18
+ #include "kernel_helpers.h"
19
+ #include "kernels.h"
20
+ #include "stdio.h"
21
+ #include <hip/hip_bf16.h>
22
+ #include <hip/hip_fp16.h>
23
+ #include <iostream>
24
+ #include <assert.h>
25
+ namespace {
26
+
27
+ #define DISPATCH_Vector_TYPES(TYPEIN, TYPEOUT,NAME, ...) \
28
+ switch (TYPEIN) { \
29
+ case gpu_ops::ElementType::F32: { \
30
+ using input_type = float; \
31
+ switch (TYPEOUT) { \
32
+ case gpu_ops::ElementType::F32: { \
33
+ using output_type = float; \
34
+ __VA_ARGS__; \
35
+ break; \
36
+ } \
37
+ case gpu_ops::ElementType::F16: { \
38
+ using output_type = __half; \
39
+ __VA_ARGS__; \
40
+ break; \
41
+ } \
42
+ case gpu_ops::ElementType::BF16: { \
43
+ using output_type = hip_bfloat16; \
44
+ __VA_ARGS__; \
45
+ break; \
46
+ } \
47
+ default: \
48
+ break; \
49
+ } \
50
+ break; \
51
+ } \
52
+ case gpu_ops::ElementType::F16: { \
53
+ using input_type = __half; \
54
+ switch (TYPEOUT) { \
55
+ case gpu_ops::ElementType::F32: { \
56
+ using output_type = float; \
57
+ __VA_ARGS__; \
58
+ break; \
59
+ } \
60
+ case gpu_ops::ElementType::F16: { \
61
+ using output_type = __half; \
62
+ __VA_ARGS__; \
63
+ break; \
64
+ } \
65
+ case gpu_ops::ElementType::BF16: { \
66
+ using output_type = hip_bfloat16; \
67
+ __VA_ARGS__; \
68
+ break; \
69
+ } \
70
+ default: \
71
+ break; \
72
+ } \
73
+ break; \
74
+ } \
75
+ case gpu_ops::ElementType::BF16: { \
76
+ using input_type = hip_bfloat16; \
77
+ switch (TYPEOUT) { \
78
+ case gpu_ops::ElementType::F32: { \
79
+ using output_type = float; \
80
+ __VA_ARGS__; \
81
+ break; \
82
+ } \
83
+ case gpu_ops::ElementType::F16: { \
84
+ using output_type = __half; \
85
+ __VA_ARGS__; \
86
+ break; \
87
+ } \
88
+ case gpu_ops::ElementType::BF16: { \
89
+ using output_type = hip_bfloat16; \
90
+ __VA_ARGS__; \
91
+ break; \
92
+ } \
93
+ default: \
94
+ break; \
95
+ } \
96
+ break; \
97
+ } \
98
+ default: \
99
+ break; \
100
+ }
101
+
102
+
103
+ static_assert(_N_ % 4 ==0,"the size of head must be the times of 4.");
104
+
105
+
106
+
107
+
108
+
109
+ template <typename F_in,typename F_out>
110
+ __device__ void kernel_forward_core(const int B, const int T, const int C, const int H,const int b, const int h, const int i, const float* state,
111
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
112
+ F_out *__restrict__ const _y)
113
+ {
114
+
115
+ _u += h*_N_;
116
+
117
+ __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
118
+ //float state[_N_] = {0};
119
+
120
+ __syncthreads();
121
+ u[i] = float(_u[i]);
122
+ __syncthreads();
123
+
124
+ for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
125
+ {
126
+ __syncthreads();
127
+ w[i] = __expf(-__expf(float(_w[t])));
128
+ r[i] = float(_r[t]);
129
+ k[i] = float(_k[t]);
130
+ __syncthreads();
131
+
132
+ const float v = float(_v[t]);
133
+ float y = 0;
134
+
135
+ #pragma unroll
136
+ for (int j = 0; j < _N_; j+=4)
137
+ {
138
+ const float4& r_ = (float4&)(r[j]);
139
+ const float4& k_ = (float4&)(k[j]);
140
+ const float4& w_ = (float4&)(w[j]);
141
+ const float4& u_ = (float4&)(u[j]);
142
+ float4& s = (float4&)(state[j]);
143
+ float4 x;
144
+
145
+ x.x = k_.x * v;
146
+ x.y = k_.y * v;
147
+ x.z = k_.z * v;
148
+ x.w = k_.w * v;
149
+
150
+ y += r_.x * (u_.x * x.x + s.x);
151
+ y += r_.y * (u_.y * x.y + s.y);
152
+ y += r_.z * (u_.z * x.z + s.z);
153
+ y += r_.w * (u_.w * x.w + s.w);
154
+
155
+ s.x = s.x * w_.x + x.x;
156
+ s.y = s.y * w_.y + x.y;
157
+ s.z = s.z * w_.z + x.z;
158
+ s.w = s.w * w_.w + x.w;
159
+ }
160
+ _y[t] = F_out(y);
161
+ }
162
+ }
163
+
164
+
165
+
166
+
167
+ template <typename F_in,typename F_out>
168
+ __global__ void kernel_forward_state(const int B, const int T, const int C, const int H,const bool is_custom_state,const int32_t* map,
169
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
170
+ const F_out *__restrict__ _s, F_out *__restrict__ const _y, F_out *__restrict__ const _ys)
171
+ {
172
+ const int b = blockIdx.x / H;
173
+ const int h = blockIdx.x % H;
174
+ const int i = threadIdx.x;
175
+ float state[_N_] = {0};
176
+ if(is_custom_state){
177
+ assert(map[b] >=0 && map[b] < B);
178
+
179
+ const int64_t input_state_offset = map[b] * H * _N_ *_N_ + h * _N_ * _N_ + i;
180
+
181
+ for(int j= 0; j< _N_; j++){
182
+ state[j] = float(_s[j*_N_ + input_state_offset]);
183
+ }
184
+ }
185
+
186
+ const int64_t current_state_offset = b * H * _N_ *_N_ + h * _N_ * _N_ + i;
187
+
188
+ kernel_forward_core(B, T, C, H, b, h, i, state, _r, _k, _v, _w, _u, _y);
189
+ for(int j=0; j< _N_; j++){
190
+ _ys[j*_N_ + current_state_offset] = F_out(state[j]);
191
+ }
192
+ }
193
+
194
+
195
+ template <typename F_in,typename F_out>
196
+ __global__ void kernel_forward(const int B, const int T, const int C, const int H,
197
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
198
+ F_out *__restrict__ const _y)
199
+ {
200
+ const int b = blockIdx.x / H;
201
+ const int h = blockIdx.x % H;
202
+ const int i = threadIdx.x;
203
+ float state[_N_] = {0};
204
+ kernel_forward_core(B, T, C, H, b, h, i, state, _r, _k, _v, _w, _u, _y);
205
+ }
206
+
207
+ template <typename F_in, typename F_out>
208
+ __global__ void kernel_backward_101(const int B, const int T, const int C, const int H,
209
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w,
210
+ const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
211
+ F_out *__restrict__ const _gr, F_out *__restrict__ const _gu)
212
+ {
213
+ const int b = blockIdx.x / H;
214
+ const int h = blockIdx.x % H;
215
+ const int i = threadIdx.x;
216
+
217
+ __shared__ float v[_N_], gy[_N_];
218
+
219
+ const float u = float(_u[h*_N_ + i]);
220
+
221
+ float state[_N_] = {0};
222
+
223
+ const int t_0 = b*T*C + h*_N_ + i;
224
+ const int t_T = t_0 + T*C;
225
+
226
+ float gu = 0;
227
+ for (int t = t_0; t < t_T; t += C)
228
+ {
229
+ __syncthreads();
230
+ v[i] = float(_v[t]);
231
+ gy[i] = float(_gy[t]);
232
+ __syncthreads();
233
+
234
+ const float k = float(_k[t]);
235
+ const float w = __expf(-__expf(float(_w[t])));
236
+ float gr = 0, gu_ = 0;
237
+
238
+ #pragma unroll
239
+ for (int j = 0; j < _N_; j++)
240
+ {
241
+ float& s = state[j];
242
+ float x = k * v[j];
243
+
244
+ gr += (u * x + s) * gy[j];
245
+ gu_ += x * gy[j];
246
+ s = s * w + x;
247
+ }
248
+ _gr[t] = F_out(gr);
249
+ gu += float(_r[t]) * gu_;
250
+ }
251
+ _gu[b*C + h*_N_ + i] = F_out(gu);
252
+ }
253
+
254
+ template <typename F_in, typename F_out>
255
+ __global__ void kernel_backward_102(const int B, const int T, const int C, const int H,
256
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v,
257
+ const F_in *__restrict__ _w, const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
258
+ F_out *__restrict__ const _gk)
259
+ {
260
+ const int b = blockIdx.x / H;
261
+ const int h = blockIdx.x % H;
262
+ const int i = threadIdx.x;
263
+
264
+ __shared__ float v[_N_], gy[_N_];
265
+
266
+ const float u = float(_u[h*_N_ + i]);
267
+
268
+ float scccc[_N_] = {0};
269
+
270
+ const int t_0 = b*T*C + h*_N_ + i;
271
+ const int t_T_1 = t_0 + (T-1)*C;
272
+
273
+ for (int t = t_T_1; t >= t_0; t -= C)
274
+ {
275
+ __syncthreads();
276
+ v[i] = float(_v[t]);
277
+ gy[i] = float(_gy[t]);
278
+ __syncthreads();
279
+
280
+ const float rr = float(_r[t]);
281
+ const float w = __expf(-__expf(float(_w[t])));
282
+ float gk = 0;
283
+
284
+ #pragma unroll
285
+ for (int j = 0; j < _N_; j++)
286
+ {
287
+ float& s = scccc[j];
288
+ float x = rr * gy[j];
289
+
290
+ gk += (u * x + s) * v[j];
291
+ s = x + s * w;
292
+ }
293
+ _gk[t] = F_out(gk);
294
+ }
295
+ }
296
+
297
+ template <typename F_in, typename F_out>
298
+ __global__ void kernel_backward_103(const int B, const int T, const int C, const int H,
299
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v,
300
+ const F_in *__restrict__ _w, const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
301
+ F_out *__restrict__ const _gv)
302
+ {
303
+ const int b = blockIdx.x / H;
304
+ const int h = blockIdx.x % H;
305
+ const int i = threadIdx.x;
306
+ _u += h*_N_;
307
+
308
+ __shared__ float u_[_N_], r[_N_], k[_N_], w_[_N_];
309
+ __syncthreads();
310
+ u_[i] = float(_u[i]);
311
+ __syncthreads();
312
+
313
+ float sdddd[_N_] = {0};
314
+
315
+ const int t_0 = b*T*C + h*_N_ + i;
316
+ const int t_T_1 = t_0 + (T-1)*C;
317
+
318
+ for (int t = t_T_1; t >= t_0; t -= C)
319
+ {
320
+ __syncthreads();
321
+ r[i] = float(_r[t]);
322
+ k[i] = float(_k[t]);
323
+ w_[i] = __expf(-__expf(float(_w[t])));
324
+ __syncthreads();
325
+
326
+ const float gyy = float(_gy[t]);
327
+ float gv = 0;
328
+
329
+ #pragma unroll
330
+ for (int j = 0; j < _N_; j++)
331
+ {
332
+ float& s = sdddd[j];
333
+ float x = gyy * r[j];
334
+
335
+ gv += (u_[j] * x + s) * k[j];
336
+ s = x + s * w_[j];
337
+ }
338
+ _gv[t] = F_out(gv);
339
+ }
340
+ }
341
+
342
+ template <typename F_in, typename F_out>
343
+ __global__ void kernel_backward_201(const int B, const int T, const int C, const int H,
344
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w,
345
+ const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
346
+ F_out *__restrict__ const _gw)
347
+ {
348
+ const int b = blockIdx.x / H;
349
+ const int h = blockIdx.x % H;
350
+ const int i = threadIdx.x;
351
+
352
+ __shared__ float v[_N_], gy[_N_];
353
+ float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0};
354
+
355
+ const int t_0 = b*T*C + h*_N_ + i;
356
+ const int t_1 = t_0 + C;
357
+ const int t_2 = t_0 + 2*C;
358
+ const int t_T_1 = t_0 + (T-1)*C;
359
+
360
+ for (int t = t_T_1; t > t_1; t -= C)
361
+ {
362
+ __syncthreads();
363
+ gy[i] = float(_gy[t]);
364
+ v[i] = float(_v[t-2*C]);
365
+ __syncthreads();
366
+
367
+ const float r = float(_r[t]);
368
+ const float w = __expf(-__expf(float(_w[t-C])));
369
+ float sum = 0.0f;
370
+
371
+ #pragma unroll
372
+ for (int j = 0; j < _N_; j++)
373
+ {
374
+ float& s = saaaa[j];
375
+ float x = r * gy[j];
376
+ s = (s + x) * w;
377
+ sum += s * v[j];
378
+ }
379
+ sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]);
380
+ }
381
+
382
+ float sss = sbbbb[0];
383
+ _gw[t_0] = 0;
384
+ _gw[t_1] = F_out(sss * -__expf(float(_w[t_1])));
385
+
386
+ for (int t = t_2; t < t_T_1; t += C)
387
+ {
388
+ __syncthreads();
389
+ gy[i] = float(_gy[t]);
390
+ v[i] = float(_v[t-2*C]);
391
+ __syncthreads();
392
+
393
+ const float w = __expf(-__expf(float(_w[t-C])));
394
+ const float k = float(_k[t-2*C]);
395
+ float sum = 0.0f;
396
+
397
+ #pragma unroll
398
+ for (int j = 0; j < _N_; j++)
399
+ {
400
+ float& s = scccc[j];
401
+ float x = k * v[j];
402
+ s = (s + x) * w;
403
+ sum += s * gy[j];
404
+ }
405
+ sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t]));
406
+ _gw[t] = F_out(sss * -__expf(float(_w[t])));
407
+ }
408
+ _gw[t_T_1] = 0;
409
+ }
410
+
411
+
412
+
413
+
414
+
415
+ template <typename T_in, typename T_out>
416
+ void HostApplyRWKVWithState(hipStream_t stream,int B, int T, int C, int H, bool S, const int32_t* state_map,
417
+ const T_in *input_r,const T_in *input_k,const T_in *input_v,
418
+ const T_in *input_w,const T_in *input_u,T_out *input_s, T_out *output_y, T_out *output_s) {
419
+ assert(H*_N_ == C);
420
+ //assert(_N_%4 == 0);
421
+ kernel_forward_state<<<dim3(B * H), dim3(_N_), _N_ * 4 * sizeof(float),stream>>>(B, T, C, H, S, state_map, input_r, input_k, input_v, input_w, input_u,input_s, output_y,output_s);
422
+
423
+ }
424
+
425
+
426
+
427
+ template <typename T_in, typename T_out>
428
+ void HostApplyRWKV(hipStream_t stream,int B, int T, int C, int H,
429
+ const T_in *input_r,const T_in *input_k,const T_in *input_v,
430
+ const T_in *input_w,const T_in *input_u,T_out *output_y) {
431
+ assert(H*_N_ == C);
432
+ //assert(_N_%4 == 0);
433
+ kernel_forward<<<dim3(B * H), dim3(_N_), _N_ * 4 * sizeof(float),stream>>>(B, T, C, H, input_r, input_k, input_v, input_w, input_u, output_y);
434
+
435
+ }
436
+ //todo 为kernel设置正确的sharememory大小
437
+ template <typename T_in, typename T_out>
438
+ void HostApplyGradient(hipStream_t stream,int B, int T, int C, int H,
439
+ T_in *r, T_in *k, T_in *v, T_in *w, T_in *u, T_out *gy, T_out *gr, T_out *gk, T_out *gv, T_out *gw, T_out *gu)
440
+ {
441
+ assert(H*_N_ == C);
442
+ kernel_backward_101<<<dim3(B * H), dim3(_N_),_N_ * 2 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gr, gu);
443
+ kernel_backward_102<<<dim3(B * H), dim3(_N_),_N_ * 2 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gk);
444
+ kernel_backward_103<<<dim3(B * H), dim3(_N_),_N_ * 4 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gv);
445
+ kernel_backward_201<<<dim3(B * H), dim3(_N_),_N_ * 2 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gw);
446
+ }
447
+
448
+ }
449
+
450
+
451
+ namespace gpu_ops {
452
+
453
+ void rwkv_forward_fn(hipStream_t stream, void **buffers,
454
+ const char *opaque,
455
+ std::size_t opaque_len) {
456
+ const WKVDescriptor &d = *UnpackDescriptor<WKVDescriptor>(opaque, opaque_len);
457
+
458
+ DISPATCH_Vector_TYPES(
459
+ d.x_type, d.y_type, "rwkv_forward_kernel",
460
+ HostApplyRWKV<input_type, output_type>(
461
+ stream, d.B, d.T, d.C, d.H,
462
+ static_cast<input_type *>(buffers[0]),static_cast<input_type *>(buffers[1]),static_cast<input_type *>(buffers[2]),
463
+ static_cast<input_type *>(buffers[3]),static_cast<input_type *>(buffers[4]),static_cast<output_type *>(buffers[5])
464
+ );
465
+ )
466
+ }
467
+
468
+
469
+ void rwkv_forward_with_state_fn(hipStream_t stream, void **buffers,
470
+ const char *opaque,
471
+ std::size_t opaque_len) {
472
+ const WKVDescriptor &d = *UnpackDescriptor<WKVDescriptor>(opaque, opaque_len);
473
+
474
+ DISPATCH_Vector_TYPES(
475
+ d.x_type, d.y_type, "rwkv_forward_with_state_kernel",
476
+ if(d.S){
477
+ HostApplyRWKVWithState<input_type, output_type>(
478
+ stream, d.B, d.T, d.C, d.H, true, static_cast<int32_t *>(buffers[0])/*map*/,
479
+ static_cast<input_type *>(buffers[1])/*r*/,static_cast<input_type *>(buffers[2])/*k*/,static_cast<input_type *>(buffers[3])/*v*/,
480
+ static_cast<input_type *>(buffers[4])/*w*/,static_cast<input_type *>(buffers[5])/*u*/,static_cast<output_type *>(buffers[6])/*s*/,
481
+ static_cast<output_type *>(buffers[7])/*y*/,static_cast<output_type *>(buffers[8])/*ys*/
482
+ );
483
+ }else{
484
+ HostApplyRWKVWithState<input_type, output_type>(
485
+ stream, d.B, d.T, d.C, d.H, false, nullptr,
486
+ static_cast<input_type *>(buffers[0])/*r*/,static_cast<input_type *>(buffers[1])/*k*/,static_cast<input_type *>(buffers[2])/*v*/,
487
+ static_cast<input_type *>(buffers[3])/*w*/,static_cast<input_type *>(buffers[4])/*u*/,nullptr/*s*/,
488
+ static_cast<output_type *>(buffers[5])/*y*/,static_cast<output_type *>(buffers[6])/*ys*/
489
+ );
490
+ }
491
+ )
492
+ }
493
+
494
+
495
+ void rwkv_backward_fn(hipStream_t stream, void **buffers,
496
+ const char *opaque,
497
+ std::size_t opaque_len) {
498
+ const WKVDescriptor &d = *UnpackDescriptor<WKVDescriptor>(opaque, opaque_len);
499
+
500
+ DISPATCH_Vector_TYPES(
501
+ d.x_type, d.y_type, "rwkv_backward_kernel",
502
+ HostApplyGradient<input_type, output_type>(
503
+ stream, d.B, d.T, d.C, d.H,
504
+ static_cast<input_type *>(buffers[0]),static_cast<input_type *>(buffers[1]),static_cast<input_type *>(buffers[2]),
505
+ static_cast<input_type *>(buffers[3]),static_cast<input_type *>(buffers[4]),static_cast<output_type *>(buffers[5]),
506
+ static_cast<output_type *>(buffers[6]),static_cast<output_type *>(buffers[7]),static_cast<output_type *>(buffers[8]),
507
+ static_cast<output_type *>(buffers[9]),static_cast<output_type *>(buffers[10])
508
+ );
509
+ )
510
+ }
511
+
512
+
513
+ } // namespace gpu_ops
514
+
@@ -1,7 +1,6 @@
1
1
  import os
2
2
  import pybind11
3
3
  import importlib
4
- import sys
5
4
  import sysconfig
6
5
  import subprocess
7
6
  from functools import partial, reduce
@@ -10,12 +9,9 @@ import jax
10
9
  import jax.numpy as jnp
11
10
  from jax import core, dtypes
12
11
  from jax.core import ShapedArray
13
- from jax.experimental.custom_partitioning import custom_partitioning
14
- from jax.experimental.pjit import pjit
15
- from jax.interpreters import batching, mlir, xla
12
+ from jax.interpreters import mlir, xla
16
13
  from jax.interpreters.mlir import ir
17
14
  from jax.lib import xla_client
18
- from jax.sharding import Mesh, NamedSharding, PartitionSpec
19
15
  from jaxlib.hlo_helpers import custom_call
20
16
 
21
17
 
@@ -367,9 +363,10 @@ class RWKVKernelOperator:
367
363
  def _rwkv_fwd_with_state(r, k, v, w, u, init_state=None, state_map=None):
368
364
  bz = r.shape[0]
369
365
  if init_state is not None: # shape=(B,H,D,D)
370
- assert len(init_state.shape) in [3, 4], (
371
- "init_state的shape为(Batch_size,num_heads,head_size,head_size)"
372
- )
366
+ assert len(init_state.shape) in [
367
+ 3,
368
+ 4,
369
+ ], "init_state的shape为(Batch_size,num_heads,head_size,head_size)"
373
370
  if len(init_state.shape) == 3:
374
371
  state_map = jnp.zeros((bz,), dtype=jnp.int32)
375
372
 
@@ -390,9 +387,10 @@ class RWKVKernelOperator:
390
387
  assert False, "未实现"
391
388
  else:
392
389
  # assert state_map is not None,"请传入一个state_map,这是一个int32类型的shape为(bz,)的数组,存放的是int_state到每一维度上的映射关系"
393
- assert state_map.dtype in [jnp.int64, jnp.int32], (
394
- "state_map的数值类型必须为int32"
395
- )
390
+ assert state_map.dtype in [
391
+ jnp.int64,
392
+ jnp.int32,
393
+ ], "state_map的数值类型必须为int32"
396
394
  state_map = jnp.astype(state_map, jnp.int32)
397
395
  assert jnp.all(state_map >= 0) and jnp.add(state_map < bz), (
398
396
  f"state_map内为state的映射下标,因此范围为: [0,{bz})"
@@ -633,11 +631,11 @@ class RWKVKernelOperator:
633
631
  assert os.path.exists(cu_src)
634
632
  cu_dst = os.path.join(target_dir, "rwkv_kernels.hip.o")
635
633
  kernel_cmd = (
636
- f"hipcc -O3 --hipstdpar -xhip -fopenmp -ffast-math"
637
- + f" -munsafe-fp-atomics -enable-vectorize-compares"
634
+ "hipcc -O3 --hipstdpar -xhip -fopenmp -ffast-math"
635
+ + " -munsafe-fp-atomics -enable-vectorize-compares"
638
636
  + f" -I{cuda_lib_dir} -I{pybind11.get_include()}"
639
- + f" -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
640
- + f" --gpu-max-threads-per-block=120"
637
+ + " -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
638
+ + " --gpu-max-threads-per-block=120"
641
639
  + f" -c {cu_src} -o {cu_dst} -D _N_={head_size} -D _T_={max_sequence_length}"
642
640
  )
643
641
  else:
@@ -645,8 +643,8 @@ class RWKVKernelOperator:
645
643
  assert os.path.exists(cu_src)
646
644
  cu_dst = os.path.join(target_dir, "rwkv_kernels.cu.o")
647
645
  kernel_cmd = (
648
- f"nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3"
649
- + f" --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86]"
646
+ "nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3"
647
+ + " --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86]"
650
648
  + f" -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c {cu_src} -o {cu_dst} -D _N_={head_size} -D _T_={max_sequence_length}"
651
649
  )
652
650
  build_cmds.append(kernel_cmd)
@@ -660,14 +658,14 @@ class RWKVKernelOperator:
660
658
  if use_rocm:
661
659
  cpp_cmd = (
662
660
  f"c++ -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
663
- + f" -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
664
- + f" -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects"
661
+ + " -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
662
+ + " -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects"
665
663
  + f" -o {cpp_dst} -c {cpp_src}"
666
664
  )
667
665
  else:
668
666
  cpp_cmd = (
669
667
  f"c++ -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
670
- + f" -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects"
668
+ + " -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects"
671
669
  + f" -o {cpp_dst} -c {cpp_src}"
672
670
  )
673
671
  build_cmds.append(cpp_cmd)
@@ -677,13 +675,13 @@ class RWKVKernelOperator:
677
675
  assembly_cmd = (
678
676
  f"c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o {so_dst} {cpp_dst} {cu_dst}"
679
677
  + f" -fPIC -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
680
- + f" -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
681
- + f" -L/opt/rocm/lib -lamdhip64 -lpthread -ldl"
678
+ + " -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
679
+ + " -L/opt/rocm/lib -lamdhip64 -lpthread -ldl"
682
680
  )
683
681
  else:
684
682
  assembly_cmd = (
685
683
  f"c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o {so_dst} {cpp_dst} {cu_dst}"
686
- + f" -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl"
684
+ + " -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl"
687
685
  )
688
686
  build_cmds.append(assembly_cmd)
689
687