rwkv-ops 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_ops/__init__.py +45 -0
- rwkv_ops/mhc_kernel/__init__.py +50 -0
- rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
- rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
- rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
- rwkv_ops/rwkv6_kernel/__init__.py +120 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +113 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.6.1.dist-info/METADATA +495 -0
- rwkv_ops-0.6.1.dist-info/RECORD +89 -0
- rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
- rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from torch.utils.cpp_extension import load
|
|
4
|
+
|
|
5
|
+
kernel_dir_name = "torch_kernel"
|
|
6
|
+
|
|
7
|
+
use_rocm = "RWKV_USE_ROCM" in os.environ and os.environ["RWKV_USE_ROCM"] == "1"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RWKVKernelOperator:
|
|
11
|
+
def __init__(self, head_size, max_sequence_length):
|
|
12
|
+
current_dir = os.path.dirname(__file__)
|
|
13
|
+
# current_dir = os.pat
|
|
14
|
+
if use_rocm:
|
|
15
|
+
wkv6_cuda = load(
|
|
16
|
+
name="wkv6",
|
|
17
|
+
sources=[
|
|
18
|
+
os.path.join(current_dir, f"{kernel_dir_name}/wkv6_op.cpp"),
|
|
19
|
+
os.path.join(current_dir, f"{kernel_dir_name}/wkv6_cuda.cu"),
|
|
20
|
+
],
|
|
21
|
+
# verbose=True, extra_cuda_cflags=[f"-D_N_={head_size}", f"-D_T_={max_sequence_length}"])
|
|
22
|
+
verbose=True,
|
|
23
|
+
extra_cuda_cflags=[
|
|
24
|
+
"-fopenmp -ffast-math -munsafe-fp-atomics --gpu-max-threads-per-block=120 -enable-vectorize-compares",
|
|
25
|
+
f"-D_N_={head_size}",
|
|
26
|
+
f"-D_T_={max_sequence_length}",
|
|
27
|
+
],
|
|
28
|
+
)
|
|
29
|
+
else:
|
|
30
|
+
wkv6_cuda = load(
|
|
31
|
+
name="wkv6",
|
|
32
|
+
sources=[
|
|
33
|
+
os.path.join(current_dir, f"{kernel_dir_name}/wkv6_op.cpp"),
|
|
34
|
+
os.path.join(current_dir, f"{kernel_dir_name}/wkv6_cuda.cu"),
|
|
35
|
+
],
|
|
36
|
+
# verbose=True, extra_cuda_cflags=[f"-D_N_={head_size}", f"-D_T_={max_sequence_length}"])
|
|
37
|
+
verbose=True,
|
|
38
|
+
extra_cuda_cflags=[
|
|
39
|
+
"-res-usage",
|
|
40
|
+
"--use_fast_math",
|
|
41
|
+
"-O3",
|
|
42
|
+
"-Xptxas -O3",
|
|
43
|
+
"--extra-device-vectorization",
|
|
44
|
+
f"-D_N_={head_size}",
|
|
45
|
+
f"-D_T_={max_sequence_length}",
|
|
46
|
+
],
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
class RWKV_6(torch.autograd.Function):
|
|
50
|
+
@staticmethod
|
|
51
|
+
def forward(ctx, B, T, C, H, r, k, v, w, u):
|
|
52
|
+
if not isinstance(u, torch.Tensor):
|
|
53
|
+
u = u.value
|
|
54
|
+
with torch.no_grad():
|
|
55
|
+
assert r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
|
|
56
|
+
assert r.dtype in [torch.float32, torch.bfloat16, torch.float16]
|
|
57
|
+
|
|
58
|
+
assert head_size == C // H
|
|
59
|
+
ctx.B = B
|
|
60
|
+
ctx.T = T
|
|
61
|
+
ctx.C = C
|
|
62
|
+
ctx.H = H
|
|
63
|
+
assert r.is_contiguous()
|
|
64
|
+
assert k.is_contiguous()
|
|
65
|
+
assert v.is_contiguous()
|
|
66
|
+
assert w.is_contiguous()
|
|
67
|
+
assert u.is_contiguous()
|
|
68
|
+
ctx.save_for_backward(r, k, v, w, u)
|
|
69
|
+
|
|
70
|
+
y_dtype = r.dtype if r.dtype != torch.float16 else torch.float32
|
|
71
|
+
|
|
72
|
+
y = torch.empty(
|
|
73
|
+
(B, T, C),
|
|
74
|
+
device=r.device,
|
|
75
|
+
dtype=y_dtype,
|
|
76
|
+
memory_format=torch.contiguous_format,
|
|
77
|
+
) # .uniform_(-100, 100)
|
|
78
|
+
|
|
79
|
+
if r.dtype == torch.float32:
|
|
80
|
+
wkv6_cuda.forward_fp32(B, T, C, H, r, k, v, w, u, y)
|
|
81
|
+
elif r.dtype == torch.bfloat16:
|
|
82
|
+
wkv6_cuda.forward_bf16(B, T, C, H, r, k, v, w, u, y)
|
|
83
|
+
else:
|
|
84
|
+
wkv6_cuda.forward_fp16(B, T, C, H, r, k, v, w, u, y)
|
|
85
|
+
return y
|
|
86
|
+
|
|
87
|
+
@staticmethod
|
|
88
|
+
def backward(ctx, gy):
|
|
89
|
+
assert gy.is_cuda
|
|
90
|
+
with torch.no_grad():
|
|
91
|
+
assert gy.dtype in [torch.bfloat16, torch.float32]
|
|
92
|
+
B = ctx.B
|
|
93
|
+
T = ctx.T
|
|
94
|
+
C = ctx.C
|
|
95
|
+
H = ctx.H
|
|
96
|
+
assert gy.is_contiguous()
|
|
97
|
+
r, k, v, w, u = ctx.saved_tensors
|
|
98
|
+
y_dtype = r.dtype if r.dtype != torch.float16 else torch.float32
|
|
99
|
+
|
|
100
|
+
gr = torch.empty(
|
|
101
|
+
(B, T, C),
|
|
102
|
+
device=gy.device,
|
|
103
|
+
requires_grad=False,
|
|
104
|
+
dtype=y_dtype,
|
|
105
|
+
memory_format=torch.contiguous_format,
|
|
106
|
+
) # .uniform_(-100, 100)
|
|
107
|
+
gk = torch.empty(
|
|
108
|
+
(B, T, C),
|
|
109
|
+
device=gy.device,
|
|
110
|
+
requires_grad=False,
|
|
111
|
+
dtype=y_dtype,
|
|
112
|
+
memory_format=torch.contiguous_format,
|
|
113
|
+
) # .uniform_(-100, 100)
|
|
114
|
+
gv = torch.empty(
|
|
115
|
+
(B, T, C),
|
|
116
|
+
device=gy.device,
|
|
117
|
+
requires_grad=False,
|
|
118
|
+
dtype=y_dtype,
|
|
119
|
+
memory_format=torch.contiguous_format,
|
|
120
|
+
) # .uniform_(-100, 100)
|
|
121
|
+
gw = torch.empty(
|
|
122
|
+
(B, T, C),
|
|
123
|
+
device=gy.device,
|
|
124
|
+
requires_grad=False,
|
|
125
|
+
dtype=y_dtype,
|
|
126
|
+
memory_format=torch.contiguous_format,
|
|
127
|
+
) # .uniform_(-100, 100)
|
|
128
|
+
gu = torch.empty(
|
|
129
|
+
(B, C),
|
|
130
|
+
device=gy.device,
|
|
131
|
+
requires_grad=False,
|
|
132
|
+
dtype=y_dtype,
|
|
133
|
+
memory_format=torch.contiguous_format,
|
|
134
|
+
) # .uniform_(-100, 100)
|
|
135
|
+
|
|
136
|
+
if r.dtype == torch.float32:
|
|
137
|
+
wkv6_cuda.backward_fp32(
|
|
138
|
+
B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
|
|
139
|
+
)
|
|
140
|
+
elif r.dtype == torch.bfloat16:
|
|
141
|
+
wkv6_cuda.backward_bf16(
|
|
142
|
+
B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
|
|
143
|
+
)
|
|
144
|
+
else:
|
|
145
|
+
wkv6_cuda.backward_fp16(
|
|
146
|
+
B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
gu = torch.sum(gu, 0).view(H, C // H)
|
|
150
|
+
|
|
151
|
+
return (None, None, None, None, gr, gk, gv, gw, gu)
|
|
152
|
+
|
|
153
|
+
class RWKV_6_with_state:
|
|
154
|
+
@staticmethod
|
|
155
|
+
def apply(B, T, C, H, S, s_map, r, k, v, w, u, s):
|
|
156
|
+
with torch.no_grad():
|
|
157
|
+
assert s_map.dtype == torch.int64, (
|
|
158
|
+
"s_map 必须为None 或者是长度为B的,int64类型的数组。"
|
|
159
|
+
)
|
|
160
|
+
assert (s is None and s_map is None) or (
|
|
161
|
+
s is not None and s_map is not None
|
|
162
|
+
), "init_state与s_map必须同时为None 或者同时不为None"
|
|
163
|
+
assert (
|
|
164
|
+
r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
|
|
165
|
+
and r.dtype in [torch.float16, torch.float32, torch.bfloat16]
|
|
166
|
+
), " r, k, v, w, u 必须为fp16 fp32 bf16中的一种 并且类型相同"
|
|
167
|
+
if r.dtype in [torch.float32, torch.bfloat16]:
|
|
168
|
+
o_dtype = r.dtype
|
|
169
|
+
else:
|
|
170
|
+
o_dtype = torch.float32
|
|
171
|
+
assert (
|
|
172
|
+
r.device
|
|
173
|
+
== k.device
|
|
174
|
+
== v.device
|
|
175
|
+
== w.device
|
|
176
|
+
== u.device
|
|
177
|
+
== s.device
|
|
178
|
+
== s_map.device
|
|
179
|
+
), "what kan i say? 请确保r k v w u s s_map在同一设备上,快去检查!"
|
|
180
|
+
|
|
181
|
+
y = torch.empty(
|
|
182
|
+
(B, T, C),
|
|
183
|
+
device=r.device,
|
|
184
|
+
dtype=o_dtype,
|
|
185
|
+
memory_format=torch.contiguous_format,
|
|
186
|
+
)
|
|
187
|
+
ys = torch.empty(
|
|
188
|
+
(B, H, head_size, head_size),
|
|
189
|
+
device=r.device,
|
|
190
|
+
dtype=o_dtype,
|
|
191
|
+
memory_format=torch.contiguous_format,
|
|
192
|
+
)
|
|
193
|
+
# print(ys)
|
|
194
|
+
if r.dtype == torch.bfloat16:
|
|
195
|
+
wkv6_cuda.forward_with_state_bf16(
|
|
196
|
+
B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
|
|
197
|
+
)
|
|
198
|
+
elif r.dtype == torch.float32:
|
|
199
|
+
wkv6_cuda.forward_with_state_fp32(
|
|
200
|
+
B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
wkv6_cuda.forward_with_state_fp16(
|
|
204
|
+
B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
return y, ys
|
|
208
|
+
|
|
209
|
+
self.head_size = head_size
|
|
210
|
+
self.normal_kernenl = RWKV_6
|
|
211
|
+
self.kernel_with_state = RWKV_6_with_state
|
|
212
|
+
|
|
213
|
+
def __call__(
|
|
214
|
+
self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
|
|
215
|
+
):
|
|
216
|
+
B, T, C = r.shape
|
|
217
|
+
assert C % self.head_size == 0
|
|
218
|
+
H = C // self.head_size
|
|
219
|
+
if not isinstance(u, torch.Tensor):
|
|
220
|
+
u = u.value
|
|
221
|
+
|
|
222
|
+
assert r.is_cuda
|
|
223
|
+
assert k.is_cuda
|
|
224
|
+
assert v.is_cuda
|
|
225
|
+
assert w.is_cuda
|
|
226
|
+
assert u.is_cuda
|
|
227
|
+
|
|
228
|
+
if isinstance(r, torch.Tensor):
|
|
229
|
+
assert r.device == k.device == v.device == w.device == u.device
|
|
230
|
+
else:
|
|
231
|
+
r.get_device() == k.get_device() == v.get_device() == w.get_device() == u.get_device()
|
|
232
|
+
|
|
233
|
+
assert r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
|
|
234
|
+
|
|
235
|
+
if r.dtype in [torch.float32, torch.bfloat16]:
|
|
236
|
+
s_dtype = r.dtype
|
|
237
|
+
else:
|
|
238
|
+
s_dtype = torch.float32
|
|
239
|
+
|
|
240
|
+
is_custom_init = init_state is not None
|
|
241
|
+
|
|
242
|
+
if init_state is not None:
|
|
243
|
+
assert len(init_state.shape) in [3, 4], (
|
|
244
|
+
"init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
|
|
245
|
+
)
|
|
246
|
+
if len(init_state.shape) == 3:
|
|
247
|
+
init_state = init_state[None, :]
|
|
248
|
+
assert (
|
|
249
|
+
init_state.shape[1:] == (H, self.head_size, self.head_size)
|
|
250
|
+
and init_state.shape[0] <= B
|
|
251
|
+
), (
|
|
252
|
+
"init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
assert init_state.dtype == s_dtype, f"init_state的数值类型应为: {s_dtype}"
|
|
256
|
+
assert init_state.device == r.device
|
|
257
|
+
|
|
258
|
+
if state_map is not None:
|
|
259
|
+
if isinstance(state_map, list):
|
|
260
|
+
state_map = torch.tensor(state_map, dtype=torch.int64)
|
|
261
|
+
elif isinstance(state_map, torch.Tensor):
|
|
262
|
+
assert state_map.dtype in [
|
|
263
|
+
torch.int32,
|
|
264
|
+
torch.int64,
|
|
265
|
+
], "state_map是一个长度为Batch_Size的int64类型的映射数组"
|
|
266
|
+
state_map = state_map.to(torch.int64)
|
|
267
|
+
assert state_map.shape == (B,), "state_map的shape必须为(Batch_Size,)"
|
|
268
|
+
assert state_map.device == r.deivec
|
|
269
|
+
|
|
270
|
+
if with_state:
|
|
271
|
+
if init_state is None:
|
|
272
|
+
assert state_map is None, (
|
|
273
|
+
"您必须在指定了init_state的情况下才能使用state_map"
|
|
274
|
+
)
|
|
275
|
+
init_state = torch.zeros((0,), device=r.device, dtype=s_dtype)
|
|
276
|
+
state_map = torch.zeros((0,), device=r.device, dtype=torch.int64)
|
|
277
|
+
else:
|
|
278
|
+
n_state = init_state.shape[0]
|
|
279
|
+
if state_map is None:
|
|
280
|
+
assert n_state == 1 or n_state == B, (
|
|
281
|
+
"我无法为您推断state_map的形状,请手动指定。"
|
|
282
|
+
)
|
|
283
|
+
if n_state == 1:
|
|
284
|
+
state_map = torch.tensor(
|
|
285
|
+
[0] * B, dtype=torch.int64, device=r.device
|
|
286
|
+
)
|
|
287
|
+
elif n_state == B:
|
|
288
|
+
state_map = torch.tensor(
|
|
289
|
+
[i for i in range(B)], dtype=torch.int64, device=r.device
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
assert False, "未实现"
|
|
293
|
+
else:
|
|
294
|
+
assert state_map.shape == (B,), "state_map的形状必须为(batch_size,)"
|
|
295
|
+
assert (state_map >= 0).all() and (state_map < n_state).all(), (
|
|
296
|
+
f"state_map的取值范围为[0,{n_state})之间的整数,您的输入显然不满足。"
|
|
297
|
+
)
|
|
298
|
+
# print('state map:',state_map)
|
|
299
|
+
o, ys = self.kernel_with_state.apply(
|
|
300
|
+
B, T, C, H, is_custom_init, state_map, r, k, v, w, u, init_state
|
|
301
|
+
)
|
|
302
|
+
return o, ys
|
|
303
|
+
else:
|
|
304
|
+
o = self.normal_kernenl.apply(B, T, C, H, r, k, v, w, u)
|
|
305
|
+
return o, None
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from distutils.util import strtobool
|
|
3
|
+
import os
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def transpose_head(x, head_first):
|
|
8
|
+
if head_first:
|
|
9
|
+
return ops.transpose(x, (0, 2, 1, 3))
|
|
10
|
+
else:
|
|
11
|
+
return x
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
15
|
+
assert HEAD_SIZE % 4 == 0
|
|
16
|
+
from .native_keras_op import generalized_delta_rule as native_op
|
|
17
|
+
|
|
18
|
+
if keras.config.backend() == "torch":
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
if torch.cuda.is_available():
|
|
22
|
+
if KERNEL_TYPE.lower() == "triton":
|
|
23
|
+
from .torch_op import generalized_delta_rule
|
|
24
|
+
|
|
25
|
+
return generalized_delta_rule, generalized_delta_rule, True
|
|
26
|
+
elif KERNEL_TYPE.lower() == "cuda":
|
|
27
|
+
from .torch_cuda_kernel.wkv7_torch import (
|
|
28
|
+
get_torch_generalized_delta_rule,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
return get_torch_generalized_delta_rule(HEAD_SIZE) + [False]
|
|
32
|
+
elif keras.config.backend() == "jax":
|
|
33
|
+
import jax
|
|
34
|
+
import os
|
|
35
|
+
|
|
36
|
+
if jax.devices()[0].platform == "gpu":
|
|
37
|
+
if KERNEL_TYPE.lower() == "triton":
|
|
38
|
+
os.environ["JAX_LOG_COMPUTATION"] = "0"
|
|
39
|
+
from .jax_op import generalized_delta_rule
|
|
40
|
+
|
|
41
|
+
return generalized_delta_rule, native_op, False
|
|
42
|
+
elif KERNEL_TYPE.lower() == "cuda":
|
|
43
|
+
from .jax_cuda_kernel.wkv7_jax import get_jax_generalized_delta_rule
|
|
44
|
+
|
|
45
|
+
return get_jax_generalized_delta_rule(HEAD_SIZE) + [False]
|
|
46
|
+
elif keras.config.backend() == "tensorflow":
|
|
47
|
+
import tensorflow as tf
|
|
48
|
+
|
|
49
|
+
if len(tf.config.list_physical_devices("GPU")) > 0:
|
|
50
|
+
if KERNEL_TYPE.lower() == "cuda" and HEAD_SIZE:
|
|
51
|
+
try:
|
|
52
|
+
from jax.lib import xla_bridge
|
|
53
|
+
|
|
54
|
+
assert xla_bridge.get_backend().platform == "gpu"
|
|
55
|
+
except:
|
|
56
|
+
raise (
|
|
57
|
+
"The operation of the TensorFlow kernel depends on the JAX kernel."
|
|
58
|
+
"Therefore, it is necessary to ensure that it can be used in JAX, so that TensorFlow can be used."
|
|
59
|
+
)
|
|
60
|
+
print("🎉" * 10)
|
|
61
|
+
print("Tensorflow CUDA kernel onlt support Forward,not get graident")
|
|
62
|
+
print("🎉" * 10)
|
|
63
|
+
from .tf_eager_kernel import get_tf_generalized_delta_rule
|
|
64
|
+
|
|
65
|
+
generalized_delta_rule_inference = get_tf_generalized_delta_rule(
|
|
66
|
+
HEAD_SIZE
|
|
67
|
+
)
|
|
68
|
+
return native_op, generalized_delta_rule_inference, False
|
|
69
|
+
elif keras.config.backend() == "mlx" and KERNEL_TYPE.lower() == "cuda":
|
|
70
|
+
from .mlx_op import generalized_delta_rule
|
|
71
|
+
|
|
72
|
+
return native_op, generalized_delta_rule, False
|
|
73
|
+
return native_op, native_op, False
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_rnn_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
77
|
+
assert HEAD_SIZE % 4 == 0
|
|
78
|
+
from .native_keras_op import generalized_delta_rule
|
|
79
|
+
|
|
80
|
+
if KERNEL_TYPE == "cuda":
|
|
81
|
+
if keras.config.backend() == "jax":
|
|
82
|
+
import jax
|
|
83
|
+
|
|
84
|
+
if jax.devices()[0].platform == "gpu":
|
|
85
|
+
from .jax_cuda_kernel_single.wkv7_single_step_jax import (
|
|
86
|
+
get_jax_generalized_delta_rule_single_step,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return get_jax_generalized_delta_rule_single_step(HEAD_SIZE)
|
|
90
|
+
elif keras.config.backend() == "torch":
|
|
91
|
+
import torch
|
|
92
|
+
|
|
93
|
+
if torch.cuda.is_available():
|
|
94
|
+
from .torch_cuda_kernel_single.wkv7_single_step_torch import (
|
|
95
|
+
get_torch_generalized_delta_rule_single_step,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return get_torch_generalized_delta_rule_single_step(HEAD_SIZE)
|
|
99
|
+
elif keras.config.backend() == "tensorflow":
|
|
100
|
+
import tensorflow as tf
|
|
101
|
+
|
|
102
|
+
if len(tf.config.list_physical_devices("GPU")) > 0:
|
|
103
|
+
try:
|
|
104
|
+
import jax
|
|
105
|
+
except ImportError:
|
|
106
|
+
return generalized_delta_rule
|
|
107
|
+
if jax.devices()[0].platform == "gpu":
|
|
108
|
+
from .tf_eager_kernel import (
|
|
109
|
+
get_tf_generalized_delta_rule_single_step,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
return get_tf_generalized_delta_rule_single_step(HEAD_SIZE)
|
|
113
|
+
return generalized_delta_rule
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from functools import lru_cache
|
|
3
|
+
from typing import Literal
|
|
4
|
+
import functools
|
|
5
|
+
import triton
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
from enum import Enum
|
|
9
|
+
import contextlib
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@lru_cache(maxsize=None)
|
|
13
|
+
def get_multiprocessor_count(tensor_idx: int = 0) -> int:
|
|
14
|
+
return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
|
|
15
|
+
"multiprocessor_count"
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@lru_cache(maxsize=None)
|
|
20
|
+
def get_available_device() -> str:
|
|
21
|
+
try:
|
|
22
|
+
return triton.runtime.driver.active.get_current_target().backend
|
|
23
|
+
except BaseException:
|
|
24
|
+
import warnings
|
|
25
|
+
|
|
26
|
+
warnings.warn(
|
|
27
|
+
("Triton is not supported on current platform, roll back to CPU."),
|
|
28
|
+
stacklevel=1,
|
|
29
|
+
)
|
|
30
|
+
return "cpu"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@lru_cache(maxsize=None)
|
|
34
|
+
def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
|
|
35
|
+
device = get_available_device()
|
|
36
|
+
if device == "cuda":
|
|
37
|
+
return "nvidia"
|
|
38
|
+
elif device == "hip":
|
|
39
|
+
return "amd"
|
|
40
|
+
elif device == "xpu":
|
|
41
|
+
return "intel"
|
|
42
|
+
else:
|
|
43
|
+
return device
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
|
47
|
+
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
|
48
|
+
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
|
49
|
+
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
|
50
|
+
|
|
51
|
+
device_platform = _check_platform()
|
|
52
|
+
|
|
53
|
+
is_intel = device_platform == "intel"
|
|
54
|
+
is_nvidia = device_platform == "nvidia"
|
|
55
|
+
is_amd = device_platform == "amd"
|
|
56
|
+
|
|
57
|
+
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
|
61
|
+
|
|
62
|
+
is_intel_a770 = False
|
|
63
|
+
device = jax.devices()
|
|
64
|
+
is_tf32_supported = is_nvidia
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_all_max_shared_memory():
|
|
68
|
+
return [
|
|
69
|
+
triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
|
|
70
|
+
for i in range(len(jax.devices()))
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
device_shared_mem_list = get_all_max_shared_memory()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@lru_cache(maxsize=None)
|
|
78
|
+
def is_triton_shared_mem_enough(
|
|
79
|
+
max_shared_mem: int = 102400, tensor_idx: int = 0
|
|
80
|
+
) -> bool:
|
|
81
|
+
max_shared_memory = device_shared_mem_list[tensor_idx]
|
|
82
|
+
return max_shared_memory >= max_shared_mem
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
device_capacity = is_triton_shared_mem_enough()
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _cpu_device_warning():
|
|
89
|
+
import warnings
|
|
90
|
+
|
|
91
|
+
warnings.warn(
|
|
92
|
+
("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def get_all_max_shared_mem():
|
|
97
|
+
try:
|
|
98
|
+
return [
|
|
99
|
+
triton.runtime.driver.active.utils.get_device_properties(i)[
|
|
100
|
+
"max_shared_mem"
|
|
101
|
+
]
|
|
102
|
+
for i in range(len(jax.devices()))
|
|
103
|
+
]
|
|
104
|
+
except BaseException:
|
|
105
|
+
_cpu_device_warning()
|
|
106
|
+
return [-1]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class Backend(Enum):
|
|
110
|
+
ADA = 101376 # RTX 4090
|
|
111
|
+
AMPERE = 166912 # A100
|
|
112
|
+
HOPPER = 232448 # H100
|
|
113
|
+
DEFAULT = 102400 # Default
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def get_shared_memory(cls, arch: str) -> int:
|
|
117
|
+
try:
|
|
118
|
+
return cls[arch.upper()].value
|
|
119
|
+
except KeyError:
|
|
120
|
+
return cls.DEFAULT.value
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@lru_cache(maxsize=None)
|
|
124
|
+
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
|
125
|
+
try:
|
|
126
|
+
device_shared_mem_list = get_all_max_shared_mem()
|
|
127
|
+
max_shared_memory = device_shared_mem_list[tensor_idx]
|
|
128
|
+
return max_shared_memory >= Backend.get_shared_memory(arch)
|
|
129
|
+
except Exception:
|
|
130
|
+
return False
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def tensor_cache(fn):
|
|
134
|
+
"""
|
|
135
|
+
A decorator that caches the most recent result of a function with tensor inputs.
|
|
136
|
+
|
|
137
|
+
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
|
138
|
+
If the function is called again with the same input tensors, it will return the cached result.
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
fn (Callable[..., jax.Array]):
|
|
143
|
+
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Callable[..., jax.Array]:
|
|
147
|
+
A wrapped version of the input function with single-entry caching.
|
|
148
|
+
"""
|
|
149
|
+
last_args = None
|
|
150
|
+
last_kwargs = None
|
|
151
|
+
last_result = None
|
|
152
|
+
|
|
153
|
+
@functools.wraps(fn)
|
|
154
|
+
def wrapper(*args, **kwargs):
|
|
155
|
+
nonlocal last_args, last_kwargs, last_result
|
|
156
|
+
|
|
157
|
+
if last_args is not None and last_kwargs is not None:
|
|
158
|
+
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
|
|
159
|
+
if all(a is b for a, b in zip(args, last_args)) and all(
|
|
160
|
+
k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
|
|
161
|
+
):
|
|
162
|
+
return last_result
|
|
163
|
+
|
|
164
|
+
result = fn(*args, **kwargs)
|
|
165
|
+
last_args, last_kwargs, last_result = args, kwargs, result
|
|
166
|
+
return result
|
|
167
|
+
|
|
168
|
+
return wrapper
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@tensor_cache
|
|
172
|
+
def prepare_lens(cu_seqlens):
|
|
173
|
+
return cu_seqlens[1:] - cu_seqlens[:-1]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@tensor_cache
|
|
177
|
+
def prepare_chunk_indices(cu_seqlens, chunk_size: int):
|
|
178
|
+
indices = jnp.concatenate(
|
|
179
|
+
[
|
|
180
|
+
jnp.arange(n)
|
|
181
|
+
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
|
182
|
+
]
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return jnp.stack([jnp.cumsum(jnp.equal(indices, 0), 0) - 1, indices], 1)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def input_guard(fn):
|
|
189
|
+
"""
|
|
190
|
+
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
@functools.wraps(fn)
|
|
194
|
+
def wrapper(*args, **kwargs):
|
|
195
|
+
contiguous_args = (i for i in args)
|
|
196
|
+
contiguous_kwargs = {k: v for k, v in kwargs.items()}
|
|
197
|
+
|
|
198
|
+
tensor = None
|
|
199
|
+
for arg in args:
|
|
200
|
+
if isinstance(arg, jax.Array):
|
|
201
|
+
tensor = arg
|
|
202
|
+
break
|
|
203
|
+
if tensor is None:
|
|
204
|
+
for value in kwargs.values():
|
|
205
|
+
if isinstance(value, jax.Array):
|
|
206
|
+
tensor = value
|
|
207
|
+
break
|
|
208
|
+
|
|
209
|
+
if tensor is not None:
|
|
210
|
+
ctx = tensor.device
|
|
211
|
+
else:
|
|
212
|
+
ctx = contextlib.nullcontext()
|
|
213
|
+
|
|
214
|
+
with ctx:
|
|
215
|
+
return fn(*contiguous_args, **contiguous_kwargs)
|
|
216
|
+
|
|
217
|
+
return wrapper
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
is_intel_alchemist = False
|