rwkv-ops 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (31) hide show
  1. rwkv_ops/__init__.py +5 -6
  2. rwkv_ops/rwkv6_kernel/__init__.py +0 -6
  3. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  4. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  5. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  6. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  7. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  8. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  9. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  10. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  11. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  12. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  13. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +21 -23
  14. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +14 -10
  15. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  16. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  17. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +4 -4
  18. rwkv_ops/rwkv7_kernel/__init__.py +77 -29
  19. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  20. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +279 -0
  21. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +237 -0
  22. rwkv_ops/rwkv7_kernel/jax_op.py +6 -5
  23. rwkv_ops/rwkv7_kernel/native_keras_op.py +5 -6
  24. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +123 -0
  25. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +165 -0
  26. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +35 -0
  27. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/METADATA +28 -27
  28. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/RECORD +30 -13
  29. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/WHEEL +1 -2
  30. rwkv_ops-0.2.2.dist-info/top_level.txt +0 -1
  31. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info/licenses}/LICENSE.txt +0 -0
@@ -0,0 +1,512 @@
1
+ /* Copyright 2024 The JAX Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "kernel_helpers.h"
17
+ #include "kernels.h"
18
+ #include "stdio.h"
19
+ #include <cuda_bf16.h>
20
+ #include <cuda_fp16.h>
21
+ #include <iostream>
22
+ #include <assert.h>
23
+ namespace {
24
+
25
+ #define DISPATCH_Vector_TYPES(TYPEIN, TYPEOUT,NAME, ...) \
26
+ switch (TYPEIN) { \
27
+ case gpu_ops::ElementType::F32: { \
28
+ using input_type = float; \
29
+ switch (TYPEOUT) { \
30
+ case gpu_ops::ElementType::F32: { \
31
+ using output_type = float; \
32
+ __VA_ARGS__; \
33
+ break; \
34
+ } \
35
+ case gpu_ops::ElementType::F16: { \
36
+ using output_type = __half; \
37
+ __VA_ARGS__; \
38
+ break; \
39
+ } \
40
+ case gpu_ops::ElementType::BF16: { \
41
+ using output_type = __nv_bfloat16; \
42
+ __VA_ARGS__; \
43
+ break; \
44
+ } \
45
+ default: \
46
+ break; \
47
+ } \
48
+ break; \
49
+ } \
50
+ case gpu_ops::ElementType::F16: { \
51
+ using input_type = __half; \
52
+ switch (TYPEOUT) { \
53
+ case gpu_ops::ElementType::F32: { \
54
+ using output_type = float; \
55
+ __VA_ARGS__; \
56
+ break; \
57
+ } \
58
+ case gpu_ops::ElementType::F16: { \
59
+ using output_type = __half; \
60
+ __VA_ARGS__; \
61
+ break; \
62
+ } \
63
+ case gpu_ops::ElementType::BF16: { \
64
+ using output_type = __nv_bfloat16; \
65
+ __VA_ARGS__; \
66
+ break; \
67
+ } \
68
+ default: \
69
+ break; \
70
+ } \
71
+ break; \
72
+ } \
73
+ case gpu_ops::ElementType::BF16: { \
74
+ using input_type = __nv_bfloat16; \
75
+ switch (TYPEOUT) { \
76
+ case gpu_ops::ElementType::F32: { \
77
+ using output_type = float; \
78
+ __VA_ARGS__; \
79
+ break; \
80
+ } \
81
+ case gpu_ops::ElementType::F16: { \
82
+ using output_type = __half; \
83
+ __VA_ARGS__; \
84
+ break; \
85
+ } \
86
+ case gpu_ops::ElementType::BF16: { \
87
+ using output_type = __nv_bfloat16; \
88
+ __VA_ARGS__; \
89
+ break; \
90
+ } \
91
+ default: \
92
+ break; \
93
+ } \
94
+ break; \
95
+ } \
96
+ default: \
97
+ break; \
98
+ }
99
+
100
+
101
+ static_assert(_N_ % 4 ==0,"the size of head must be the times of 4.");
102
+
103
+
104
+
105
+
106
+
107
+ template <typename F_in,typename F_out>
108
+ __device__ void kernel_forward_core(const int B, const int T, const int C, const int H,const int b, const int h, const int i, const float* state,
109
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
110
+ F_out *__restrict__ const _y)
111
+ {
112
+
113
+ _u += h*_N_;
114
+
115
+ __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
116
+ //float state[_N_] = {0};
117
+
118
+ __syncthreads();
119
+ u[i] = float(_u[i]);
120
+ __syncthreads();
121
+
122
+ for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
123
+ {
124
+ __syncthreads();
125
+ w[i] = __expf(-__expf(float(_w[t])));
126
+ r[i] = float(_r[t]);
127
+ k[i] = float(_k[t]);
128
+ __syncthreads();
129
+
130
+ const float v = float(_v[t]);
131
+ float y = 0;
132
+
133
+ #pragma unroll
134
+ for (int j = 0; j < _N_; j+=4)
135
+ {
136
+ const float4& r_ = (float4&)(r[j]);
137
+ const float4& k_ = (float4&)(k[j]);
138
+ const float4& w_ = (float4&)(w[j]);
139
+ const float4& u_ = (float4&)(u[j]);
140
+ float4& s = (float4&)(state[j]);
141
+ float4 x;
142
+
143
+ x.x = k_.x * v;
144
+ x.y = k_.y * v;
145
+ x.z = k_.z * v;
146
+ x.w = k_.w * v;
147
+
148
+ y += r_.x * (u_.x * x.x + s.x);
149
+ y += r_.y * (u_.y * x.y + s.y);
150
+ y += r_.z * (u_.z * x.z + s.z);
151
+ y += r_.w * (u_.w * x.w + s.w);
152
+
153
+ s.x = s.x * w_.x + x.x;
154
+ s.y = s.y * w_.y + x.y;
155
+ s.z = s.z * w_.z + x.z;
156
+ s.w = s.w * w_.w + x.w;
157
+ }
158
+ _y[t] = F_out(y);
159
+ }
160
+ }
161
+
162
+
163
+
164
+
165
+ template <typename F_in,typename F_out>
166
+ __global__ void kernel_forward_state(const int B, const int T, const int C, const int H,const bool is_custom_state,const int32_t* map,
167
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
168
+ const F_out *__restrict__ _s, F_out *__restrict__ const _y, F_out *__restrict__ const _ys)
169
+ {
170
+ const int b = blockIdx.x / H;
171
+ const int h = blockIdx.x % H;
172
+ const int i = threadIdx.x;
173
+ float state[_N_] = {0};
174
+ if(is_custom_state){
175
+ assert(map[b] >=0 && map[b] < B);
176
+
177
+ const int64_t input_state_offset = map[b] * H * _N_ *_N_ + h * _N_ * _N_ + i;
178
+
179
+ for(int j= 0; j< _N_; j++){
180
+ state[j] = float(_s[j*_N_ + input_state_offset]);
181
+ }
182
+ }
183
+
184
+ const int64_t current_state_offset = b * H * _N_ *_N_ + h * _N_ * _N_ + i;
185
+
186
+ kernel_forward_core(B, T, C, H, b, h, i, state, _r, _k, _v, _w, _u, _y);
187
+ for(int j=0; j< _N_; j++){
188
+ _ys[j*_N_ + current_state_offset] = F_out(state[j]);
189
+ }
190
+ }
191
+
192
+
193
+ template <typename F_in,typename F_out>
194
+ __global__ void kernel_forward(const int B, const int T, const int C, const int H,
195
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
196
+ F_out *__restrict__ const _y)
197
+ {
198
+ const int b = blockIdx.x / H;
199
+ const int h = blockIdx.x % H;
200
+ const int i = threadIdx.x;
201
+ float state[_N_] = {0};
202
+ kernel_forward_core(B, T, C, H, b, h, i, state, _r, _k, _v, _w, _u, _y);
203
+ }
204
+
205
+ template <typename F_in, typename F_out>
206
+ __global__ void kernel_backward_101(const int B, const int T, const int C, const int H,
207
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w,
208
+ const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
209
+ F_out *__restrict__ const _gr, F_out *__restrict__ const _gu)
210
+ {
211
+ const int b = blockIdx.x / H;
212
+ const int h = blockIdx.x % H;
213
+ const int i = threadIdx.x;
214
+
215
+ __shared__ float v[_N_], gy[_N_];
216
+
217
+ const float u = float(_u[h*_N_ + i]);
218
+
219
+ float state[_N_] = {0};
220
+
221
+ const int t_0 = b*T*C + h*_N_ + i;
222
+ const int t_T = t_0 + T*C;
223
+
224
+ float gu = 0;
225
+ for (int t = t_0; t < t_T; t += C)
226
+ {
227
+ __syncthreads();
228
+ v[i] = float(_v[t]);
229
+ gy[i] = float(_gy[t]);
230
+ __syncthreads();
231
+
232
+ const float k = float(_k[t]);
233
+ const float w = __expf(-__expf(float(_w[t])));
234
+ float gr = 0, gu_ = 0;
235
+
236
+ #pragma unroll
237
+ for (int j = 0; j < _N_; j++)
238
+ {
239
+ float& s = state[j];
240
+ float x = k * v[j];
241
+
242
+ gr += (u * x + s) * gy[j];
243
+ gu_ += x * gy[j];
244
+ s = s * w + x;
245
+ }
246
+ _gr[t] = F_out(gr);
247
+ gu += float(_r[t]) * gu_;
248
+ }
249
+ _gu[b*C + h*_N_ + i] = F_out(gu);
250
+ }
251
+
252
+ template <typename F_in, typename F_out>
253
+ __global__ void kernel_backward_102(const int B, const int T, const int C, const int H,
254
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v,
255
+ const F_in *__restrict__ _w, const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
256
+ F_out *__restrict__ const _gk)
257
+ {
258
+ const int b = blockIdx.x / H;
259
+ const int h = blockIdx.x % H;
260
+ const int i = threadIdx.x;
261
+
262
+ __shared__ float v[_N_], gy[_N_];
263
+
264
+ const float u = float(_u[h*_N_ + i]);
265
+
266
+ float scccc[_N_] = {0};
267
+
268
+ const int t_0 = b*T*C + h*_N_ + i;
269
+ const int t_T_1 = t_0 + (T-1)*C;
270
+
271
+ for (int t = t_T_1; t >= t_0; t -= C)
272
+ {
273
+ __syncthreads();
274
+ v[i] = float(_v[t]);
275
+ gy[i] = float(_gy[t]);
276
+ __syncthreads();
277
+
278
+ const float rr = float(_r[t]);
279
+ const float w = __expf(-__expf(float(_w[t])));
280
+ float gk = 0;
281
+
282
+ #pragma unroll
283
+ for (int j = 0; j < _N_; j++)
284
+ {
285
+ float& s = scccc[j];
286
+ float x = rr * gy[j];
287
+
288
+ gk += (u * x + s) * v[j];
289
+ s = x + s * w;
290
+ }
291
+ _gk[t] = F_out(gk);
292
+ }
293
+ }
294
+
295
+ template <typename F_in, typename F_out>
296
+ __global__ void kernel_backward_103(const int B, const int T, const int C, const int H,
297
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v,
298
+ const F_in *__restrict__ _w, const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
299
+ F_out *__restrict__ const _gv)
300
+ {
301
+ const int b = blockIdx.x / H;
302
+ const int h = blockIdx.x % H;
303
+ const int i = threadIdx.x;
304
+ _u += h*_N_;
305
+
306
+ __shared__ float u_[_N_], r[_N_], k[_N_], w_[_N_];
307
+ __syncthreads();
308
+ u_[i] = float(_u[i]);
309
+ __syncthreads();
310
+
311
+ float sdddd[_N_] = {0};
312
+
313
+ const int t_0 = b*T*C + h*_N_ + i;
314
+ const int t_T_1 = t_0 + (T-1)*C;
315
+
316
+ for (int t = t_T_1; t >= t_0; t -= C)
317
+ {
318
+ __syncthreads();
319
+ r[i] = float(_r[t]);
320
+ k[i] = float(_k[t]);
321
+ w_[i] = __expf(-__expf(float(_w[t])));
322
+ __syncthreads();
323
+
324
+ const float gyy = float(_gy[t]);
325
+ float gv = 0;
326
+
327
+ #pragma unroll
328
+ for (int j = 0; j < _N_; j++)
329
+ {
330
+ float& s = sdddd[j];
331
+ float x = gyy * r[j];
332
+
333
+ gv += (u_[j] * x + s) * k[j];
334
+ s = x + s * w_[j];
335
+ }
336
+ _gv[t] = F_out(gv);
337
+ }
338
+ }
339
+
340
+ template <typename F_in, typename F_out>
341
+ __global__ void kernel_backward_201(const int B, const int T, const int C, const int H,
342
+ const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w,
343
+ const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
344
+ F_out *__restrict__ const _gw)
345
+ {
346
+ const int b = blockIdx.x / H;
347
+ const int h = blockIdx.x % H;
348
+ const int i = threadIdx.x;
349
+
350
+ __shared__ float v[_N_], gy[_N_];
351
+ float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0};
352
+
353
+ const int t_0 = b*T*C + h*_N_ + i;
354
+ const int t_1 = t_0 + C;
355
+ const int t_2 = t_0 + 2*C;
356
+ const int t_T_1 = t_0 + (T-1)*C;
357
+
358
+ for (int t = t_T_1; t > t_1; t -= C)
359
+ {
360
+ __syncthreads();
361
+ gy[i] = float(_gy[t]);
362
+ v[i] = float(_v[t-2*C]);
363
+ __syncthreads();
364
+
365
+ const float r = float(_r[t]);
366
+ const float w = __expf(-__expf(float(_w[t-C])));
367
+ float sum = 0.0f;
368
+
369
+ #pragma unroll
370
+ for (int j = 0; j < _N_; j++)
371
+ {
372
+ float& s = saaaa[j];
373
+ float x = r * gy[j];
374
+ s = (s + x) * w;
375
+ sum += s * v[j];
376
+ }
377
+ sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]);
378
+ }
379
+
380
+ float sss = sbbbb[0];
381
+ _gw[t_0] = 0;
382
+ _gw[t_1] = F_out(sss * -__expf(float(_w[t_1])));
383
+
384
+ for (int t = t_2; t < t_T_1; t += C)
385
+ {
386
+ __syncthreads();
387
+ gy[i] = float(_gy[t]);
388
+ v[i] = float(_v[t-2*C]);
389
+ __syncthreads();
390
+
391
+ const float w = __expf(-__expf(float(_w[t-C])));
392
+ const float k = float(_k[t-2*C]);
393
+ float sum = 0.0f;
394
+
395
+ #pragma unroll
396
+ for (int j = 0; j < _N_; j++)
397
+ {
398
+ float& s = scccc[j];
399
+ float x = k * v[j];
400
+ s = (s + x) * w;
401
+ sum += s * gy[j];
402
+ }
403
+ sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t]));
404
+ _gw[t] = F_out(sss * -__expf(float(_w[t])));
405
+ }
406
+ _gw[t_T_1] = 0;
407
+ }
408
+
409
+
410
+
411
+
412
+
413
+ template <typename T_in, typename T_out>
414
+ void HostApplyRWKVWithState(cudaStream_t stream,int B, int T, int C, int H, bool S, const int32_t* state_map,
415
+ const T_in *input_r,const T_in *input_k,const T_in *input_v,
416
+ const T_in *input_w,const T_in *input_u,T_out *input_s, T_out *output_y, T_out *output_s) {
417
+ assert(H*_N_ == C);
418
+ //assert(_N_%4 == 0);
419
+ kernel_forward_state<<<dim3(B * H), dim3(_N_), _N_ * 4 * sizeof(float),stream>>>(B, T, C, H, S, state_map, input_r, input_k, input_v, input_w, input_u,input_s, output_y,output_s);
420
+
421
+ }
422
+
423
+
424
+
425
+ template <typename T_in, typename T_out>
426
+ void HostApplyRWKV(cudaStream_t stream,int B, int T, int C, int H,
427
+ const T_in *input_r,const T_in *input_k,const T_in *input_v,
428
+ const T_in *input_w,const T_in *input_u,T_out *output_y) {
429
+ assert(H*_N_ == C);
430
+ //assert(_N_%4 == 0);
431
+ kernel_forward<<<dim3(B * H), dim3(_N_), _N_ * 4 * sizeof(float),stream>>>(B, T, C, H, input_r, input_k, input_v, input_w, input_u, output_y);
432
+
433
+ }
434
+ //todo 为kernel设置正确的sharememory大小
435
+ template <typename T_in, typename T_out>
436
+ void HostApplyGradient(cudaStream_t stream,int B, int T, int C, int H,
437
+ T_in *r, T_in *k, T_in *v, T_in *w, T_in *u, T_out *gy, T_out *gr, T_out *gk, T_out *gv, T_out *gw, T_out *gu)
438
+ {
439
+ assert(H*_N_ == C);
440
+ kernel_backward_101<<<dim3(B * H), dim3(_N_),_N_ * 2 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gr, gu);
441
+ kernel_backward_102<<<dim3(B * H), dim3(_N_),_N_ * 2 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gk);
442
+ kernel_backward_103<<<dim3(B * H), dim3(_N_),_N_ * 4 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gv);
443
+ kernel_backward_201<<<dim3(B * H), dim3(_N_),_N_ * 2 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gw);
444
+ }
445
+
446
+ }
447
+
448
+
449
+ namespace gpu_ops {
450
+
451
+ void rwkv_forward_fn(cudaStream_t stream, void **buffers,
452
+ const char *opaque,
453
+ std::size_t opaque_len) {
454
+ const WKVDescriptor &d = *UnpackDescriptor<WKVDescriptor>(opaque, opaque_len);
455
+
456
+ DISPATCH_Vector_TYPES(
457
+ d.x_type, d.y_type, "rwkv_forward_kernel",
458
+ HostApplyRWKV<input_type, output_type>(
459
+ stream, d.B, d.T, d.C, d.H,
460
+ static_cast<input_type *>(buffers[0]),static_cast<input_type *>(buffers[1]),static_cast<input_type *>(buffers[2]),
461
+ static_cast<input_type *>(buffers[3]),static_cast<input_type *>(buffers[4]),static_cast<output_type *>(buffers[5])
462
+ );
463
+ )
464
+ }
465
+
466
+
467
+ void rwkv_forward_with_state_fn(cudaStream_t stream, void **buffers,
468
+ const char *opaque,
469
+ std::size_t opaque_len) {
470
+ const WKVDescriptor &d = *UnpackDescriptor<WKVDescriptor>(opaque, opaque_len);
471
+
472
+ DISPATCH_Vector_TYPES(
473
+ d.x_type, d.y_type, "rwkv_forward_with_state_kernel",
474
+ if(d.S){
475
+ HostApplyRWKVWithState<input_type, output_type>(
476
+ stream, d.B, d.T, d.C, d.H, true, static_cast<int32_t *>(buffers[0])/*map*/,
477
+ static_cast<input_type *>(buffers[1])/*r*/,static_cast<input_type *>(buffers[2])/*k*/,static_cast<input_type *>(buffers[3])/*v*/,
478
+ static_cast<input_type *>(buffers[4])/*w*/,static_cast<input_type *>(buffers[5])/*u*/,static_cast<output_type *>(buffers[6])/*s*/,
479
+ static_cast<output_type *>(buffers[7])/*y*/,static_cast<output_type *>(buffers[8])/*ys*/
480
+ );
481
+ }else{
482
+ HostApplyRWKVWithState<input_type, output_type>(
483
+ stream, d.B, d.T, d.C, d.H, false, nullptr,
484
+ static_cast<input_type *>(buffers[0])/*r*/,static_cast<input_type *>(buffers[1])/*k*/,static_cast<input_type *>(buffers[2])/*v*/,
485
+ static_cast<input_type *>(buffers[3])/*w*/,static_cast<input_type *>(buffers[4])/*u*/,nullptr/*s*/,
486
+ static_cast<output_type *>(buffers[5])/*y*/,static_cast<output_type *>(buffers[6])/*ys*/
487
+ );
488
+ }
489
+ )
490
+ }
491
+
492
+
493
+ void rwkv_backward_fn(cudaStream_t stream, void **buffers,
494
+ const char *opaque,
495
+ std::size_t opaque_len) {
496
+ const WKVDescriptor &d = *UnpackDescriptor<WKVDescriptor>(opaque, opaque_len);
497
+
498
+ DISPATCH_Vector_TYPES(
499
+ d.x_type, d.y_type, "rwkv_backward_kernel",
500
+ HostApplyGradient<input_type, output_type>(
501
+ stream, d.B, d.T, d.C, d.H,
502
+ static_cast<input_type *>(buffers[0]),static_cast<input_type *>(buffers[1]),static_cast<input_type *>(buffers[2]),
503
+ static_cast<input_type *>(buffers[3]),static_cast<input_type *>(buffers[4]),static_cast<output_type *>(buffers[5]),
504
+ static_cast<output_type *>(buffers[6]),static_cast<output_type *>(buffers[7]),static_cast<output_type *>(buffers[8]),
505
+ static_cast<output_type *>(buffers[9]),static_cast<output_type *>(buffers[10])
506
+ );
507
+ )
508
+ }
509
+
510
+
511
+ } // namespace gpu_ops
512
+
@@ -0,0 +1,44 @@
1
+ /* Copyright 2024 The JAX Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "kernels.h"
17
+ #include "pybind11_kernel_helpers.h"
18
+
19
+ namespace {
20
+ pybind11::dict WKVRegistrations() {
21
+ pybind11::dict dict;
22
+ dict["wkv_forward"] =
23
+ gpu_ops::EncapsulateFunction(gpu_ops::rwkv_forward_fn);
24
+ dict["wkv_backward"] =
25
+ gpu_ops::EncapsulateFunction(gpu_ops::rwkv_backward_fn);
26
+ dict["wkv_forward_with_state"] =
27
+ gpu_ops::EncapsulateFunction(gpu_ops::rwkv_forward_with_state_fn);
28
+ return dict;
29
+ }
30
+
31
+ PYBIND11_MODULE(gpu_ops, m) {
32
+ m.def("get_rwkv_registrations", &WKVRegistrations);
33
+ m.def("create_rwkv_descriptor",
34
+ [](int B, int T,int C, int H,bool S, gpu_ops::ElementType input_type,gpu_ops::ElementType output_type) {
35
+ return gpu_ops::PackDescriptor(gpu_ops::WKVDescriptor{B, T, C, H, S, input_type, output_type});
36
+ });
37
+
38
+ pybind11::enum_<gpu_ops::ElementType>(m, "ElementType")
39
+ .value("BF16", gpu_ops::ElementType::BF16)
40
+ .value("F16", gpu_ops::ElementType::F16)
41
+ .value("F32", gpu_ops::ElementType::F32);
42
+
43
+ }
44
+ } // namespace
@@ -0,0 +1,64 @@
1
+ /* Copyright 2024 The JAX Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ // This header is not specific to our application and you'll probably want
17
+ // something like this for any extension you're building. This includes the
18
+ // infrastructure needed to serialize descriptors that are used with the
19
+ // "opaque" parameter of the GPU custom call. In our example we'll use this
20
+ // parameter to pass the size of our problem.
21
+
22
+ #ifndef _GPU_OPS_KERNEL_HELPERS_H_
23
+ #define _GPU_OPS_KERNEL_HELPERS_H_
24
+
25
+ #include <cstdint>
26
+ #include <stdexcept>
27
+ #include <string>
28
+ #include <type_traits>
29
+
30
+ #define JAX_APEX_WARP_SIZE 32
31
+
32
+ namespace gpu_ops {
33
+
34
+ // https://en.cppreference.com/w/cpp/numeric/bit_cast
35
+ template <class To, class From>
36
+ typename std::enable_if<sizeof(To) == sizeof(From) &&
37
+ std::is_trivially_copyable<From>::value &&
38
+ std::is_trivially_copyable<To>::value,
39
+ To>::type
40
+ bit_cast(const From &src) noexcept {
41
+ static_assert(std::is_trivially_constructible<To>::value,
42
+ "This implementation additionally requires destination type to "
43
+ "be trivially constructible");
44
+
45
+ To dst;
46
+ memcpy(&dst, &src, sizeof(To));
47
+ return dst;
48
+ }
49
+
50
+ template <typename T> std::string PackDescriptorAsString(const T &descriptor) {
51
+ return std::string(bit_cast<const char *>(&descriptor), sizeof(T));
52
+ }
53
+
54
+ template <typename T>
55
+ const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) {
56
+ if (opaque_len != sizeof(T)) {
57
+ throw std::runtime_error("Invalid opaque object size");
58
+ }
59
+ return bit_cast<const T *>(opaque);
60
+ }
61
+
62
+ } // namespace gpu_ops
63
+
64
+ #endif
@@ -0,0 +1,56 @@
1
+ /* Copyright 2024 The JAX Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef _GPU_OPS_KERNELS_H_
17
+ #define _GPU_OPS_KERNELS_H_
18
+
19
+ #include <hip/hip_runtime_api.h>
20
+
21
+ #include <cstddef>
22
+ #include <cstdint>
23
+
24
+ #ifndef _N_
25
+ #define _N_ 8
26
+ #endif
27
+ #ifndef _T_
28
+ #define _T_ 16
29
+ #endif
30
+ namespace gpu_ops {
31
+
32
+ enum ElementType { BF16, F16, F32 };
33
+
34
+ struct WKVDescriptor {
35
+ int B;
36
+ int T;
37
+ int C;
38
+ int H;
39
+ bool S;
40
+ ElementType x_type;
41
+ ElementType y_type;
42
+ };
43
+
44
+ void rwkv_forward_fn(hipStream_t stream, void **buffers,
45
+ const char *opaque,
46
+ std::size_t opaque_len);
47
+ void rwkv_backward_fn(hipStream_t stream, void **buffers,
48
+ const char *opaque,
49
+ std::size_t opaque_len);
50
+
51
+ void rwkv_forward_with_state_fn(hipStream_t stream, void **buffers,
52
+ const char *opaque,
53
+ std::size_t opaque_len);
54
+ } // namespace gpu_ops
55
+
56
+ #endif