rwkv-ops 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- rwkv_ops/__init__.py +5 -6
- rwkv_ops/rwkv6_kernel/__init__.py +0 -6
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +21 -23
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +14 -10
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +4 -4
- rwkv_ops/rwkv7_kernel/__init__.py +80 -29
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +279 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +237 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +6 -5
- rwkv_ops/rwkv7_kernel/native_keras_op.py +5 -6
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +123 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +165 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +35 -0
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/METADATA +28 -27
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/RECORD +30 -13
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/WHEEL +1 -2
- rwkv_ops-0.2.2.dist-info/top_level.txt +0 -1
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info/licenses}/LICENSE.txt +0 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
/* Copyright 2024 The JAX Authors.
|
|
2
|
+
|
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
you may not use this file except in compliance with the License.
|
|
5
|
+
You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
See the License for the specific language governing permissions and
|
|
13
|
+
limitations under the License.
|
|
14
|
+
==============================================================================*/
|
|
15
|
+
|
|
16
|
+
// This header extends kernel_helpers.h with the pybind11 specific interface to
|
|
17
|
+
// serializing descriptors. It also adds a pybind11 function for wrapping our
|
|
18
|
+
// custom calls in a Python capsule. This is separate from kernel_helpers so
|
|
19
|
+
// that the CUDA code itself doesn't include pybind11. I don't think that this
|
|
20
|
+
// is strictly necessary, but they do it in jaxlib, so let's do it here too.
|
|
21
|
+
|
|
22
|
+
#ifndef _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
|
|
23
|
+
#define _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
|
|
24
|
+
|
|
25
|
+
#include <pybind11/pybind11.h>
|
|
26
|
+
|
|
27
|
+
#include "kernel_helpers.h"
|
|
28
|
+
|
|
29
|
+
namespace gpu_ops {
|
|
30
|
+
|
|
31
|
+
template <typename T> pybind11::bytes PackDescriptor(const T &descriptor) {
|
|
32
|
+
return pybind11::bytes(PackDescriptorAsString(descriptor));
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
template <typename T> pybind11::capsule EncapsulateFunction(T *fn) {
|
|
36
|
+
return pybind11::capsule(bit_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
} // namespace gpu_ops
|
|
40
|
+
|
|
41
|
+
#endif
|
|
@@ -0,0 +1,514 @@
|
|
|
1
|
+
/* Copyright 2024 The JAX Authors.
|
|
2
|
+
|
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
you may not use this file except in compliance with the License.
|
|
5
|
+
You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
See the License for the specific language governing permissions and
|
|
13
|
+
limitations under the License.
|
|
14
|
+
==============================================================================*/
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
#include <hip/hip_runtime.h>
|
|
18
|
+
#include "kernel_helpers.h"
|
|
19
|
+
#include "kernels.h"
|
|
20
|
+
#include "stdio.h"
|
|
21
|
+
#include <hip/hip_bf16.h>
|
|
22
|
+
#include <hip/hip_fp16.h>
|
|
23
|
+
#include <iostream>
|
|
24
|
+
#include <assert.h>
|
|
25
|
+
namespace {
|
|
26
|
+
|
|
27
|
+
#define DISPATCH_Vector_TYPES(TYPEIN, TYPEOUT,NAME, ...) \
|
|
28
|
+
switch (TYPEIN) { \
|
|
29
|
+
case gpu_ops::ElementType::F32: { \
|
|
30
|
+
using input_type = float; \
|
|
31
|
+
switch (TYPEOUT) { \
|
|
32
|
+
case gpu_ops::ElementType::F32: { \
|
|
33
|
+
using output_type = float; \
|
|
34
|
+
__VA_ARGS__; \
|
|
35
|
+
break; \
|
|
36
|
+
} \
|
|
37
|
+
case gpu_ops::ElementType::F16: { \
|
|
38
|
+
using output_type = __half; \
|
|
39
|
+
__VA_ARGS__; \
|
|
40
|
+
break; \
|
|
41
|
+
} \
|
|
42
|
+
case gpu_ops::ElementType::BF16: { \
|
|
43
|
+
using output_type = hip_bfloat16; \
|
|
44
|
+
__VA_ARGS__; \
|
|
45
|
+
break; \
|
|
46
|
+
} \
|
|
47
|
+
default: \
|
|
48
|
+
break; \
|
|
49
|
+
} \
|
|
50
|
+
break; \
|
|
51
|
+
} \
|
|
52
|
+
case gpu_ops::ElementType::F16: { \
|
|
53
|
+
using input_type = __half; \
|
|
54
|
+
switch (TYPEOUT) { \
|
|
55
|
+
case gpu_ops::ElementType::F32: { \
|
|
56
|
+
using output_type = float; \
|
|
57
|
+
__VA_ARGS__; \
|
|
58
|
+
break; \
|
|
59
|
+
} \
|
|
60
|
+
case gpu_ops::ElementType::F16: { \
|
|
61
|
+
using output_type = __half; \
|
|
62
|
+
__VA_ARGS__; \
|
|
63
|
+
break; \
|
|
64
|
+
} \
|
|
65
|
+
case gpu_ops::ElementType::BF16: { \
|
|
66
|
+
using output_type = hip_bfloat16; \
|
|
67
|
+
__VA_ARGS__; \
|
|
68
|
+
break; \
|
|
69
|
+
} \
|
|
70
|
+
default: \
|
|
71
|
+
break; \
|
|
72
|
+
} \
|
|
73
|
+
break; \
|
|
74
|
+
} \
|
|
75
|
+
case gpu_ops::ElementType::BF16: { \
|
|
76
|
+
using input_type = hip_bfloat16; \
|
|
77
|
+
switch (TYPEOUT) { \
|
|
78
|
+
case gpu_ops::ElementType::F32: { \
|
|
79
|
+
using output_type = float; \
|
|
80
|
+
__VA_ARGS__; \
|
|
81
|
+
break; \
|
|
82
|
+
} \
|
|
83
|
+
case gpu_ops::ElementType::F16: { \
|
|
84
|
+
using output_type = __half; \
|
|
85
|
+
__VA_ARGS__; \
|
|
86
|
+
break; \
|
|
87
|
+
} \
|
|
88
|
+
case gpu_ops::ElementType::BF16: { \
|
|
89
|
+
using output_type = hip_bfloat16; \
|
|
90
|
+
__VA_ARGS__; \
|
|
91
|
+
break; \
|
|
92
|
+
} \
|
|
93
|
+
default: \
|
|
94
|
+
break; \
|
|
95
|
+
} \
|
|
96
|
+
break; \
|
|
97
|
+
} \
|
|
98
|
+
default: \
|
|
99
|
+
break; \
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
static_assert(_N_ % 4 ==0,"the size of head must be the times of 4.");
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
template <typename F_in,typename F_out>
|
|
110
|
+
__device__ void kernel_forward_core(const int B, const int T, const int C, const int H,const int b, const int h, const int i, const float* state,
|
|
111
|
+
const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
|
|
112
|
+
F_out *__restrict__ const _y)
|
|
113
|
+
{
|
|
114
|
+
|
|
115
|
+
_u += h*_N_;
|
|
116
|
+
|
|
117
|
+
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
|
118
|
+
//float state[_N_] = {0};
|
|
119
|
+
|
|
120
|
+
__syncthreads();
|
|
121
|
+
u[i] = float(_u[i]);
|
|
122
|
+
__syncthreads();
|
|
123
|
+
|
|
124
|
+
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
|
125
|
+
{
|
|
126
|
+
__syncthreads();
|
|
127
|
+
w[i] = __expf(-__expf(float(_w[t])));
|
|
128
|
+
r[i] = float(_r[t]);
|
|
129
|
+
k[i] = float(_k[t]);
|
|
130
|
+
__syncthreads();
|
|
131
|
+
|
|
132
|
+
const float v = float(_v[t]);
|
|
133
|
+
float y = 0;
|
|
134
|
+
|
|
135
|
+
#pragma unroll
|
|
136
|
+
for (int j = 0; j < _N_; j+=4)
|
|
137
|
+
{
|
|
138
|
+
const float4& r_ = (float4&)(r[j]);
|
|
139
|
+
const float4& k_ = (float4&)(k[j]);
|
|
140
|
+
const float4& w_ = (float4&)(w[j]);
|
|
141
|
+
const float4& u_ = (float4&)(u[j]);
|
|
142
|
+
float4& s = (float4&)(state[j]);
|
|
143
|
+
float4 x;
|
|
144
|
+
|
|
145
|
+
x.x = k_.x * v;
|
|
146
|
+
x.y = k_.y * v;
|
|
147
|
+
x.z = k_.z * v;
|
|
148
|
+
x.w = k_.w * v;
|
|
149
|
+
|
|
150
|
+
y += r_.x * (u_.x * x.x + s.x);
|
|
151
|
+
y += r_.y * (u_.y * x.y + s.y);
|
|
152
|
+
y += r_.z * (u_.z * x.z + s.z);
|
|
153
|
+
y += r_.w * (u_.w * x.w + s.w);
|
|
154
|
+
|
|
155
|
+
s.x = s.x * w_.x + x.x;
|
|
156
|
+
s.y = s.y * w_.y + x.y;
|
|
157
|
+
s.z = s.z * w_.z + x.z;
|
|
158
|
+
s.w = s.w * w_.w + x.w;
|
|
159
|
+
}
|
|
160
|
+
_y[t] = F_out(y);
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
template <typename F_in,typename F_out>
|
|
168
|
+
__global__ void kernel_forward_state(const int B, const int T, const int C, const int H,const bool is_custom_state,const int32_t* map,
|
|
169
|
+
const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
|
|
170
|
+
const F_out *__restrict__ _s, F_out *__restrict__ const _y, F_out *__restrict__ const _ys)
|
|
171
|
+
{
|
|
172
|
+
const int b = blockIdx.x / H;
|
|
173
|
+
const int h = blockIdx.x % H;
|
|
174
|
+
const int i = threadIdx.x;
|
|
175
|
+
float state[_N_] = {0};
|
|
176
|
+
if(is_custom_state){
|
|
177
|
+
assert(map[b] >=0 && map[b] < B);
|
|
178
|
+
|
|
179
|
+
const int64_t input_state_offset = map[b] * H * _N_ *_N_ + h * _N_ * _N_ + i;
|
|
180
|
+
|
|
181
|
+
for(int j= 0; j< _N_; j++){
|
|
182
|
+
state[j] = float(_s[j*_N_ + input_state_offset]);
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
const int64_t current_state_offset = b * H * _N_ *_N_ + h * _N_ * _N_ + i;
|
|
187
|
+
|
|
188
|
+
kernel_forward_core(B, T, C, H, b, h, i, state, _r, _k, _v, _w, _u, _y);
|
|
189
|
+
for(int j=0; j< _N_; j++){
|
|
190
|
+
_ys[j*_N_ + current_state_offset] = F_out(state[j]);
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
template <typename F_in,typename F_out>
|
|
196
|
+
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
|
|
197
|
+
const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
|
|
198
|
+
F_out *__restrict__ const _y)
|
|
199
|
+
{
|
|
200
|
+
const int b = blockIdx.x / H;
|
|
201
|
+
const int h = blockIdx.x % H;
|
|
202
|
+
const int i = threadIdx.x;
|
|
203
|
+
float state[_N_] = {0};
|
|
204
|
+
kernel_forward_core(B, T, C, H, b, h, i, state, _r, _k, _v, _w, _u, _y);
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
template <typename F_in, typename F_out>
|
|
208
|
+
__global__ void kernel_backward_101(const int B, const int T, const int C, const int H,
|
|
209
|
+
const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w,
|
|
210
|
+
const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
|
|
211
|
+
F_out *__restrict__ const _gr, F_out *__restrict__ const _gu)
|
|
212
|
+
{
|
|
213
|
+
const int b = blockIdx.x / H;
|
|
214
|
+
const int h = blockIdx.x % H;
|
|
215
|
+
const int i = threadIdx.x;
|
|
216
|
+
|
|
217
|
+
__shared__ float v[_N_], gy[_N_];
|
|
218
|
+
|
|
219
|
+
const float u = float(_u[h*_N_ + i]);
|
|
220
|
+
|
|
221
|
+
float state[_N_] = {0};
|
|
222
|
+
|
|
223
|
+
const int t_0 = b*T*C + h*_N_ + i;
|
|
224
|
+
const int t_T = t_0 + T*C;
|
|
225
|
+
|
|
226
|
+
float gu = 0;
|
|
227
|
+
for (int t = t_0; t < t_T; t += C)
|
|
228
|
+
{
|
|
229
|
+
__syncthreads();
|
|
230
|
+
v[i] = float(_v[t]);
|
|
231
|
+
gy[i] = float(_gy[t]);
|
|
232
|
+
__syncthreads();
|
|
233
|
+
|
|
234
|
+
const float k = float(_k[t]);
|
|
235
|
+
const float w = __expf(-__expf(float(_w[t])));
|
|
236
|
+
float gr = 0, gu_ = 0;
|
|
237
|
+
|
|
238
|
+
#pragma unroll
|
|
239
|
+
for (int j = 0; j < _N_; j++)
|
|
240
|
+
{
|
|
241
|
+
float& s = state[j];
|
|
242
|
+
float x = k * v[j];
|
|
243
|
+
|
|
244
|
+
gr += (u * x + s) * gy[j];
|
|
245
|
+
gu_ += x * gy[j];
|
|
246
|
+
s = s * w + x;
|
|
247
|
+
}
|
|
248
|
+
_gr[t] = F_out(gr);
|
|
249
|
+
gu += float(_r[t]) * gu_;
|
|
250
|
+
}
|
|
251
|
+
_gu[b*C + h*_N_ + i] = F_out(gu);
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
template <typename F_in, typename F_out>
|
|
255
|
+
__global__ void kernel_backward_102(const int B, const int T, const int C, const int H,
|
|
256
|
+
const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v,
|
|
257
|
+
const F_in *__restrict__ _w, const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
|
|
258
|
+
F_out *__restrict__ const _gk)
|
|
259
|
+
{
|
|
260
|
+
const int b = blockIdx.x / H;
|
|
261
|
+
const int h = blockIdx.x % H;
|
|
262
|
+
const int i = threadIdx.x;
|
|
263
|
+
|
|
264
|
+
__shared__ float v[_N_], gy[_N_];
|
|
265
|
+
|
|
266
|
+
const float u = float(_u[h*_N_ + i]);
|
|
267
|
+
|
|
268
|
+
float scccc[_N_] = {0};
|
|
269
|
+
|
|
270
|
+
const int t_0 = b*T*C + h*_N_ + i;
|
|
271
|
+
const int t_T_1 = t_0 + (T-1)*C;
|
|
272
|
+
|
|
273
|
+
for (int t = t_T_1; t >= t_0; t -= C)
|
|
274
|
+
{
|
|
275
|
+
__syncthreads();
|
|
276
|
+
v[i] = float(_v[t]);
|
|
277
|
+
gy[i] = float(_gy[t]);
|
|
278
|
+
__syncthreads();
|
|
279
|
+
|
|
280
|
+
const float rr = float(_r[t]);
|
|
281
|
+
const float w = __expf(-__expf(float(_w[t])));
|
|
282
|
+
float gk = 0;
|
|
283
|
+
|
|
284
|
+
#pragma unroll
|
|
285
|
+
for (int j = 0; j < _N_; j++)
|
|
286
|
+
{
|
|
287
|
+
float& s = scccc[j];
|
|
288
|
+
float x = rr * gy[j];
|
|
289
|
+
|
|
290
|
+
gk += (u * x + s) * v[j];
|
|
291
|
+
s = x + s * w;
|
|
292
|
+
}
|
|
293
|
+
_gk[t] = F_out(gk);
|
|
294
|
+
}
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
template <typename F_in, typename F_out>
|
|
298
|
+
__global__ void kernel_backward_103(const int B, const int T, const int C, const int H,
|
|
299
|
+
const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v,
|
|
300
|
+
const F_in *__restrict__ _w, const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
|
|
301
|
+
F_out *__restrict__ const _gv)
|
|
302
|
+
{
|
|
303
|
+
const int b = blockIdx.x / H;
|
|
304
|
+
const int h = blockIdx.x % H;
|
|
305
|
+
const int i = threadIdx.x;
|
|
306
|
+
_u += h*_N_;
|
|
307
|
+
|
|
308
|
+
__shared__ float u_[_N_], r[_N_], k[_N_], w_[_N_];
|
|
309
|
+
__syncthreads();
|
|
310
|
+
u_[i] = float(_u[i]);
|
|
311
|
+
__syncthreads();
|
|
312
|
+
|
|
313
|
+
float sdddd[_N_] = {0};
|
|
314
|
+
|
|
315
|
+
const int t_0 = b*T*C + h*_N_ + i;
|
|
316
|
+
const int t_T_1 = t_0 + (T-1)*C;
|
|
317
|
+
|
|
318
|
+
for (int t = t_T_1; t >= t_0; t -= C)
|
|
319
|
+
{
|
|
320
|
+
__syncthreads();
|
|
321
|
+
r[i] = float(_r[t]);
|
|
322
|
+
k[i] = float(_k[t]);
|
|
323
|
+
w_[i] = __expf(-__expf(float(_w[t])));
|
|
324
|
+
__syncthreads();
|
|
325
|
+
|
|
326
|
+
const float gyy = float(_gy[t]);
|
|
327
|
+
float gv = 0;
|
|
328
|
+
|
|
329
|
+
#pragma unroll
|
|
330
|
+
for (int j = 0; j < _N_; j++)
|
|
331
|
+
{
|
|
332
|
+
float& s = sdddd[j];
|
|
333
|
+
float x = gyy * r[j];
|
|
334
|
+
|
|
335
|
+
gv += (u_[j] * x + s) * k[j];
|
|
336
|
+
s = x + s * w_[j];
|
|
337
|
+
}
|
|
338
|
+
_gv[t] = F_out(gv);
|
|
339
|
+
}
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
template <typename F_in, typename F_out>
|
|
343
|
+
__global__ void kernel_backward_201(const int B, const int T, const int C, const int H,
|
|
344
|
+
const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w,
|
|
345
|
+
const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
|
|
346
|
+
F_out *__restrict__ const _gw)
|
|
347
|
+
{
|
|
348
|
+
const int b = blockIdx.x / H;
|
|
349
|
+
const int h = blockIdx.x % H;
|
|
350
|
+
const int i = threadIdx.x;
|
|
351
|
+
|
|
352
|
+
__shared__ float v[_N_], gy[_N_];
|
|
353
|
+
float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0};
|
|
354
|
+
|
|
355
|
+
const int t_0 = b*T*C + h*_N_ + i;
|
|
356
|
+
const int t_1 = t_0 + C;
|
|
357
|
+
const int t_2 = t_0 + 2*C;
|
|
358
|
+
const int t_T_1 = t_0 + (T-1)*C;
|
|
359
|
+
|
|
360
|
+
for (int t = t_T_1; t > t_1; t -= C)
|
|
361
|
+
{
|
|
362
|
+
__syncthreads();
|
|
363
|
+
gy[i] = float(_gy[t]);
|
|
364
|
+
v[i] = float(_v[t-2*C]);
|
|
365
|
+
__syncthreads();
|
|
366
|
+
|
|
367
|
+
const float r = float(_r[t]);
|
|
368
|
+
const float w = __expf(-__expf(float(_w[t-C])));
|
|
369
|
+
float sum = 0.0f;
|
|
370
|
+
|
|
371
|
+
#pragma unroll
|
|
372
|
+
for (int j = 0; j < _N_; j++)
|
|
373
|
+
{
|
|
374
|
+
float& s = saaaa[j];
|
|
375
|
+
float x = r * gy[j];
|
|
376
|
+
s = (s + x) * w;
|
|
377
|
+
sum += s * v[j];
|
|
378
|
+
}
|
|
379
|
+
sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]);
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
float sss = sbbbb[0];
|
|
383
|
+
_gw[t_0] = 0;
|
|
384
|
+
_gw[t_1] = F_out(sss * -__expf(float(_w[t_1])));
|
|
385
|
+
|
|
386
|
+
for (int t = t_2; t < t_T_1; t += C)
|
|
387
|
+
{
|
|
388
|
+
__syncthreads();
|
|
389
|
+
gy[i] = float(_gy[t]);
|
|
390
|
+
v[i] = float(_v[t-2*C]);
|
|
391
|
+
__syncthreads();
|
|
392
|
+
|
|
393
|
+
const float w = __expf(-__expf(float(_w[t-C])));
|
|
394
|
+
const float k = float(_k[t-2*C]);
|
|
395
|
+
float sum = 0.0f;
|
|
396
|
+
|
|
397
|
+
#pragma unroll
|
|
398
|
+
for (int j = 0; j < _N_; j++)
|
|
399
|
+
{
|
|
400
|
+
float& s = scccc[j];
|
|
401
|
+
float x = k * v[j];
|
|
402
|
+
s = (s + x) * w;
|
|
403
|
+
sum += s * gy[j];
|
|
404
|
+
}
|
|
405
|
+
sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t]));
|
|
406
|
+
_gw[t] = F_out(sss * -__expf(float(_w[t])));
|
|
407
|
+
}
|
|
408
|
+
_gw[t_T_1] = 0;
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
template <typename T_in, typename T_out>
|
|
416
|
+
void HostApplyRWKVWithState(hipStream_t stream,int B, int T, int C, int H, bool S, const int32_t* state_map,
|
|
417
|
+
const T_in *input_r,const T_in *input_k,const T_in *input_v,
|
|
418
|
+
const T_in *input_w,const T_in *input_u,T_out *input_s, T_out *output_y, T_out *output_s) {
|
|
419
|
+
assert(H*_N_ == C);
|
|
420
|
+
//assert(_N_%4 == 0);
|
|
421
|
+
kernel_forward_state<<<dim3(B * H), dim3(_N_), _N_ * 4 * sizeof(float),stream>>>(B, T, C, H, S, state_map, input_r, input_k, input_v, input_w, input_u,input_s, output_y,output_s);
|
|
422
|
+
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
template <typename T_in, typename T_out>
|
|
428
|
+
void HostApplyRWKV(hipStream_t stream,int B, int T, int C, int H,
|
|
429
|
+
const T_in *input_r,const T_in *input_k,const T_in *input_v,
|
|
430
|
+
const T_in *input_w,const T_in *input_u,T_out *output_y) {
|
|
431
|
+
assert(H*_N_ == C);
|
|
432
|
+
//assert(_N_%4 == 0);
|
|
433
|
+
kernel_forward<<<dim3(B * H), dim3(_N_), _N_ * 4 * sizeof(float),stream>>>(B, T, C, H, input_r, input_k, input_v, input_w, input_u, output_y);
|
|
434
|
+
|
|
435
|
+
}
|
|
436
|
+
//todo 为kernel设置正确的sharememory大小
|
|
437
|
+
template <typename T_in, typename T_out>
|
|
438
|
+
void HostApplyGradient(hipStream_t stream,int B, int T, int C, int H,
|
|
439
|
+
T_in *r, T_in *k, T_in *v, T_in *w, T_in *u, T_out *gy, T_out *gr, T_out *gk, T_out *gv, T_out *gw, T_out *gu)
|
|
440
|
+
{
|
|
441
|
+
assert(H*_N_ == C);
|
|
442
|
+
kernel_backward_101<<<dim3(B * H), dim3(_N_),_N_ * 2 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gr, gu);
|
|
443
|
+
kernel_backward_102<<<dim3(B * H), dim3(_N_),_N_ * 2 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gk);
|
|
444
|
+
kernel_backward_103<<<dim3(B * H), dim3(_N_),_N_ * 4 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gv);
|
|
445
|
+
kernel_backward_201<<<dim3(B * H), dim3(_N_),_N_ * 2 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gw);
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
namespace gpu_ops {
|
|
452
|
+
|
|
453
|
+
void rwkv_forward_fn(hipStream_t stream, void **buffers,
|
|
454
|
+
const char *opaque,
|
|
455
|
+
std::size_t opaque_len) {
|
|
456
|
+
const WKVDescriptor &d = *UnpackDescriptor<WKVDescriptor>(opaque, opaque_len);
|
|
457
|
+
|
|
458
|
+
DISPATCH_Vector_TYPES(
|
|
459
|
+
d.x_type, d.y_type, "rwkv_forward_kernel",
|
|
460
|
+
HostApplyRWKV<input_type, output_type>(
|
|
461
|
+
stream, d.B, d.T, d.C, d.H,
|
|
462
|
+
static_cast<input_type *>(buffers[0]),static_cast<input_type *>(buffers[1]),static_cast<input_type *>(buffers[2]),
|
|
463
|
+
static_cast<input_type *>(buffers[3]),static_cast<input_type *>(buffers[4]),static_cast<output_type *>(buffers[5])
|
|
464
|
+
);
|
|
465
|
+
)
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
void rwkv_forward_with_state_fn(hipStream_t stream, void **buffers,
|
|
470
|
+
const char *opaque,
|
|
471
|
+
std::size_t opaque_len) {
|
|
472
|
+
const WKVDescriptor &d = *UnpackDescriptor<WKVDescriptor>(opaque, opaque_len);
|
|
473
|
+
|
|
474
|
+
DISPATCH_Vector_TYPES(
|
|
475
|
+
d.x_type, d.y_type, "rwkv_forward_with_state_kernel",
|
|
476
|
+
if(d.S){
|
|
477
|
+
HostApplyRWKVWithState<input_type, output_type>(
|
|
478
|
+
stream, d.B, d.T, d.C, d.H, true, static_cast<int32_t *>(buffers[0])/*map*/,
|
|
479
|
+
static_cast<input_type *>(buffers[1])/*r*/,static_cast<input_type *>(buffers[2])/*k*/,static_cast<input_type *>(buffers[3])/*v*/,
|
|
480
|
+
static_cast<input_type *>(buffers[4])/*w*/,static_cast<input_type *>(buffers[5])/*u*/,static_cast<output_type *>(buffers[6])/*s*/,
|
|
481
|
+
static_cast<output_type *>(buffers[7])/*y*/,static_cast<output_type *>(buffers[8])/*ys*/
|
|
482
|
+
);
|
|
483
|
+
}else{
|
|
484
|
+
HostApplyRWKVWithState<input_type, output_type>(
|
|
485
|
+
stream, d.B, d.T, d.C, d.H, false, nullptr,
|
|
486
|
+
static_cast<input_type *>(buffers[0])/*r*/,static_cast<input_type *>(buffers[1])/*k*/,static_cast<input_type *>(buffers[2])/*v*/,
|
|
487
|
+
static_cast<input_type *>(buffers[3])/*w*/,static_cast<input_type *>(buffers[4])/*u*/,nullptr/*s*/,
|
|
488
|
+
static_cast<output_type *>(buffers[5])/*y*/,static_cast<output_type *>(buffers[6])/*ys*/
|
|
489
|
+
);
|
|
490
|
+
}
|
|
491
|
+
)
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
void rwkv_backward_fn(hipStream_t stream, void **buffers,
|
|
496
|
+
const char *opaque,
|
|
497
|
+
std::size_t opaque_len) {
|
|
498
|
+
const WKVDescriptor &d = *UnpackDescriptor<WKVDescriptor>(opaque, opaque_len);
|
|
499
|
+
|
|
500
|
+
DISPATCH_Vector_TYPES(
|
|
501
|
+
d.x_type, d.y_type, "rwkv_backward_kernel",
|
|
502
|
+
HostApplyGradient<input_type, output_type>(
|
|
503
|
+
stream, d.B, d.T, d.C, d.H,
|
|
504
|
+
static_cast<input_type *>(buffers[0]),static_cast<input_type *>(buffers[1]),static_cast<input_type *>(buffers[2]),
|
|
505
|
+
static_cast<input_type *>(buffers[3]),static_cast<input_type *>(buffers[4]),static_cast<output_type *>(buffers[5]),
|
|
506
|
+
static_cast<output_type *>(buffers[6]),static_cast<output_type *>(buffers[7]),static_cast<output_type *>(buffers[8]),
|
|
507
|
+
static_cast<output_type *>(buffers[9]),static_cast<output_type *>(buffers[10])
|
|
508
|
+
);
|
|
509
|
+
)
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
} // namespace gpu_ops
|
|
514
|
+
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import pybind11
|
|
3
3
|
import importlib
|
|
4
|
-
import sys
|
|
5
4
|
import sysconfig
|
|
6
5
|
import subprocess
|
|
7
6
|
from functools import partial, reduce
|
|
@@ -10,12 +9,9 @@ import jax
|
|
|
10
9
|
import jax.numpy as jnp
|
|
11
10
|
from jax import core, dtypes
|
|
12
11
|
from jax.core import ShapedArray
|
|
13
|
-
from jax.
|
|
14
|
-
from jax.experimental.pjit import pjit
|
|
15
|
-
from jax.interpreters import batching, mlir, xla
|
|
12
|
+
from jax.interpreters import mlir, xla
|
|
16
13
|
from jax.interpreters.mlir import ir
|
|
17
14
|
from jax.lib import xla_client
|
|
18
|
-
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
19
15
|
from jaxlib.hlo_helpers import custom_call
|
|
20
16
|
|
|
21
17
|
|
|
@@ -367,9 +363,10 @@ class RWKVKernelOperator:
|
|
|
367
363
|
def _rwkv_fwd_with_state(r, k, v, w, u, init_state=None, state_map=None):
|
|
368
364
|
bz = r.shape[0]
|
|
369
365
|
if init_state is not None: # shape=(B,H,D,D)
|
|
370
|
-
assert len(init_state.shape) in [
|
|
371
|
-
|
|
372
|
-
|
|
366
|
+
assert len(init_state.shape) in [
|
|
367
|
+
3,
|
|
368
|
+
4,
|
|
369
|
+
], "init_state的shape为(Batch_size,num_heads,head_size,head_size)"
|
|
373
370
|
if len(init_state.shape) == 3:
|
|
374
371
|
state_map = jnp.zeros((bz,), dtype=jnp.int32)
|
|
375
372
|
|
|
@@ -390,9 +387,10 @@ class RWKVKernelOperator:
|
|
|
390
387
|
assert False, "未实现"
|
|
391
388
|
else:
|
|
392
389
|
# assert state_map is not None,"请传入一个state_map,这是一个int32类型的shape为(bz,)的数组,存放的是int_state到每一维度上的映射关系"
|
|
393
|
-
assert state_map.dtype in [
|
|
394
|
-
|
|
395
|
-
|
|
390
|
+
assert state_map.dtype in [
|
|
391
|
+
jnp.int64,
|
|
392
|
+
jnp.int32,
|
|
393
|
+
], "state_map的数值类型必须为int32"
|
|
396
394
|
state_map = jnp.astype(state_map, jnp.int32)
|
|
397
395
|
assert jnp.all(state_map >= 0) and jnp.add(state_map < bz), (
|
|
398
396
|
f"state_map内为state的映射下标,因此范围为: [0,{bz})"
|
|
@@ -633,11 +631,11 @@ class RWKVKernelOperator:
|
|
|
633
631
|
assert os.path.exists(cu_src)
|
|
634
632
|
cu_dst = os.path.join(target_dir, "rwkv_kernels.hip.o")
|
|
635
633
|
kernel_cmd = (
|
|
636
|
-
|
|
637
|
-
+
|
|
634
|
+
"hipcc -O3 --hipstdpar -xhip -fopenmp -ffast-math"
|
|
635
|
+
+ " -munsafe-fp-atomics -enable-vectorize-compares"
|
|
638
636
|
+ f" -I{cuda_lib_dir} -I{pybind11.get_include()}"
|
|
639
|
-
+
|
|
640
|
-
+
|
|
637
|
+
+ " -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
|
|
638
|
+
+ " --gpu-max-threads-per-block=120"
|
|
641
639
|
+ f" -c {cu_src} -o {cu_dst} -D _N_={head_size} -D _T_={max_sequence_length}"
|
|
642
640
|
)
|
|
643
641
|
else:
|
|
@@ -645,8 +643,8 @@ class RWKVKernelOperator:
|
|
|
645
643
|
assert os.path.exists(cu_src)
|
|
646
644
|
cu_dst = os.path.join(target_dir, "rwkv_kernels.cu.o")
|
|
647
645
|
kernel_cmd = (
|
|
648
|
-
|
|
649
|
-
+
|
|
646
|
+
"nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3"
|
|
647
|
+
+ " --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86]"
|
|
650
648
|
+ f" -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c {cu_src} -o {cu_dst} -D _N_={head_size} -D _T_={max_sequence_length}"
|
|
651
649
|
)
|
|
652
650
|
build_cmds.append(kernel_cmd)
|
|
@@ -660,14 +658,14 @@ class RWKVKernelOperator:
|
|
|
660
658
|
if use_rocm:
|
|
661
659
|
cpp_cmd = (
|
|
662
660
|
f"c++ -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
|
|
663
|
-
+
|
|
664
|
-
+
|
|
661
|
+
+ " -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
|
|
662
|
+
+ " -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects"
|
|
665
663
|
+ f" -o {cpp_dst} -c {cpp_src}"
|
|
666
664
|
)
|
|
667
665
|
else:
|
|
668
666
|
cpp_cmd = (
|
|
669
667
|
f"c++ -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
|
|
670
|
-
+
|
|
668
|
+
+ " -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects"
|
|
671
669
|
+ f" -o {cpp_dst} -c {cpp_src}"
|
|
672
670
|
)
|
|
673
671
|
build_cmds.append(cpp_cmd)
|
|
@@ -677,13 +675,13 @@ class RWKVKernelOperator:
|
|
|
677
675
|
assembly_cmd = (
|
|
678
676
|
f"c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o {so_dst} {cpp_dst} {cu_dst}"
|
|
679
677
|
+ f" -fPIC -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
|
|
680
|
-
+
|
|
681
|
-
+
|
|
678
|
+
+ " -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
|
|
679
|
+
+ " -L/opt/rocm/lib -lamdhip64 -lpthread -ldl"
|
|
682
680
|
)
|
|
683
681
|
else:
|
|
684
682
|
assembly_cmd = (
|
|
685
683
|
f"c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o {so_dst} {cpp_dst} {cu_dst}"
|
|
686
|
-
+
|
|
684
|
+
+ " -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl"
|
|
687
685
|
)
|
|
688
686
|
build_cmds.append(assembly_cmd)
|
|
689
687
|
|