rwkv-ops 0.1.0__py3-none-any.whl → 0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- rwkv_ops/__init__.py +3 -0
- rwkv_ops/rwkv6_kernel/__init__.py +126 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +724 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +86 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +10 -7
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +2 -3
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +2 -2
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +3 -3
- rwkv_ops/rwkv7_kernel/jax_op.py +0 -2
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +1 -1
- rwkv_ops/rwkv7_kernel/torch_op.py +20 -53
- rwkv_ops-0.2.dist-info/METADATA +258 -0
- {rwkv_ops-0.1.0.dist-info → rwkv_ops-0.2.dist-info}/RECORD +17 -13
- rwkv_ops-0.1.0.dist-info/METADATA +0 -118
- {rwkv_ops-0.1.0.dist-info → rwkv_ops-0.2.dist-info}/LICENSE.txt +0 -0
- {rwkv_ops-0.1.0.dist-info → rwkv_ops-0.2.dist-info}/WHEEL +0 -0
- {rwkv_ops-0.1.0.dist-info → rwkv_ops-0.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,724 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pybind11
|
|
3
|
+
import importlib
|
|
4
|
+
import sys
|
|
5
|
+
import sysconfig
|
|
6
|
+
import subprocess
|
|
7
|
+
from functools import partial, reduce
|
|
8
|
+
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
from jax import core, dtypes
|
|
12
|
+
from jax.core import ShapedArray
|
|
13
|
+
from jax.experimental.custom_partitioning import custom_partitioning
|
|
14
|
+
from jax.experimental.pjit import pjit
|
|
15
|
+
from jax.interpreters import batching, mlir, xla
|
|
16
|
+
from jax.interpreters.mlir import ir
|
|
17
|
+
from jax.lib import xla_client
|
|
18
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
19
|
+
from jaxlib.hlo_helpers import custom_call
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
use_rocm = os.environ.get("RWKV_USE_ROCM", "0") == "1"
|
|
23
|
+
if use_rocm:
|
|
24
|
+
kernel_dir_name = "jax_kernel_hip"
|
|
25
|
+
cuda_lib_dir = "/opt/rocm/include"
|
|
26
|
+
kernel_name = "gpu_ops"
|
|
27
|
+
platform = "ROCM"
|
|
28
|
+
else:
|
|
29
|
+
kernel_dir_name = "jax_kernel_cuda"
|
|
30
|
+
cuda_lib_dir = "/usr/local/cuda/include"
|
|
31
|
+
kernel_name = "gpu_ops"
|
|
32
|
+
platform = "gpu"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def default_layouts(*shapes):
|
|
36
|
+
return [range(len(shape) - 1, -1, -1) for shape in shapes]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RWKVKernelOperator:
|
|
40
|
+
def __init__(self, head_size, max_sequence_length):
|
|
41
|
+
"""
|
|
42
|
+
加载或构建rwkv内核
|
|
43
|
+
"""
|
|
44
|
+
rwkv_kernel = RWKVKernelOperator._load_or_build_kernel(
|
|
45
|
+
head_size, max_sequence_length
|
|
46
|
+
)
|
|
47
|
+
"""
|
|
48
|
+
向mlir注册C++算子入口
|
|
49
|
+
"""
|
|
50
|
+
for _name, _value in rwkv_kernel.get_rwkv_registrations().items():
|
|
51
|
+
xla_client.register_custom_call_target(_name, _value, platform=platform)
|
|
52
|
+
|
|
53
|
+
"""
|
|
54
|
+
构建jax前端数据类型到C++后端的映射
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def element_type_to_descriptor_type_mapping(element_type):
|
|
58
|
+
_element_type_to_descriptor_type_mapping = {
|
|
59
|
+
ir.BF16Type.get(): rwkv_kernel.ElementType.BF16,
|
|
60
|
+
ir.F16Type.get(): rwkv_kernel.ElementType.F16,
|
|
61
|
+
ir.F32Type.get(): rwkv_kernel.ElementType.F32,
|
|
62
|
+
}
|
|
63
|
+
return _element_type_to_descriptor_type_mapping.get(element_type)
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
定义前向过程算子
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def build_forward_op():
|
|
70
|
+
_rwkv_fwd_p = core.Primitive("rwkv_fwd")
|
|
71
|
+
_rwkv_fwd_p.multiple_results = False
|
|
72
|
+
_rwkv_fwd_p.def_impl(partial(xla.apply_primitive, _rwkv_fwd_p))
|
|
73
|
+
|
|
74
|
+
"""
|
|
75
|
+
构建前向过程方法
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def _rwkv_fwd(r, k, v, w, u):
|
|
79
|
+
y = _rwkv_fwd_p.bind(r, k, v, w, u)
|
|
80
|
+
ctx = r, k, v, w, u
|
|
81
|
+
return y, ctx
|
|
82
|
+
|
|
83
|
+
"""
|
|
84
|
+
milr调用上面注册的前向过程算子,milr帮助算子创建输出张量
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def _rwkv_fwd_cuda_lowering(ctx, r, k, v, w, u):
|
|
88
|
+
r_type = ir.RankedTensorType(r.type)
|
|
89
|
+
k_type = ir.RankedTensorType(k.type)
|
|
90
|
+
v_type = ir.RankedTensorType(v.type)
|
|
91
|
+
w_type = ir.RankedTensorType(w.type)
|
|
92
|
+
u_type = ir.RankedTensorType(u.type)
|
|
93
|
+
assert all(
|
|
94
|
+
[
|
|
95
|
+
r_type.element_type == xx.element_type
|
|
96
|
+
for xx in [k_type, v_type, w_type, u_type]
|
|
97
|
+
]
|
|
98
|
+
)
|
|
99
|
+
assert all(
|
|
100
|
+
[r_type.shape == xx.shape for xx in [k_type, v_type, w_type]]
|
|
101
|
+
)
|
|
102
|
+
assert r_type.element_type in [
|
|
103
|
+
ir.F32Type.get(),
|
|
104
|
+
ir.BF16Type.get(),
|
|
105
|
+
ir.F16Type.get(),
|
|
106
|
+
]
|
|
107
|
+
bz, seq_len, hd_sz = r_type.shape
|
|
108
|
+
|
|
109
|
+
assert hd_sz % head_size == 0
|
|
110
|
+
assert reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz, (
|
|
111
|
+
"the elements of u (time first) is not equal to hidden_size"
|
|
112
|
+
)
|
|
113
|
+
input_type = r_type.element_type
|
|
114
|
+
|
|
115
|
+
if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
|
|
116
|
+
output_type = input_type
|
|
117
|
+
else:
|
|
118
|
+
output_type = ir.F32Type.get()
|
|
119
|
+
|
|
120
|
+
opaque = rwkv_kernel.create_rwkv_descriptor(
|
|
121
|
+
bz,
|
|
122
|
+
seq_len,
|
|
123
|
+
hd_sz,
|
|
124
|
+
hd_sz // head_size,
|
|
125
|
+
False, # unused
|
|
126
|
+
element_type_to_descriptor_type_mapping(input_type),
|
|
127
|
+
element_type_to_descriptor_type_mapping(output_type),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
out = custom_call(
|
|
131
|
+
b"wkv_forward",
|
|
132
|
+
result_types=[
|
|
133
|
+
ir.RankedTensorType.get(r_type.shape, output_type),
|
|
134
|
+
],
|
|
135
|
+
operands=[r, k, v, w, u],
|
|
136
|
+
backend_config=opaque,
|
|
137
|
+
operand_layouts=default_layouts(
|
|
138
|
+
r_type.shape,
|
|
139
|
+
k_type.shape,
|
|
140
|
+
v_type.shape,
|
|
141
|
+
w_type.shape,
|
|
142
|
+
u_type.shape,
|
|
143
|
+
),
|
|
144
|
+
result_layouts=default_layouts(r_type.shape),
|
|
145
|
+
).results
|
|
146
|
+
return out
|
|
147
|
+
|
|
148
|
+
"""
|
|
149
|
+
将算子绑定到C++
|
|
150
|
+
"""
|
|
151
|
+
mlir.register_lowering(
|
|
152
|
+
_rwkv_fwd_p,
|
|
153
|
+
_rwkv_fwd_cuda_lowering,
|
|
154
|
+
platform="gpu",
|
|
155
|
+
)
|
|
156
|
+
"""
|
|
157
|
+
定义抽象过程,告知jax输出张量形状与数据类型
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def _rwkv_fwd_abstract(r, k, v, w, u):
|
|
161
|
+
assert all([r.shape == xx.shape for xx in [k, v, w]])
|
|
162
|
+
assert len(r.shape) == 3
|
|
163
|
+
bz, seq_len, channels = r.shape
|
|
164
|
+
assert channels % head_size == 0
|
|
165
|
+
assert seq_len <= max_sequence_length
|
|
166
|
+
assert reduce(lambda x, y: x * y, u.shape, 1) == channels, (
|
|
167
|
+
"the elements of u (time first) is not equal to hidden_size"
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
r_dtype = dtypes.canonicalize_dtype(r.dtype)
|
|
171
|
+
k_dtype = dtypes.canonicalize_dtype(k.dtype)
|
|
172
|
+
v_dtype = dtypes.canonicalize_dtype(v.dtype)
|
|
173
|
+
w_dtype = dtypes.canonicalize_dtype(w.dtype)
|
|
174
|
+
u_dtype = dtypes.canonicalize_dtype(u.dtype)
|
|
175
|
+
|
|
176
|
+
assert all(
|
|
177
|
+
[r_dtype == xx for xx in [k_dtype, v_dtype, w_dtype, u_dtype]]
|
|
178
|
+
)
|
|
179
|
+
assert r_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
|
|
180
|
+
if r_dtype in [jnp.float32, jnp.bfloat16]:
|
|
181
|
+
output_dtype = r_dtype
|
|
182
|
+
else:
|
|
183
|
+
output_dtype = jnp.float32
|
|
184
|
+
return ShapedArray(
|
|
185
|
+
r.shape, output_dtype, named_shape=r.named_shape
|
|
186
|
+
) # output
|
|
187
|
+
|
|
188
|
+
_rwkv_fwd_p.def_abstract_eval(_rwkv_fwd_abstract)
|
|
189
|
+
return _rwkv_fwd, _rwkv_fwd_p
|
|
190
|
+
|
|
191
|
+
rwkv_fwd, _ = build_forward_op()
|
|
192
|
+
|
|
193
|
+
def build_backward_op():
|
|
194
|
+
"""
|
|
195
|
+
注册反向传播过程算子
|
|
196
|
+
"""
|
|
197
|
+
_rwkv_bwd_p = core.Primitive("rwkv_bwd")
|
|
198
|
+
_rwkv_bwd_p.multiple_results = True
|
|
199
|
+
_rwkv_bwd_p.def_impl(partial(xla.apply_primitive, _rwkv_bwd_p))
|
|
200
|
+
|
|
201
|
+
"""
|
|
202
|
+
反向传播过程
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
def _rwkv_bwd(ctx, gy):
|
|
206
|
+
r, k, v, w, u = ctx
|
|
207
|
+
gr, gk, gv, gw, gu = _rwkv_bwd_p.bind(r, k, v, w, u, gy)
|
|
208
|
+
gu = jnp.sum(gu, axis=0)
|
|
209
|
+
|
|
210
|
+
return gr, gk, gv, gw, gu
|
|
211
|
+
|
|
212
|
+
"""
|
|
213
|
+
反向传播mlir后端
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def _rwkv_bwd_cuda_lowering(ctx, r, k, v, w, u, gy):
|
|
217
|
+
r_type = ir.RankedTensorType(r.type)
|
|
218
|
+
k_type = ir.RankedTensorType(k.type)
|
|
219
|
+
v_type = ir.RankedTensorType(v.type)
|
|
220
|
+
w_type = ir.RankedTensorType(w.type)
|
|
221
|
+
u_type = ir.RankedTensorType(u.type)
|
|
222
|
+
gy_type = ir.RankedTensorType(gy.type)
|
|
223
|
+
|
|
224
|
+
assert all(
|
|
225
|
+
[
|
|
226
|
+
r_type.element_type == xx.element_type
|
|
227
|
+
for xx in [k_type, v_type, w_type, u_type]
|
|
228
|
+
]
|
|
229
|
+
)
|
|
230
|
+
assert all(
|
|
231
|
+
[
|
|
232
|
+
r_type.shape == xx.shape
|
|
233
|
+
for xx in [k_type, v_type, w_type, gy_type]
|
|
234
|
+
]
|
|
235
|
+
)
|
|
236
|
+
assert r_type.element_type in [
|
|
237
|
+
ir.F32Type.get(),
|
|
238
|
+
ir.BF16Type.get(),
|
|
239
|
+
ir.F16Type.get(),
|
|
240
|
+
]
|
|
241
|
+
bz, seq_len, hd_sz = r_type.shape
|
|
242
|
+
|
|
243
|
+
assert hd_sz % head_size == 0
|
|
244
|
+
assert reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz, (
|
|
245
|
+
"the elements of u (time first) is not equal to hidden_size"
|
|
246
|
+
)
|
|
247
|
+
input_type = r_type.element_type
|
|
248
|
+
|
|
249
|
+
if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
|
|
250
|
+
output_type = input_type
|
|
251
|
+
else:
|
|
252
|
+
output_type = ir.F32Type.get()
|
|
253
|
+
|
|
254
|
+
assert output_type == gy_type.element_type
|
|
255
|
+
|
|
256
|
+
opaque = rwkv_kernel.create_rwkv_descriptor(
|
|
257
|
+
bz,
|
|
258
|
+
seq_len,
|
|
259
|
+
hd_sz,
|
|
260
|
+
hd_sz // head_size,
|
|
261
|
+
False, # unused
|
|
262
|
+
element_type_to_descriptor_type_mapping(input_type),
|
|
263
|
+
element_type_to_descriptor_type_mapping(output_type),
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
gu_shape = (bz,) + u_type.shape
|
|
267
|
+
|
|
268
|
+
out = custom_call(
|
|
269
|
+
b"wkv_backward",
|
|
270
|
+
result_types=[
|
|
271
|
+
ir.RankedTensorType.get(r_type.shape, output_type), # gr
|
|
272
|
+
ir.RankedTensorType.get(k_type.shape, output_type), # gk
|
|
273
|
+
ir.RankedTensorType.get(v_type.shape, output_type), # gw
|
|
274
|
+
ir.RankedTensorType.get(w_type.shape, output_type), # gw
|
|
275
|
+
ir.RankedTensorType.get(gu_shape, output_type), # gu
|
|
276
|
+
],
|
|
277
|
+
operands=[r, k, v, w, u, gy],
|
|
278
|
+
backend_config=opaque,
|
|
279
|
+
operand_layouts=default_layouts(
|
|
280
|
+
r_type.shape,
|
|
281
|
+
k_type.shape,
|
|
282
|
+
v_type.shape,
|
|
283
|
+
w_type.shape,
|
|
284
|
+
u_type.shape,
|
|
285
|
+
gy_type.shape,
|
|
286
|
+
),
|
|
287
|
+
result_layouts=default_layouts(
|
|
288
|
+
r_type.shape, k_type.shape, v_type.shape, w_type.shape, gu_shape
|
|
289
|
+
),
|
|
290
|
+
).results
|
|
291
|
+
return out
|
|
292
|
+
|
|
293
|
+
"""
|
|
294
|
+
注册反向传播算子到mlir
|
|
295
|
+
"""
|
|
296
|
+
mlir.register_lowering(
|
|
297
|
+
_rwkv_bwd_p,
|
|
298
|
+
_rwkv_bwd_cuda_lowering,
|
|
299
|
+
platform="gpu",
|
|
300
|
+
)
|
|
301
|
+
"""
|
|
302
|
+
反向传播的抽象过程
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
def _rwkv_bwd_abstract(r, k, v, w, u, gy):
|
|
306
|
+
assert all([r.shape == xx.shape for xx in [k, v, w]])
|
|
307
|
+
assert len(r.shape) == 3
|
|
308
|
+
bz, seq_len, channels = r.shape
|
|
309
|
+
assert channels % head_size == 0
|
|
310
|
+
assert seq_len <= max_sequence_length
|
|
311
|
+
assert reduce(lambda x, y: x * y, u.shape, 1) == channels, (
|
|
312
|
+
"the elements of u (time first) is not equal to hidden_size"
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
r_dtype = dtypes.canonicalize_dtype(r.dtype)
|
|
316
|
+
k_dtype = dtypes.canonicalize_dtype(k.dtype)
|
|
317
|
+
v_dtype = dtypes.canonicalize_dtype(v.dtype)
|
|
318
|
+
w_dtype = dtypes.canonicalize_dtype(w.dtype)
|
|
319
|
+
u_dtype = dtypes.canonicalize_dtype(u.dtype)
|
|
320
|
+
gy_dtype = dtypes.canonicalize_dtype(gy.dtype)
|
|
321
|
+
|
|
322
|
+
u_shape = (bz,) + u_shape
|
|
323
|
+
assert all(
|
|
324
|
+
[r_dtype == xx for xx in [k_dtype, v_dtype, w_dtype, u_dtype]]
|
|
325
|
+
)
|
|
326
|
+
assert r_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
|
|
327
|
+
if r_dtype in [jnp.float32, jnp.bfloat16]:
|
|
328
|
+
output_dtype = r_dtype
|
|
329
|
+
else:
|
|
330
|
+
output_dtype = jnp.float32
|
|
331
|
+
assert output_dtype == gy_dtype
|
|
332
|
+
|
|
333
|
+
outputs = [
|
|
334
|
+
ShapedArray(r.shape, output_dtype, named_shape=r.named_shape),
|
|
335
|
+
ShapedArray(k.shape, output_dtype, named_shape=k.named_shape),
|
|
336
|
+
ShapedArray(v.shape, output_dtype, named_shape=v.named_shape),
|
|
337
|
+
ShapedArray(w.shape, output_dtype, named_shape=w.named_shape),
|
|
338
|
+
ShapedArray(u_shape, output_dtype, named_shape=u.named_shape),
|
|
339
|
+
] # output
|
|
340
|
+
return outputs
|
|
341
|
+
|
|
342
|
+
_rwkv_bwd_p.def_abstract_eval(_rwkv_bwd_abstract)
|
|
343
|
+
return _rwkv_bwd, _rwkv_bwd_p
|
|
344
|
+
|
|
345
|
+
rwkv_bwd, _ = build_backward_op()
|
|
346
|
+
"""
|
|
347
|
+
组合算子
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
@jax.custom_vjp
|
|
351
|
+
def rwkv_normal_operator(r, k, v, w, u):
|
|
352
|
+
output, _ = rwkv_fwd(r, k, v, w, u)
|
|
353
|
+
return output
|
|
354
|
+
|
|
355
|
+
rwkv_normal_operator.defvjp(rwkv_fwd, rwkv_bwd)
|
|
356
|
+
self.rwkv_normal_operator = rwkv_normal_operator
|
|
357
|
+
|
|
358
|
+
def build_forward_with_state_op():
|
|
359
|
+
_rwkv_fwd_state_p = core.Primitive("rwkv_fwd_with_state")
|
|
360
|
+
_rwkv_fwd_state_p.multiple_results = True
|
|
361
|
+
_rwkv_fwd_state_p.def_impl(partial(xla.apply_primitive, _rwkv_fwd_state_p))
|
|
362
|
+
|
|
363
|
+
"""
|
|
364
|
+
构建带有状态前向过程方法
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
def _rwkv_fwd_with_state(r, k, v, w, u, init_state=None, state_map=None):
|
|
368
|
+
bz = r.shape[0]
|
|
369
|
+
if init_state is not None: # shape=(B,H,D,D)
|
|
370
|
+
assert len(init_state.shape) in [3, 4], (
|
|
371
|
+
"init_state的shape为(Batch_size,num_heads,head_size,head_size)"
|
|
372
|
+
)
|
|
373
|
+
if len(init_state.shape) == 3:
|
|
374
|
+
state_map = jnp.zeros((bz,), dtype=jnp.int32)
|
|
375
|
+
|
|
376
|
+
if state_map is None:
|
|
377
|
+
n_state = jnp.shape(init_state)[0]
|
|
378
|
+
B = jnp.shape(r)[0]
|
|
379
|
+
# print('ns:',n_state,'B:',B,r.shape,k.shape,v.shape)
|
|
380
|
+
assert n_state == 1 or n_state == B, (
|
|
381
|
+
"我无法为您推断state_map的形状,请手动指定。"
|
|
382
|
+
)
|
|
383
|
+
if n_state == 1:
|
|
384
|
+
state_map = jnp.array([0] * B, dtype=jnp.int32)
|
|
385
|
+
elif n_state == B:
|
|
386
|
+
state_map = jnp.array(
|
|
387
|
+
[i for i in range(B)], dtype=jnp.int32
|
|
388
|
+
)
|
|
389
|
+
else:
|
|
390
|
+
assert False, "未实现"
|
|
391
|
+
else:
|
|
392
|
+
# assert state_map is not None,"请传入一个state_map,这是一个int32类型的shape为(bz,)的数组,存放的是int_state到每一维度上的映射关系"
|
|
393
|
+
assert state_map.dtype in [jnp.int64, jnp.int32], (
|
|
394
|
+
"state_map的数值类型必须为int32"
|
|
395
|
+
)
|
|
396
|
+
state_map = jnp.astype(state_map, jnp.int32)
|
|
397
|
+
assert jnp.all(state_map >= 0) and jnp.add(state_map < bz), (
|
|
398
|
+
f"state_map内为state的映射下标,因此范围为: [0,{bz})"
|
|
399
|
+
)
|
|
400
|
+
assert (init_state is None) == (state_map is None), (
|
|
401
|
+
"init_state与state_map必须同时传入"
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
if init_state is None:
|
|
405
|
+
y, s = _rwkv_fwd_state_p.bind(r, k, v, w, u)
|
|
406
|
+
else:
|
|
407
|
+
# state_map = jnp.astype(state_map,jnp.int32)
|
|
408
|
+
# print('sd',state_map.dtype)
|
|
409
|
+
y, s = _rwkv_fwd_state_p.bind(r, k, v, w, u, init_state, state_map)
|
|
410
|
+
return y, s
|
|
411
|
+
|
|
412
|
+
"""
|
|
413
|
+
milr调用上面注册的前向过程算子,milr帮助算子创建输出张量
|
|
414
|
+
"""
|
|
415
|
+
|
|
416
|
+
def _rwkv_fwd_with_state_cuda_lowering(
|
|
417
|
+
ctx, r, k, v, w, u, s=None, s_map=None
|
|
418
|
+
):
|
|
419
|
+
r_type = ir.RankedTensorType(r.type)
|
|
420
|
+
k_type = ir.RankedTensorType(k.type)
|
|
421
|
+
v_type = ir.RankedTensorType(v.type)
|
|
422
|
+
w_type = ir.RankedTensorType(w.type)
|
|
423
|
+
u_type = ir.RankedTensorType(u.type)
|
|
424
|
+
|
|
425
|
+
with_init_state = s is not None
|
|
426
|
+
|
|
427
|
+
assert all(
|
|
428
|
+
[
|
|
429
|
+
r_type.element_type == xx.element_type
|
|
430
|
+
for xx in [k_type, v_type, w_type, u_type]
|
|
431
|
+
]
|
|
432
|
+
)
|
|
433
|
+
assert all(
|
|
434
|
+
[r_type.shape == xx.shape for xx in [k_type, v_type, w_type]]
|
|
435
|
+
)
|
|
436
|
+
assert r_type.element_type in [
|
|
437
|
+
ir.F32Type.get(),
|
|
438
|
+
ir.BF16Type.get(),
|
|
439
|
+
ir.F16Type.get(),
|
|
440
|
+
]
|
|
441
|
+
bz, seq_len, hd_sz = r_type.shape
|
|
442
|
+
|
|
443
|
+
assert hd_sz % head_size == 0
|
|
444
|
+
num_heads = hd_sz // head_size
|
|
445
|
+
assert reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz, (
|
|
446
|
+
"the elements of u (time first) is not equal to hidden_size"
|
|
447
|
+
)
|
|
448
|
+
input_type = r_type.element_type
|
|
449
|
+
|
|
450
|
+
if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
|
|
451
|
+
output_type = input_type
|
|
452
|
+
else:
|
|
453
|
+
output_type = ir.F32Type.get()
|
|
454
|
+
state_shape = (bz, num_heads, head_size, head_size)
|
|
455
|
+
|
|
456
|
+
if with_init_state:
|
|
457
|
+
assert s_map is not None, (
|
|
458
|
+
"您必须同时传入init_state与state_map 或者都赋值为None."
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
s_type = ir.RankedTensorType(s.type)
|
|
462
|
+
sm_type = ir.RankedTensorType(s_map.type)
|
|
463
|
+
# print(sm_type, ir.IntegerType.get_signless(64))
|
|
464
|
+
assert sm_type.element_type == ir.IntegerType.get_signless(32), (
|
|
465
|
+
"state_map的数据类型必须为int32"
|
|
466
|
+
)
|
|
467
|
+
# print(sm_type.shape,bz)
|
|
468
|
+
assert tuple(sm_type.shape) == (bz,), (
|
|
469
|
+
"state_map的shape 形状必须为(batch_size,)"
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
assert s_type.element_type == output_type
|
|
473
|
+
assert tuple(s_type.shape) == state_shape, (
|
|
474
|
+
"the shape of init state must be (batch_size,num_heads,head_size,head_size)"
|
|
475
|
+
)
|
|
476
|
+
# assert s_type.shape[0] == bz and reduce(lambda x,y: x * y, s_type.shape[1:],1) == head_size * hd_sz,"the shape of init state must be (batch_size,num_heads,head_size,head_size)"
|
|
477
|
+
|
|
478
|
+
opaque = rwkv_kernel.create_rwkv_descriptor(
|
|
479
|
+
bz,
|
|
480
|
+
seq_len,
|
|
481
|
+
hd_sz,
|
|
482
|
+
hd_sz // head_size,
|
|
483
|
+
with_init_state,
|
|
484
|
+
element_type_to_descriptor_type_mapping(input_type),
|
|
485
|
+
element_type_to_descriptor_type_mapping(output_type),
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
if with_init_state:
|
|
489
|
+
operands = [s_map, r, k, v, w, u, s]
|
|
490
|
+
operand_layouts = default_layouts(
|
|
491
|
+
sm_type.shape,
|
|
492
|
+
r_type.shape,
|
|
493
|
+
k_type.shape,
|
|
494
|
+
v_type.shape,
|
|
495
|
+
w_type.shape,
|
|
496
|
+
u_type.shape,
|
|
497
|
+
s_type.shape,
|
|
498
|
+
)
|
|
499
|
+
else:
|
|
500
|
+
operands = [r, k, v, w, u]
|
|
501
|
+
operand_layouts = default_layouts(
|
|
502
|
+
r_type.shape,
|
|
503
|
+
k_type.shape,
|
|
504
|
+
v_type.shape,
|
|
505
|
+
w_type.shape,
|
|
506
|
+
u_type.shape,
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
out = custom_call(
|
|
510
|
+
b"wkv_forward_with_state",
|
|
511
|
+
result_types=[
|
|
512
|
+
ir.RankedTensorType.get(r_type.shape, output_type),
|
|
513
|
+
ir.RankedTensorType.get(state_shape, output_type),
|
|
514
|
+
],
|
|
515
|
+
operands=operands,
|
|
516
|
+
backend_config=opaque,
|
|
517
|
+
operand_layouts=operand_layouts,
|
|
518
|
+
result_layouts=default_layouts(r_type.shape, state_shape),
|
|
519
|
+
).results
|
|
520
|
+
return out
|
|
521
|
+
|
|
522
|
+
"""
|
|
523
|
+
将算子绑定到C++
|
|
524
|
+
"""
|
|
525
|
+
mlir.register_lowering(
|
|
526
|
+
_rwkv_fwd_state_p,
|
|
527
|
+
_rwkv_fwd_with_state_cuda_lowering,
|
|
528
|
+
platform="gpu",
|
|
529
|
+
)
|
|
530
|
+
"""
|
|
531
|
+
定义抽象过程,告知jax输出张量形状与数据类型
|
|
532
|
+
"""
|
|
533
|
+
|
|
534
|
+
def _rwkv_fwd_with_state_abstract(r, k, v, w, u, s=None, s_map=None):
|
|
535
|
+
assert all([r.shape == xx.shape for xx in [k, v, w]])
|
|
536
|
+
assert len(r.shape) == 3
|
|
537
|
+
bz, seq_len, channels = r.shape
|
|
538
|
+
assert channels % head_size == 0
|
|
539
|
+
assert seq_len <= max_sequence_length
|
|
540
|
+
assert reduce(lambda x, y: x * y, u.shape, 1) == channels, (
|
|
541
|
+
"the elements of u (time first) is not equal to hidden_size"
|
|
542
|
+
)
|
|
543
|
+
num_heads = channels // head_size
|
|
544
|
+
r_dtype = dtypes.canonicalize_dtype(r.dtype)
|
|
545
|
+
k_dtype = dtypes.canonicalize_dtype(k.dtype)
|
|
546
|
+
v_dtype = dtypes.canonicalize_dtype(v.dtype)
|
|
547
|
+
w_dtype = dtypes.canonicalize_dtype(w.dtype)
|
|
548
|
+
u_dtype = dtypes.canonicalize_dtype(u.dtype)
|
|
549
|
+
|
|
550
|
+
assert all(
|
|
551
|
+
[r_dtype == xx for xx in [k_dtype, v_dtype, w_dtype, u_dtype]]
|
|
552
|
+
)
|
|
553
|
+
assert r_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
|
|
554
|
+
if r_dtype in [jnp.float32, jnp.bfloat16]:
|
|
555
|
+
output_dtype = r_dtype
|
|
556
|
+
else:
|
|
557
|
+
output_dtype = jnp.float32
|
|
558
|
+
|
|
559
|
+
state_shape = (bz, num_heads, head_size, head_size)
|
|
560
|
+
if s is not None:
|
|
561
|
+
s_dtype = dtypes.canonicalize_dtype(s.dtype)
|
|
562
|
+
assert s_dtype == output_dtype
|
|
563
|
+
assert s.shape == state_shape, (
|
|
564
|
+
"the shape of init_state must be (batch_size, seq_len, num_heads, head_size, head_size)"
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
return [
|
|
568
|
+
ShapedArray(
|
|
569
|
+
r.shape, output_dtype, named_shape=r.named_shape
|
|
570
|
+
), # output
|
|
571
|
+
ShapedArray(state_shape, output_dtype, named_shape=r.named_shape),
|
|
572
|
+
]
|
|
573
|
+
|
|
574
|
+
_rwkv_fwd_state_p.def_abstract_eval(_rwkv_fwd_with_state_abstract)
|
|
575
|
+
return _rwkv_fwd_with_state, _rwkv_fwd_state_p
|
|
576
|
+
|
|
577
|
+
self.rwkv_fwd_with_state, _ = build_forward_with_state_op()
|
|
578
|
+
|
|
579
|
+
def __call__(
|
|
580
|
+
self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
|
|
581
|
+
):
|
|
582
|
+
if with_state:
|
|
583
|
+
return self.rwkv_fwd_with_state(r, k, v, w, u, init_state, state_map)
|
|
584
|
+
else:
|
|
585
|
+
return self.rwkv_normal_operator(r, k, v, w, u), None
|
|
586
|
+
|
|
587
|
+
@staticmethod
|
|
588
|
+
def _load_or_build_kernel(head_size, max_sequence_length):
|
|
589
|
+
assert head_size % 4 == 0, f"head size必须是4的倍数,而{head_size}显然不是."
|
|
590
|
+
assert isinstance(head_size, int), "你是在搞笑吗? head_size肯定得是int类型的啊"
|
|
591
|
+
assert isinstance(max_sequence_length, int), (
|
|
592
|
+
"你是在搞笑吗? max_sequence_length肯定得是int类型的啊"
|
|
593
|
+
)
|
|
594
|
+
assert head_size > 0 and max_sequence_length > 0, (
|
|
595
|
+
"难绷,head_size与max_sequence_length肯定得是大于0的正整数啊。"
|
|
596
|
+
)
|
|
597
|
+
assert os.path.exists(cuda_lib_dir) and len(os.listdir(cuda_lib_dir)) > 0, (
|
|
598
|
+
f"请检查{cuda_lib_dir}文件夹是否存在,这个文件本质是是您的cuda library的超链接。"
|
|
599
|
+
)
|
|
600
|
+
kernel_dir = os.path.abspath(
|
|
601
|
+
os.path.join(os.path.dirname(__file__), kernel_dir_name)
|
|
602
|
+
)
|
|
603
|
+
builds_dir = os.path.join(kernel_dir, "builds")
|
|
604
|
+
assert os.path.exists(kernel_dir), (
|
|
605
|
+
f"找不到{kernel_dir_name}文件夹,请问您的文件是完整的吗?"
|
|
606
|
+
)
|
|
607
|
+
if not os.path.exists(builds_dir):
|
|
608
|
+
os.mkdir(builds_dir)
|
|
609
|
+
target_dir_name = f"_N_{head_size}_T_{max_sequence_length}"
|
|
610
|
+
target_dir = os.path.join(builds_dir, target_dir_name)
|
|
611
|
+
if not os.path.exists(target_dir):
|
|
612
|
+
os.mkdir(target_dir)
|
|
613
|
+
|
|
614
|
+
def get_cflags():
|
|
615
|
+
getvar = sysconfig.get_config_var
|
|
616
|
+
flags = [
|
|
617
|
+
"-I" + sysconfig.get_path("include"),
|
|
618
|
+
"-I" + sysconfig.get_path("platinclude"),
|
|
619
|
+
]
|
|
620
|
+
|
|
621
|
+
flags.extend(getvar("CFLAGS").split())
|
|
622
|
+
return " ".join(flags)
|
|
623
|
+
|
|
624
|
+
def get_suffix():
|
|
625
|
+
getvar = sysconfig.get_config_var
|
|
626
|
+
return getvar("EXT_SUFFIX")
|
|
627
|
+
|
|
628
|
+
build_cmds = []
|
|
629
|
+
|
|
630
|
+
# first, build cuda kernel
|
|
631
|
+
if use_rocm:
|
|
632
|
+
cu_src = os.path.join(kernel_dir, "rwkv_kernels.hip")
|
|
633
|
+
assert os.path.exists(cu_src)
|
|
634
|
+
cu_dst = os.path.join(target_dir, "rwkv_kernels.hip.o")
|
|
635
|
+
kernel_cmd = (
|
|
636
|
+
f"hipcc -O3 --hipstdpar -xhip -fopenmp -ffast-math"
|
|
637
|
+
+ f" -munsafe-fp-atomics -enable-vectorize-compares"
|
|
638
|
+
+ f" -I{cuda_lib_dir} -I{pybind11.get_include()}"
|
|
639
|
+
+ f" -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
|
|
640
|
+
+ f" --gpu-max-threads-per-block=120"
|
|
641
|
+
+ f" -c {cu_src} -o {cu_dst} -D _N_={head_size} -D _T_={max_sequence_length}"
|
|
642
|
+
)
|
|
643
|
+
else:
|
|
644
|
+
cu_src = os.path.join(kernel_dir, "rwkv_kernels.cu")
|
|
645
|
+
assert os.path.exists(cu_src)
|
|
646
|
+
cu_dst = os.path.join(target_dir, "rwkv_kernels.cu.o")
|
|
647
|
+
kernel_cmd = (
|
|
648
|
+
f"nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3"
|
|
649
|
+
+ f" --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86]"
|
|
650
|
+
+ f" -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c {cu_src} -o {cu_dst} -D _N_={head_size} -D _T_={max_sequence_length}"
|
|
651
|
+
)
|
|
652
|
+
build_cmds.append(kernel_cmd)
|
|
653
|
+
|
|
654
|
+
so_dst = os.path.join(target_dir, f"{kernel_name}{get_suffix()}")
|
|
655
|
+
if not os.path.exists(so_dst):
|
|
656
|
+
# second, build C++ code.
|
|
657
|
+
cpp_src = os.path.join(kernel_dir, f"{kernel_name}.cpp")
|
|
658
|
+
cpp_dst = os.path.join(builds_dir, f"{kernel_name}.cpp.o")
|
|
659
|
+
if not os.path.exists(cpp_dst):
|
|
660
|
+
if use_rocm:
|
|
661
|
+
cpp_cmd = (
|
|
662
|
+
f"c++ -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
|
|
663
|
+
+ f" -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
|
|
664
|
+
+ f" -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects"
|
|
665
|
+
+ f" -o {cpp_dst} -c {cpp_src}"
|
|
666
|
+
)
|
|
667
|
+
else:
|
|
668
|
+
cpp_cmd = (
|
|
669
|
+
f"c++ -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
|
|
670
|
+
+ f" -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects"
|
|
671
|
+
+ f" -o {cpp_dst} -c {cpp_src}"
|
|
672
|
+
)
|
|
673
|
+
build_cmds.append(cpp_cmd)
|
|
674
|
+
|
|
675
|
+
# third assembly C++ and cuda
|
|
676
|
+
if use_rocm:
|
|
677
|
+
assembly_cmd = (
|
|
678
|
+
f"c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o {so_dst} {cpp_dst} {cu_dst}"
|
|
679
|
+
+ f" -fPIC -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
|
|
680
|
+
+ f" -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
|
|
681
|
+
+ f" -L/opt/rocm/lib -lamdhip64 -lpthread -ldl"
|
|
682
|
+
)
|
|
683
|
+
else:
|
|
684
|
+
assembly_cmd = (
|
|
685
|
+
f"c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o {so_dst} {cpp_dst} {cu_dst}"
|
|
686
|
+
+ f" -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl"
|
|
687
|
+
)
|
|
688
|
+
build_cmds.append(assembly_cmd)
|
|
689
|
+
|
|
690
|
+
# finally strip the so library
|
|
691
|
+
strip_cmd = f"strip {so_dst}"
|
|
692
|
+
build_cmds.append(strip_cmd)
|
|
693
|
+
|
|
694
|
+
print("-------------------starting build kernel -------------------")
|
|
695
|
+
for cmd in build_cmds:
|
|
696
|
+
print("--------------- execute cmd ---------------")
|
|
697
|
+
print(cmd)
|
|
698
|
+
p = subprocess.Popen(cmd, shell=True)
|
|
699
|
+
p.wait()
|
|
700
|
+
print("-------------------build kernel finished -------------------")
|
|
701
|
+
kernel_module_path = f"rwkv_ops.rwkv6_kernel.{kernel_dir_name}.builds.{target_dir_name}.{kernel_name}"
|
|
702
|
+
print("loading cuda kernel from the path:", kernel_module_path)
|
|
703
|
+
rwkv_op = importlib.import_module(kernel_module_path)
|
|
704
|
+
return rwkv_op
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
if __name__ == "__main__":
|
|
708
|
+
bz, seq_len, hd_sz = 1, 16, 8
|
|
709
|
+
r = jnp.zeros(shape=(bz, seq_len, hd_sz), dtype="float16") + 2
|
|
710
|
+
k = jnp.zeros(shape=(bz, seq_len, hd_sz), dtype="float16") + 2
|
|
711
|
+
v = jnp.zeros(shape=(bz, seq_len, hd_sz), dtype="float16") + 2
|
|
712
|
+
w = jnp.zeros(shape=(bz, seq_len, hd_sz), dtype="float16") - 2
|
|
713
|
+
u = jnp.zeros(shape=(hd_sz,), dtype="float16") + 2
|
|
714
|
+
rwkv_op = RWKVKernelOperator(head_size=hd_sz, max_sequence_length=seq_len)
|
|
715
|
+
out = rwkv_op(r, k, v, w, u)
|
|
716
|
+
|
|
717
|
+
print(out.dtype)
|
|
718
|
+
|
|
719
|
+
def ref_loss(r, k, v, w, u):
|
|
720
|
+
predictions = rwkv_op(r, k, v, w, u)
|
|
721
|
+
return -jnp.mean(predictions**2)
|
|
722
|
+
|
|
723
|
+
ref_out = jax.grad(ref_loss, argnums=(0, 1, 2, 3, 4))(r, k, v, w, u)
|
|
724
|
+
print(ref_out)
|