rwkv-ops 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_ops/__init__.py +45 -0
- rwkv_ops/mhc_kernel/__init__.py +50 -0
- rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
- rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
- rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
- rwkv_ops/rwkv6_kernel/__init__.py +120 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +113 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.6.1.dist-info/METADATA +495 -0
- rwkv_ops-0.6.1.dist-info/RECORD +89 -0
- rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
- rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
/* Copyright 2024 The JAX Authors.
|
|
2
|
+
|
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
you may not use this file except in compliance with the License.
|
|
5
|
+
You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
See the License for the specific language governing permissions and
|
|
13
|
+
limitations under the License.
|
|
14
|
+
==============================================================================*/
|
|
15
|
+
|
|
16
|
+
// This header extends kernel_helpers.h with the pybind11 specific interface to
|
|
17
|
+
// serializing descriptors. It also adds a pybind11 function for wrapping our
|
|
18
|
+
// custom calls in a Python capsule. This is separate from kernel_helpers so
|
|
19
|
+
// that the CUDA code itself doesn't include pybind11. I don't think that this
|
|
20
|
+
// is strictly necessary, but they do it in jaxlib, so let's do it here too.
|
|
21
|
+
|
|
22
|
+
#ifndef _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
|
|
23
|
+
#define _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
|
|
24
|
+
|
|
25
|
+
#include <pybind11/pybind11.h>
|
|
26
|
+
|
|
27
|
+
#include "kernel_helpers.h"
|
|
28
|
+
|
|
29
|
+
namespace gpu_ops {
|
|
30
|
+
|
|
31
|
+
template <typename T> pybind11::bytes PackDescriptor(const T &descriptor) {
|
|
32
|
+
return pybind11::bytes(PackDescriptorAsString(descriptor));
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
template <typename T> pybind11::capsule EncapsulateFunction(T *fn) {
|
|
36
|
+
return pybind11::capsule(bit_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
} // namespace gpu_ops
|
|
40
|
+
|
|
41
|
+
#endif
|
|
@@ -0,0 +1,514 @@
|
|
|
1
|
+
/* Copyright 2024 The JAX Authors.
|
|
2
|
+
|
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
you may not use this file except in compliance with the License.
|
|
5
|
+
You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
See the License for the specific language governing permissions and
|
|
13
|
+
limitations under the License.
|
|
14
|
+
==============================================================================*/
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
#include <hip/hip_runtime.h>
|
|
18
|
+
#include "kernel_helpers.h"
|
|
19
|
+
#include "kernels.h"
|
|
20
|
+
#include "stdio.h"
|
|
21
|
+
#include <hip/hip_bf16.h>
|
|
22
|
+
#include <hip/hip_fp16.h>
|
|
23
|
+
#include <iostream>
|
|
24
|
+
#include <assert.h>
|
|
25
|
+
namespace {
|
|
26
|
+
|
|
27
|
+
#define DISPATCH_Vector_TYPES(TYPEIN, TYPEOUT,NAME, ...) \
|
|
28
|
+
switch (TYPEIN) { \
|
|
29
|
+
case gpu_ops::ElementType::F32: { \
|
|
30
|
+
using input_type = float; \
|
|
31
|
+
switch (TYPEOUT) { \
|
|
32
|
+
case gpu_ops::ElementType::F32: { \
|
|
33
|
+
using output_type = float; \
|
|
34
|
+
__VA_ARGS__; \
|
|
35
|
+
break; \
|
|
36
|
+
} \
|
|
37
|
+
case gpu_ops::ElementType::F16: { \
|
|
38
|
+
using output_type = __half; \
|
|
39
|
+
__VA_ARGS__; \
|
|
40
|
+
break; \
|
|
41
|
+
} \
|
|
42
|
+
case gpu_ops::ElementType::BF16: { \
|
|
43
|
+
using output_type = hip_bfloat16; \
|
|
44
|
+
__VA_ARGS__; \
|
|
45
|
+
break; \
|
|
46
|
+
} \
|
|
47
|
+
default: \
|
|
48
|
+
break; \
|
|
49
|
+
} \
|
|
50
|
+
break; \
|
|
51
|
+
} \
|
|
52
|
+
case gpu_ops::ElementType::F16: { \
|
|
53
|
+
using input_type = __half; \
|
|
54
|
+
switch (TYPEOUT) { \
|
|
55
|
+
case gpu_ops::ElementType::F32: { \
|
|
56
|
+
using output_type = float; \
|
|
57
|
+
__VA_ARGS__; \
|
|
58
|
+
break; \
|
|
59
|
+
} \
|
|
60
|
+
case gpu_ops::ElementType::F16: { \
|
|
61
|
+
using output_type = __half; \
|
|
62
|
+
__VA_ARGS__; \
|
|
63
|
+
break; \
|
|
64
|
+
} \
|
|
65
|
+
case gpu_ops::ElementType::BF16: { \
|
|
66
|
+
using output_type = hip_bfloat16; \
|
|
67
|
+
__VA_ARGS__; \
|
|
68
|
+
break; \
|
|
69
|
+
} \
|
|
70
|
+
default: \
|
|
71
|
+
break; \
|
|
72
|
+
} \
|
|
73
|
+
break; \
|
|
74
|
+
} \
|
|
75
|
+
case gpu_ops::ElementType::BF16: { \
|
|
76
|
+
using input_type = hip_bfloat16; \
|
|
77
|
+
switch (TYPEOUT) { \
|
|
78
|
+
case gpu_ops::ElementType::F32: { \
|
|
79
|
+
using output_type = float; \
|
|
80
|
+
__VA_ARGS__; \
|
|
81
|
+
break; \
|
|
82
|
+
} \
|
|
83
|
+
case gpu_ops::ElementType::F16: { \
|
|
84
|
+
using output_type = __half; \
|
|
85
|
+
__VA_ARGS__; \
|
|
86
|
+
break; \
|
|
87
|
+
} \
|
|
88
|
+
case gpu_ops::ElementType::BF16: { \
|
|
89
|
+
using output_type = hip_bfloat16; \
|
|
90
|
+
__VA_ARGS__; \
|
|
91
|
+
break; \
|
|
92
|
+
} \
|
|
93
|
+
default: \
|
|
94
|
+
break; \
|
|
95
|
+
} \
|
|
96
|
+
break; \
|
|
97
|
+
} \
|
|
98
|
+
default: \
|
|
99
|
+
break; \
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
static_assert(_N_ % 4 ==0,"the size of head must be the times of 4.");
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
template <typename F_in,typename F_out>
|
|
110
|
+
__device__ void kernel_forward_core(const int B, const int T, const int C, const int H,const int b, const int h, const int i, const float* state,
|
|
111
|
+
const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
|
|
112
|
+
F_out *__restrict__ const _y)
|
|
113
|
+
{
|
|
114
|
+
|
|
115
|
+
_u += h*_N_;
|
|
116
|
+
|
|
117
|
+
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
|
|
118
|
+
//float state[_N_] = {0};
|
|
119
|
+
|
|
120
|
+
__syncthreads();
|
|
121
|
+
u[i] = float(_u[i]);
|
|
122
|
+
__syncthreads();
|
|
123
|
+
|
|
124
|
+
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
|
|
125
|
+
{
|
|
126
|
+
__syncthreads();
|
|
127
|
+
w[i] = __expf(-__expf(float(_w[t])));
|
|
128
|
+
r[i] = float(_r[t]);
|
|
129
|
+
k[i] = float(_k[t]);
|
|
130
|
+
__syncthreads();
|
|
131
|
+
|
|
132
|
+
const float v = float(_v[t]);
|
|
133
|
+
float y = 0;
|
|
134
|
+
|
|
135
|
+
#pragma unroll
|
|
136
|
+
for (int j = 0; j < _N_; j+=4)
|
|
137
|
+
{
|
|
138
|
+
const float4& r_ = (float4&)(r[j]);
|
|
139
|
+
const float4& k_ = (float4&)(k[j]);
|
|
140
|
+
const float4& w_ = (float4&)(w[j]);
|
|
141
|
+
const float4& u_ = (float4&)(u[j]);
|
|
142
|
+
float4& s = (float4&)(state[j]);
|
|
143
|
+
float4 x;
|
|
144
|
+
|
|
145
|
+
x.x = k_.x * v;
|
|
146
|
+
x.y = k_.y * v;
|
|
147
|
+
x.z = k_.z * v;
|
|
148
|
+
x.w = k_.w * v;
|
|
149
|
+
|
|
150
|
+
y += r_.x * (u_.x * x.x + s.x);
|
|
151
|
+
y += r_.y * (u_.y * x.y + s.y);
|
|
152
|
+
y += r_.z * (u_.z * x.z + s.z);
|
|
153
|
+
y += r_.w * (u_.w * x.w + s.w);
|
|
154
|
+
|
|
155
|
+
s.x = s.x * w_.x + x.x;
|
|
156
|
+
s.y = s.y * w_.y + x.y;
|
|
157
|
+
s.z = s.z * w_.z + x.z;
|
|
158
|
+
s.w = s.w * w_.w + x.w;
|
|
159
|
+
}
|
|
160
|
+
_y[t] = F_out(y);
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
template <typename F_in,typename F_out>
|
|
168
|
+
__global__ void kernel_forward_state(const int B, const int T, const int C, const int H,const bool is_custom_state,const int32_t* map,
|
|
169
|
+
const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
|
|
170
|
+
const F_out *__restrict__ _s, F_out *__restrict__ const _y, F_out *__restrict__ const _ys)
|
|
171
|
+
{
|
|
172
|
+
const int b = blockIdx.x / H;
|
|
173
|
+
const int h = blockIdx.x % H;
|
|
174
|
+
const int i = threadIdx.x;
|
|
175
|
+
float state[_N_] = {0};
|
|
176
|
+
if(is_custom_state){
|
|
177
|
+
assert(map[b] >=0 && map[b] < B);
|
|
178
|
+
|
|
179
|
+
const int64_t input_state_offset = map[b] * H * _N_ *_N_ + h * _N_ * _N_ + i;
|
|
180
|
+
|
|
181
|
+
for(int j= 0; j< _N_; j++){
|
|
182
|
+
state[j] = float(_s[j*_N_ + input_state_offset]);
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
const int64_t current_state_offset = b * H * _N_ *_N_ + h * _N_ * _N_ + i;
|
|
187
|
+
|
|
188
|
+
kernel_forward_core(B, T, C, H, b, h, i, state, _r, _k, _v, _w, _u, _y);
|
|
189
|
+
for(int j=0; j< _N_; j++){
|
|
190
|
+
_ys[j*_N_ + current_state_offset] = F_out(state[j]);
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
template <typename F_in,typename F_out>
|
|
196
|
+
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
|
|
197
|
+
const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w, const F_in *__restrict__ _u,
|
|
198
|
+
F_out *__restrict__ const _y)
|
|
199
|
+
{
|
|
200
|
+
const int b = blockIdx.x / H;
|
|
201
|
+
const int h = blockIdx.x % H;
|
|
202
|
+
const int i = threadIdx.x;
|
|
203
|
+
float state[_N_] = {0};
|
|
204
|
+
kernel_forward_core(B, T, C, H, b, h, i, state, _r, _k, _v, _w, _u, _y);
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
template <typename F_in, typename F_out>
|
|
208
|
+
__global__ void kernel_backward_101(const int B, const int T, const int C, const int H,
|
|
209
|
+
const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w,
|
|
210
|
+
const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
|
|
211
|
+
F_out *__restrict__ const _gr, F_out *__restrict__ const _gu)
|
|
212
|
+
{
|
|
213
|
+
const int b = blockIdx.x / H;
|
|
214
|
+
const int h = blockIdx.x % H;
|
|
215
|
+
const int i = threadIdx.x;
|
|
216
|
+
|
|
217
|
+
__shared__ float v[_N_], gy[_N_];
|
|
218
|
+
|
|
219
|
+
const float u = float(_u[h*_N_ + i]);
|
|
220
|
+
|
|
221
|
+
float state[_N_] = {0};
|
|
222
|
+
|
|
223
|
+
const int t_0 = b*T*C + h*_N_ + i;
|
|
224
|
+
const int t_T = t_0 + T*C;
|
|
225
|
+
|
|
226
|
+
float gu = 0;
|
|
227
|
+
for (int t = t_0; t < t_T; t += C)
|
|
228
|
+
{
|
|
229
|
+
__syncthreads();
|
|
230
|
+
v[i] = float(_v[t]);
|
|
231
|
+
gy[i] = float(_gy[t]);
|
|
232
|
+
__syncthreads();
|
|
233
|
+
|
|
234
|
+
const float k = float(_k[t]);
|
|
235
|
+
const float w = __expf(-__expf(float(_w[t])));
|
|
236
|
+
float gr = 0, gu_ = 0;
|
|
237
|
+
|
|
238
|
+
#pragma unroll
|
|
239
|
+
for (int j = 0; j < _N_; j++)
|
|
240
|
+
{
|
|
241
|
+
float& s = state[j];
|
|
242
|
+
float x = k * v[j];
|
|
243
|
+
|
|
244
|
+
gr += (u * x + s) * gy[j];
|
|
245
|
+
gu_ += x * gy[j];
|
|
246
|
+
s = s * w + x;
|
|
247
|
+
}
|
|
248
|
+
_gr[t] = F_out(gr);
|
|
249
|
+
gu += float(_r[t]) * gu_;
|
|
250
|
+
}
|
|
251
|
+
_gu[b*C + h*_N_ + i] = F_out(gu);
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
template <typename F_in, typename F_out>
|
|
255
|
+
__global__ void kernel_backward_102(const int B, const int T, const int C, const int H,
|
|
256
|
+
const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v,
|
|
257
|
+
const F_in *__restrict__ _w, const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
|
|
258
|
+
F_out *__restrict__ const _gk)
|
|
259
|
+
{
|
|
260
|
+
const int b = blockIdx.x / H;
|
|
261
|
+
const int h = blockIdx.x % H;
|
|
262
|
+
const int i = threadIdx.x;
|
|
263
|
+
|
|
264
|
+
__shared__ float v[_N_], gy[_N_];
|
|
265
|
+
|
|
266
|
+
const float u = float(_u[h*_N_ + i]);
|
|
267
|
+
|
|
268
|
+
float scccc[_N_] = {0};
|
|
269
|
+
|
|
270
|
+
const int t_0 = b*T*C + h*_N_ + i;
|
|
271
|
+
const int t_T_1 = t_0 + (T-1)*C;
|
|
272
|
+
|
|
273
|
+
for (int t = t_T_1; t >= t_0; t -= C)
|
|
274
|
+
{
|
|
275
|
+
__syncthreads();
|
|
276
|
+
v[i] = float(_v[t]);
|
|
277
|
+
gy[i] = float(_gy[t]);
|
|
278
|
+
__syncthreads();
|
|
279
|
+
|
|
280
|
+
const float rr = float(_r[t]);
|
|
281
|
+
const float w = __expf(-__expf(float(_w[t])));
|
|
282
|
+
float gk = 0;
|
|
283
|
+
|
|
284
|
+
#pragma unroll
|
|
285
|
+
for (int j = 0; j < _N_; j++)
|
|
286
|
+
{
|
|
287
|
+
float& s = scccc[j];
|
|
288
|
+
float x = rr * gy[j];
|
|
289
|
+
|
|
290
|
+
gk += (u * x + s) * v[j];
|
|
291
|
+
s = x + s * w;
|
|
292
|
+
}
|
|
293
|
+
_gk[t] = F_out(gk);
|
|
294
|
+
}
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
template <typename F_in, typename F_out>
|
|
298
|
+
__global__ void kernel_backward_103(const int B, const int T, const int C, const int H,
|
|
299
|
+
const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v,
|
|
300
|
+
const F_in *__restrict__ _w, const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
|
|
301
|
+
F_out *__restrict__ const _gv)
|
|
302
|
+
{
|
|
303
|
+
const int b = blockIdx.x / H;
|
|
304
|
+
const int h = blockIdx.x % H;
|
|
305
|
+
const int i = threadIdx.x;
|
|
306
|
+
_u += h*_N_;
|
|
307
|
+
|
|
308
|
+
__shared__ float u_[_N_], r[_N_], k[_N_], w_[_N_];
|
|
309
|
+
__syncthreads();
|
|
310
|
+
u_[i] = float(_u[i]);
|
|
311
|
+
__syncthreads();
|
|
312
|
+
|
|
313
|
+
float sdddd[_N_] = {0};
|
|
314
|
+
|
|
315
|
+
const int t_0 = b*T*C + h*_N_ + i;
|
|
316
|
+
const int t_T_1 = t_0 + (T-1)*C;
|
|
317
|
+
|
|
318
|
+
for (int t = t_T_1; t >= t_0; t -= C)
|
|
319
|
+
{
|
|
320
|
+
__syncthreads();
|
|
321
|
+
r[i] = float(_r[t]);
|
|
322
|
+
k[i] = float(_k[t]);
|
|
323
|
+
w_[i] = __expf(-__expf(float(_w[t])));
|
|
324
|
+
__syncthreads();
|
|
325
|
+
|
|
326
|
+
const float gyy = float(_gy[t]);
|
|
327
|
+
float gv = 0;
|
|
328
|
+
|
|
329
|
+
#pragma unroll
|
|
330
|
+
for (int j = 0; j < _N_; j++)
|
|
331
|
+
{
|
|
332
|
+
float& s = sdddd[j];
|
|
333
|
+
float x = gyy * r[j];
|
|
334
|
+
|
|
335
|
+
gv += (u_[j] * x + s) * k[j];
|
|
336
|
+
s = x + s * w_[j];
|
|
337
|
+
}
|
|
338
|
+
_gv[t] = F_out(gv);
|
|
339
|
+
}
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
template <typename F_in, typename F_out>
|
|
343
|
+
__global__ void kernel_backward_201(const int B, const int T, const int C, const int H,
|
|
344
|
+
const F_in *__restrict__ const _r, const F_in *__restrict__ const _k, const F_in *__restrict__ const _v, const F_in *__restrict__ _w,
|
|
345
|
+
const F_in *__restrict__ _u, const F_out *__restrict__ const _gy,
|
|
346
|
+
F_out *__restrict__ const _gw)
|
|
347
|
+
{
|
|
348
|
+
const int b = blockIdx.x / H;
|
|
349
|
+
const int h = blockIdx.x % H;
|
|
350
|
+
const int i = threadIdx.x;
|
|
351
|
+
|
|
352
|
+
__shared__ float v[_N_], gy[_N_];
|
|
353
|
+
float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0};
|
|
354
|
+
|
|
355
|
+
const int t_0 = b*T*C + h*_N_ + i;
|
|
356
|
+
const int t_1 = t_0 + C;
|
|
357
|
+
const int t_2 = t_0 + 2*C;
|
|
358
|
+
const int t_T_1 = t_0 + (T-1)*C;
|
|
359
|
+
|
|
360
|
+
for (int t = t_T_1; t > t_1; t -= C)
|
|
361
|
+
{
|
|
362
|
+
__syncthreads();
|
|
363
|
+
gy[i] = float(_gy[t]);
|
|
364
|
+
v[i] = float(_v[t-2*C]);
|
|
365
|
+
__syncthreads();
|
|
366
|
+
|
|
367
|
+
const float r = float(_r[t]);
|
|
368
|
+
const float w = __expf(-__expf(float(_w[t-C])));
|
|
369
|
+
float sum = 0.0f;
|
|
370
|
+
|
|
371
|
+
#pragma unroll
|
|
372
|
+
for (int j = 0; j < _N_; j++)
|
|
373
|
+
{
|
|
374
|
+
float& s = saaaa[j];
|
|
375
|
+
float x = r * gy[j];
|
|
376
|
+
s = (s + x) * w;
|
|
377
|
+
sum += s * v[j];
|
|
378
|
+
}
|
|
379
|
+
sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]);
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
float sss = sbbbb[0];
|
|
383
|
+
_gw[t_0] = 0;
|
|
384
|
+
_gw[t_1] = F_out(sss * -__expf(float(_w[t_1])));
|
|
385
|
+
|
|
386
|
+
for (int t = t_2; t < t_T_1; t += C)
|
|
387
|
+
{
|
|
388
|
+
__syncthreads();
|
|
389
|
+
gy[i] = float(_gy[t]);
|
|
390
|
+
v[i] = float(_v[t-2*C]);
|
|
391
|
+
__syncthreads();
|
|
392
|
+
|
|
393
|
+
const float w = __expf(-__expf(float(_w[t-C])));
|
|
394
|
+
const float k = float(_k[t-2*C]);
|
|
395
|
+
float sum = 0.0f;
|
|
396
|
+
|
|
397
|
+
#pragma unroll
|
|
398
|
+
for (int j = 0; j < _N_; j++)
|
|
399
|
+
{
|
|
400
|
+
float& s = scccc[j];
|
|
401
|
+
float x = k * v[j];
|
|
402
|
+
s = (s + x) * w;
|
|
403
|
+
sum += s * gy[j];
|
|
404
|
+
}
|
|
405
|
+
sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t]));
|
|
406
|
+
_gw[t] = F_out(sss * -__expf(float(_w[t])));
|
|
407
|
+
}
|
|
408
|
+
_gw[t_T_1] = 0;
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
template <typename T_in, typename T_out>
|
|
416
|
+
void HostApplyRWKVWithState(hipStream_t stream,int B, int T, int C, int H, bool S, const int32_t* state_map,
|
|
417
|
+
const T_in *input_r,const T_in *input_k,const T_in *input_v,
|
|
418
|
+
const T_in *input_w,const T_in *input_u,T_out *input_s, T_out *output_y, T_out *output_s) {
|
|
419
|
+
assert(H*_N_ == C);
|
|
420
|
+
//assert(_N_%4 == 0);
|
|
421
|
+
kernel_forward_state<<<dim3(B * H), dim3(_N_), _N_ * 4 * sizeof(float),stream>>>(B, T, C, H, S, state_map, input_r, input_k, input_v, input_w, input_u,input_s, output_y,output_s);
|
|
422
|
+
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
template <typename T_in, typename T_out>
|
|
428
|
+
void HostApplyRWKV(hipStream_t stream,int B, int T, int C, int H,
|
|
429
|
+
const T_in *input_r,const T_in *input_k,const T_in *input_v,
|
|
430
|
+
const T_in *input_w,const T_in *input_u,T_out *output_y) {
|
|
431
|
+
assert(H*_N_ == C);
|
|
432
|
+
//assert(_N_%4 == 0);
|
|
433
|
+
kernel_forward<<<dim3(B * H), dim3(_N_), _N_ * 4 * sizeof(float),stream>>>(B, T, C, H, input_r, input_k, input_v, input_w, input_u, output_y);
|
|
434
|
+
|
|
435
|
+
}
|
|
436
|
+
//todo 为kernel设置正确的sharememory大小
|
|
437
|
+
template <typename T_in, typename T_out>
|
|
438
|
+
void HostApplyGradient(hipStream_t stream,int B, int T, int C, int H,
|
|
439
|
+
T_in *r, T_in *k, T_in *v, T_in *w, T_in *u, T_out *gy, T_out *gr, T_out *gk, T_out *gv, T_out *gw, T_out *gu)
|
|
440
|
+
{
|
|
441
|
+
assert(H*_N_ == C);
|
|
442
|
+
kernel_backward_101<<<dim3(B * H), dim3(_N_),_N_ * 2 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gr, gu);
|
|
443
|
+
kernel_backward_102<<<dim3(B * H), dim3(_N_),_N_ * 2 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gk);
|
|
444
|
+
kernel_backward_103<<<dim3(B * H), dim3(_N_),_N_ * 4 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gv);
|
|
445
|
+
kernel_backward_201<<<dim3(B * H), dim3(_N_),_N_ * 2 * sizeof(float),stream >>>(B, T, C, H, r, k, v, w, u, gy, gw);
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
namespace gpu_ops {
|
|
452
|
+
|
|
453
|
+
void rwkv_forward_fn(hipStream_t stream, void **buffers,
|
|
454
|
+
const char *opaque,
|
|
455
|
+
std::size_t opaque_len) {
|
|
456
|
+
const WKVDescriptor &d = *UnpackDescriptor<WKVDescriptor>(opaque, opaque_len);
|
|
457
|
+
|
|
458
|
+
DISPATCH_Vector_TYPES(
|
|
459
|
+
d.x_type, d.y_type, "rwkv_forward_kernel",
|
|
460
|
+
HostApplyRWKV<input_type, output_type>(
|
|
461
|
+
stream, d.B, d.T, d.C, d.H,
|
|
462
|
+
static_cast<input_type *>(buffers[0]),static_cast<input_type *>(buffers[1]),static_cast<input_type *>(buffers[2]),
|
|
463
|
+
static_cast<input_type *>(buffers[3]),static_cast<input_type *>(buffers[4]),static_cast<output_type *>(buffers[5])
|
|
464
|
+
);
|
|
465
|
+
)
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
void rwkv_forward_with_state_fn(hipStream_t stream, void **buffers,
|
|
470
|
+
const char *opaque,
|
|
471
|
+
std::size_t opaque_len) {
|
|
472
|
+
const WKVDescriptor &d = *UnpackDescriptor<WKVDescriptor>(opaque, opaque_len);
|
|
473
|
+
|
|
474
|
+
DISPATCH_Vector_TYPES(
|
|
475
|
+
d.x_type, d.y_type, "rwkv_forward_with_state_kernel",
|
|
476
|
+
if(d.S){
|
|
477
|
+
HostApplyRWKVWithState<input_type, output_type>(
|
|
478
|
+
stream, d.B, d.T, d.C, d.H, true, static_cast<int32_t *>(buffers[0])/*map*/,
|
|
479
|
+
static_cast<input_type *>(buffers[1])/*r*/,static_cast<input_type *>(buffers[2])/*k*/,static_cast<input_type *>(buffers[3])/*v*/,
|
|
480
|
+
static_cast<input_type *>(buffers[4])/*w*/,static_cast<input_type *>(buffers[5])/*u*/,static_cast<output_type *>(buffers[6])/*s*/,
|
|
481
|
+
static_cast<output_type *>(buffers[7])/*y*/,static_cast<output_type *>(buffers[8])/*ys*/
|
|
482
|
+
);
|
|
483
|
+
}else{
|
|
484
|
+
HostApplyRWKVWithState<input_type, output_type>(
|
|
485
|
+
stream, d.B, d.T, d.C, d.H, false, nullptr,
|
|
486
|
+
static_cast<input_type *>(buffers[0])/*r*/,static_cast<input_type *>(buffers[1])/*k*/,static_cast<input_type *>(buffers[2])/*v*/,
|
|
487
|
+
static_cast<input_type *>(buffers[3])/*w*/,static_cast<input_type *>(buffers[4])/*u*/,nullptr/*s*/,
|
|
488
|
+
static_cast<output_type *>(buffers[5])/*y*/,static_cast<output_type *>(buffers[6])/*ys*/
|
|
489
|
+
);
|
|
490
|
+
}
|
|
491
|
+
)
|
|
492
|
+
}
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
void rwkv_backward_fn(hipStream_t stream, void **buffers,
|
|
496
|
+
const char *opaque,
|
|
497
|
+
std::size_t opaque_len) {
|
|
498
|
+
const WKVDescriptor &d = *UnpackDescriptor<WKVDescriptor>(opaque, opaque_len);
|
|
499
|
+
|
|
500
|
+
DISPATCH_Vector_TYPES(
|
|
501
|
+
d.x_type, d.y_type, "rwkv_backward_kernel",
|
|
502
|
+
HostApplyGradient<input_type, output_type>(
|
|
503
|
+
stream, d.B, d.T, d.C, d.H,
|
|
504
|
+
static_cast<input_type *>(buffers[0]),static_cast<input_type *>(buffers[1]),static_cast<input_type *>(buffers[2]),
|
|
505
|
+
static_cast<input_type *>(buffers[3]),static_cast<input_type *>(buffers[4]),static_cast<output_type *>(buffers[5]),
|
|
506
|
+
static_cast<output_type *>(buffers[6]),static_cast<output_type *>(buffers[7]),static_cast<output_type *>(buffers[8]),
|
|
507
|
+
static_cast<output_type *>(buffers[9]),static_cast<output_type *>(buffers[10])
|
|
508
|
+
);
|
|
509
|
+
)
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
} // namespace gpu_ops
|
|
514
|
+
|