rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,722 @@
1
+ import os
2
+ import pybind11
3
+ import importlib
4
+ import sysconfig
5
+ import subprocess
6
+ from functools import partial, reduce
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ from jax import core, dtypes
11
+ from jax.core import ShapedArray
12
+ from jax.interpreters import mlir, xla
13
+ from jax.interpreters.mlir import ir
14
+ from jax.lib import xla_client
15
+ from jaxlib.hlo_helpers import custom_call
16
+
17
+
18
+ use_rocm = os.environ.get("RWKV_USE_ROCM", "0") == "1"
19
+ if use_rocm:
20
+ kernel_dir_name = "jax_kernel_hip"
21
+ cuda_lib_dir = "/opt/rocm/include"
22
+ kernel_name = "gpu_ops"
23
+ platform = "ROCM"
24
+ else:
25
+ kernel_dir_name = "jax_kernel_cuda"
26
+ cuda_lib_dir = "/usr/local/cuda/include"
27
+ kernel_name = "gpu_ops"
28
+ platform = "gpu"
29
+
30
+
31
+ def default_layouts(*shapes):
32
+ return [range(len(shape) - 1, -1, -1) for shape in shapes]
33
+
34
+
35
+ class RWKVKernelOperator:
36
+ def __init__(self, head_size, max_sequence_length):
37
+ """
38
+ 加载或构建rwkv内核
39
+ """
40
+ rwkv_kernel = RWKVKernelOperator._load_or_build_kernel(
41
+ head_size, max_sequence_length
42
+ )
43
+ """
44
+ 向mlir注册C++算子入口
45
+ """
46
+ for _name, _value in rwkv_kernel.get_rwkv_registrations().items():
47
+ xla_client.register_custom_call_target(_name, _value, platform=platform)
48
+
49
+ """
50
+ 构建jax前端数据类型到C++后端的映射
51
+ """
52
+
53
+ def element_type_to_descriptor_type_mapping(element_type):
54
+ _element_type_to_descriptor_type_mapping = {
55
+ ir.BF16Type.get(): rwkv_kernel.ElementType.BF16,
56
+ ir.F16Type.get(): rwkv_kernel.ElementType.F16,
57
+ ir.F32Type.get(): rwkv_kernel.ElementType.F32,
58
+ }
59
+ return _element_type_to_descriptor_type_mapping.get(element_type)
60
+
61
+ """
62
+ 定义前向过程算子
63
+ """
64
+
65
+ def build_forward_op():
66
+ _rwkv_fwd_p = core.Primitive("rwkv_fwd")
67
+ _rwkv_fwd_p.multiple_results = False
68
+ _rwkv_fwd_p.def_impl(partial(xla.apply_primitive, _rwkv_fwd_p))
69
+
70
+ """
71
+ 构建前向过程方法
72
+ """
73
+
74
+ def _rwkv_fwd(r, k, v, w, u):
75
+ y = _rwkv_fwd_p.bind(r, k, v, w, u)
76
+ ctx = r, k, v, w, u
77
+ return y, ctx
78
+
79
+ """
80
+ milr调用上面注册的前向过程算子,milr帮助算子创建输出张量
81
+ """
82
+
83
+ def _rwkv_fwd_cuda_lowering(ctx, r, k, v, w, u):
84
+ r_type = ir.RankedTensorType(r.type)
85
+ k_type = ir.RankedTensorType(k.type)
86
+ v_type = ir.RankedTensorType(v.type)
87
+ w_type = ir.RankedTensorType(w.type)
88
+ u_type = ir.RankedTensorType(u.type)
89
+ assert all(
90
+ [
91
+ r_type.element_type == xx.element_type
92
+ for xx in [k_type, v_type, w_type, u_type]
93
+ ]
94
+ )
95
+ assert all(
96
+ [r_type.shape == xx.shape for xx in [k_type, v_type, w_type]]
97
+ )
98
+ assert r_type.element_type in [
99
+ ir.F32Type.get(),
100
+ ir.BF16Type.get(),
101
+ ir.F16Type.get(),
102
+ ]
103
+ bz, seq_len, hd_sz = r_type.shape
104
+
105
+ assert hd_sz % head_size == 0
106
+ assert reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz, (
107
+ "the elements of u (time first) is not equal to hidden_size"
108
+ )
109
+ input_type = r_type.element_type
110
+
111
+ if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
112
+ output_type = input_type
113
+ else:
114
+ output_type = ir.F32Type.get()
115
+
116
+ opaque = rwkv_kernel.create_rwkv_descriptor(
117
+ bz,
118
+ seq_len,
119
+ hd_sz,
120
+ hd_sz // head_size,
121
+ False, # unused
122
+ element_type_to_descriptor_type_mapping(input_type),
123
+ element_type_to_descriptor_type_mapping(output_type),
124
+ )
125
+
126
+ out = custom_call(
127
+ b"wkv_forward",
128
+ result_types=[
129
+ ir.RankedTensorType.get(r_type.shape, output_type),
130
+ ],
131
+ operands=[r, k, v, w, u],
132
+ backend_config=opaque,
133
+ operand_layouts=default_layouts(
134
+ r_type.shape,
135
+ k_type.shape,
136
+ v_type.shape,
137
+ w_type.shape,
138
+ u_type.shape,
139
+ ),
140
+ result_layouts=default_layouts(r_type.shape),
141
+ ).results
142
+ return out
143
+
144
+ """
145
+ 将算子绑定到C++
146
+ """
147
+ mlir.register_lowering(
148
+ _rwkv_fwd_p,
149
+ _rwkv_fwd_cuda_lowering,
150
+ platform="gpu",
151
+ )
152
+ """
153
+ 定义抽象过程,告知jax输出张量形状与数据类型
154
+ """
155
+
156
+ def _rwkv_fwd_abstract(r, k, v, w, u):
157
+ assert all([r.shape == xx.shape for xx in [k, v, w]])
158
+ assert len(r.shape) == 3
159
+ bz, seq_len, channels = r.shape
160
+ assert channels % head_size == 0
161
+ assert seq_len <= max_sequence_length
162
+ assert reduce(lambda x, y: x * y, u.shape, 1) == channels, (
163
+ "the elements of u (time first) is not equal to hidden_size"
164
+ )
165
+
166
+ r_dtype = dtypes.canonicalize_dtype(r.dtype)
167
+ k_dtype = dtypes.canonicalize_dtype(k.dtype)
168
+ v_dtype = dtypes.canonicalize_dtype(v.dtype)
169
+ w_dtype = dtypes.canonicalize_dtype(w.dtype)
170
+ u_dtype = dtypes.canonicalize_dtype(u.dtype)
171
+
172
+ assert all(
173
+ [r_dtype == xx for xx in [k_dtype, v_dtype, w_dtype, u_dtype]]
174
+ )
175
+ assert r_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
176
+ if r_dtype in [jnp.float32, jnp.bfloat16]:
177
+ output_dtype = r_dtype
178
+ else:
179
+ output_dtype = jnp.float32
180
+ return ShapedArray(
181
+ r.shape, output_dtype, named_shape=r.named_shape
182
+ ) # output
183
+
184
+ _rwkv_fwd_p.def_abstract_eval(_rwkv_fwd_abstract)
185
+ return _rwkv_fwd, _rwkv_fwd_p
186
+
187
+ rwkv_fwd, _ = build_forward_op()
188
+
189
+ def build_backward_op():
190
+ """
191
+ 注册反向传播过程算子
192
+ """
193
+ _rwkv_bwd_p = core.Primitive("rwkv_bwd")
194
+ _rwkv_bwd_p.multiple_results = True
195
+ _rwkv_bwd_p.def_impl(partial(xla.apply_primitive, _rwkv_bwd_p))
196
+
197
+ """
198
+ 反向传播过程
199
+ """
200
+
201
+ def _rwkv_bwd(ctx, gy):
202
+ r, k, v, w, u = ctx
203
+ gr, gk, gv, gw, gu = _rwkv_bwd_p.bind(r, k, v, w, u, gy)
204
+ gu = jnp.sum(gu, axis=0)
205
+
206
+ return gr, gk, gv, gw, gu
207
+
208
+ """
209
+ 反向传播mlir后端
210
+ """
211
+
212
+ def _rwkv_bwd_cuda_lowering(ctx, r, k, v, w, u, gy):
213
+ r_type = ir.RankedTensorType(r.type)
214
+ k_type = ir.RankedTensorType(k.type)
215
+ v_type = ir.RankedTensorType(v.type)
216
+ w_type = ir.RankedTensorType(w.type)
217
+ u_type = ir.RankedTensorType(u.type)
218
+ gy_type = ir.RankedTensorType(gy.type)
219
+
220
+ assert all(
221
+ [
222
+ r_type.element_type == xx.element_type
223
+ for xx in [k_type, v_type, w_type, u_type]
224
+ ]
225
+ )
226
+ assert all(
227
+ [
228
+ r_type.shape == xx.shape
229
+ for xx in [k_type, v_type, w_type, gy_type]
230
+ ]
231
+ )
232
+ assert r_type.element_type in [
233
+ ir.F32Type.get(),
234
+ ir.BF16Type.get(),
235
+ ir.F16Type.get(),
236
+ ]
237
+ bz, seq_len, hd_sz = r_type.shape
238
+
239
+ assert hd_sz % head_size == 0
240
+ assert reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz, (
241
+ "the elements of u (time first) is not equal to hidden_size"
242
+ )
243
+ input_type = r_type.element_type
244
+
245
+ if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
246
+ output_type = input_type
247
+ else:
248
+ output_type = ir.F32Type.get()
249
+
250
+ assert output_type == gy_type.element_type
251
+
252
+ opaque = rwkv_kernel.create_rwkv_descriptor(
253
+ bz,
254
+ seq_len,
255
+ hd_sz,
256
+ hd_sz // head_size,
257
+ False, # unused
258
+ element_type_to_descriptor_type_mapping(input_type),
259
+ element_type_to_descriptor_type_mapping(output_type),
260
+ )
261
+
262
+ gu_shape = (bz,) + u_type.shape
263
+
264
+ out = custom_call(
265
+ b"wkv_backward",
266
+ result_types=[
267
+ ir.RankedTensorType.get(r_type.shape, output_type), # gr
268
+ ir.RankedTensorType.get(k_type.shape, output_type), # gk
269
+ ir.RankedTensorType.get(v_type.shape, output_type), # gw
270
+ ir.RankedTensorType.get(w_type.shape, output_type), # gw
271
+ ir.RankedTensorType.get(gu_shape, output_type), # gu
272
+ ],
273
+ operands=[r, k, v, w, u, gy],
274
+ backend_config=opaque,
275
+ operand_layouts=default_layouts(
276
+ r_type.shape,
277
+ k_type.shape,
278
+ v_type.shape,
279
+ w_type.shape,
280
+ u_type.shape,
281
+ gy_type.shape,
282
+ ),
283
+ result_layouts=default_layouts(
284
+ r_type.shape, k_type.shape, v_type.shape, w_type.shape, gu_shape
285
+ ),
286
+ ).results
287
+ return out
288
+
289
+ """
290
+ 注册反向传播算子到mlir
291
+ """
292
+ mlir.register_lowering(
293
+ _rwkv_bwd_p,
294
+ _rwkv_bwd_cuda_lowering,
295
+ platform="gpu",
296
+ )
297
+ """
298
+ 反向传播的抽象过程
299
+ """
300
+
301
+ def _rwkv_bwd_abstract(r, k, v, w, u, gy):
302
+ assert all([r.shape == xx.shape for xx in [k, v, w]])
303
+ assert len(r.shape) == 3
304
+ bz, seq_len, channels = r.shape
305
+ assert channels % head_size == 0
306
+ assert seq_len <= max_sequence_length
307
+ assert reduce(lambda x, y: x * y, u.shape, 1) == channels, (
308
+ "the elements of u (time first) is not equal to hidden_size"
309
+ )
310
+
311
+ r_dtype = dtypes.canonicalize_dtype(r.dtype)
312
+ k_dtype = dtypes.canonicalize_dtype(k.dtype)
313
+ v_dtype = dtypes.canonicalize_dtype(v.dtype)
314
+ w_dtype = dtypes.canonicalize_dtype(w.dtype)
315
+ u_dtype = dtypes.canonicalize_dtype(u.dtype)
316
+ gy_dtype = dtypes.canonicalize_dtype(gy.dtype)
317
+
318
+ u_shape = (bz,) + u_shape
319
+ assert all(
320
+ [r_dtype == xx for xx in [k_dtype, v_dtype, w_dtype, u_dtype]]
321
+ )
322
+ assert r_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
323
+ if r_dtype in [jnp.float32, jnp.bfloat16]:
324
+ output_dtype = r_dtype
325
+ else:
326
+ output_dtype = jnp.float32
327
+ assert output_dtype == gy_dtype
328
+
329
+ outputs = [
330
+ ShapedArray(r.shape, output_dtype, named_shape=r.named_shape),
331
+ ShapedArray(k.shape, output_dtype, named_shape=k.named_shape),
332
+ ShapedArray(v.shape, output_dtype, named_shape=v.named_shape),
333
+ ShapedArray(w.shape, output_dtype, named_shape=w.named_shape),
334
+ ShapedArray(u_shape, output_dtype, named_shape=u.named_shape),
335
+ ] # output
336
+ return outputs
337
+
338
+ _rwkv_bwd_p.def_abstract_eval(_rwkv_bwd_abstract)
339
+ return _rwkv_bwd, _rwkv_bwd_p
340
+
341
+ rwkv_bwd, _ = build_backward_op()
342
+ """
343
+ 组合算子
344
+ """
345
+
346
+ @jax.custom_vjp
347
+ def rwkv_normal_operator(r, k, v, w, u):
348
+ output, _ = rwkv_fwd(r, k, v, w, u)
349
+ return output
350
+
351
+ rwkv_normal_operator.defvjp(rwkv_fwd, rwkv_bwd)
352
+ self.rwkv_normal_operator = rwkv_normal_operator
353
+
354
+ def build_forward_with_state_op():
355
+ _rwkv_fwd_state_p = core.Primitive("rwkv_fwd_with_state")
356
+ _rwkv_fwd_state_p.multiple_results = True
357
+ _rwkv_fwd_state_p.def_impl(partial(xla.apply_primitive, _rwkv_fwd_state_p))
358
+
359
+ """
360
+ 构建带有状态前向过程方法
361
+ """
362
+
363
+ def _rwkv_fwd_with_state(r, k, v, w, u, init_state=None, state_map=None):
364
+ bz = r.shape[0]
365
+ if init_state is not None: # shape=(B,H,D,D)
366
+ assert len(init_state.shape) in [
367
+ 3,
368
+ 4,
369
+ ], "init_state的shape为(Batch_size,num_heads,head_size,head_size)"
370
+ if len(init_state.shape) == 3:
371
+ state_map = jnp.zeros((bz,), dtype=jnp.int32)
372
+
373
+ if state_map is None:
374
+ n_state = jnp.shape(init_state)[0]
375
+ B = jnp.shape(r)[0]
376
+ # print('ns:',n_state,'B:',B,r.shape,k.shape,v.shape)
377
+ assert n_state == 1 or n_state == B, (
378
+ "我无法为您推断state_map的形状,请手动指定。"
379
+ )
380
+ if n_state == 1:
381
+ state_map = jnp.array([0] * B, dtype=jnp.int32)
382
+ elif n_state == B:
383
+ state_map = jnp.array(
384
+ [i for i in range(B)], dtype=jnp.int32
385
+ )
386
+ else:
387
+ assert False, "未实现"
388
+ else:
389
+ # assert state_map is not None,"请传入一个state_map,这是一个int32类型的shape为(bz,)的数组,存放的是int_state到每一维度上的映射关系"
390
+ assert state_map.dtype in [
391
+ jnp.int64,
392
+ jnp.int32,
393
+ ], "state_map的数值类型必须为int32"
394
+ state_map = jnp.astype(state_map, jnp.int32)
395
+ assert jnp.all(state_map >= 0) and jnp.add(state_map < bz), (
396
+ f"state_map内为state的映射下标,因此范围为: [0,{bz})"
397
+ )
398
+ assert (init_state is None) == (state_map is None), (
399
+ "init_state与state_map必须同时传入"
400
+ )
401
+
402
+ if init_state is None:
403
+ y, s = _rwkv_fwd_state_p.bind(r, k, v, w, u)
404
+ else:
405
+ # state_map = jnp.astype(state_map,jnp.int32)
406
+ # print('sd',state_map.dtype)
407
+ y, s = _rwkv_fwd_state_p.bind(r, k, v, w, u, init_state, state_map)
408
+ return y, s
409
+
410
+ """
411
+ milr调用上面注册的前向过程算子,milr帮助算子创建输出张量
412
+ """
413
+
414
+ def _rwkv_fwd_with_state_cuda_lowering(
415
+ ctx, r, k, v, w, u, s=None, s_map=None
416
+ ):
417
+ r_type = ir.RankedTensorType(r.type)
418
+ k_type = ir.RankedTensorType(k.type)
419
+ v_type = ir.RankedTensorType(v.type)
420
+ w_type = ir.RankedTensorType(w.type)
421
+ u_type = ir.RankedTensorType(u.type)
422
+
423
+ with_init_state = s is not None
424
+
425
+ assert all(
426
+ [
427
+ r_type.element_type == xx.element_type
428
+ for xx in [k_type, v_type, w_type, u_type]
429
+ ]
430
+ )
431
+ assert all(
432
+ [r_type.shape == xx.shape for xx in [k_type, v_type, w_type]]
433
+ )
434
+ assert r_type.element_type in [
435
+ ir.F32Type.get(),
436
+ ir.BF16Type.get(),
437
+ ir.F16Type.get(),
438
+ ]
439
+ bz, seq_len, hd_sz = r_type.shape
440
+
441
+ assert hd_sz % head_size == 0
442
+ num_heads = hd_sz // head_size
443
+ assert reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz, (
444
+ "the elements of u (time first) is not equal to hidden_size"
445
+ )
446
+ input_type = r_type.element_type
447
+
448
+ if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
449
+ output_type = input_type
450
+ else:
451
+ output_type = ir.F32Type.get()
452
+ state_shape = (bz, num_heads, head_size, head_size)
453
+
454
+ if with_init_state:
455
+ assert s_map is not None, (
456
+ "您必须同时传入init_state与state_map 或者都赋值为None."
457
+ )
458
+
459
+ s_type = ir.RankedTensorType(s.type)
460
+ sm_type = ir.RankedTensorType(s_map.type)
461
+ # print(sm_type, ir.IntegerType.get_signless(64))
462
+ assert sm_type.element_type == ir.IntegerType.get_signless(32), (
463
+ "state_map的数据类型必须为int32"
464
+ )
465
+ # print(sm_type.shape,bz)
466
+ assert tuple(sm_type.shape) == (bz,), (
467
+ "state_map的shape 形状必须为(batch_size,)"
468
+ )
469
+
470
+ assert s_type.element_type == output_type
471
+ assert tuple(s_type.shape) == state_shape, (
472
+ "the shape of init state must be (batch_size,num_heads,head_size,head_size)"
473
+ )
474
+ # assert s_type.shape[0] == bz and reduce(lambda x,y: x * y, s_type.shape[1:],1) == head_size * hd_sz,"the shape of init state must be (batch_size,num_heads,head_size,head_size)"
475
+
476
+ opaque = rwkv_kernel.create_rwkv_descriptor(
477
+ bz,
478
+ seq_len,
479
+ hd_sz,
480
+ hd_sz // head_size,
481
+ with_init_state,
482
+ element_type_to_descriptor_type_mapping(input_type),
483
+ element_type_to_descriptor_type_mapping(output_type),
484
+ )
485
+
486
+ if with_init_state:
487
+ operands = [s_map, r, k, v, w, u, s]
488
+ operand_layouts = default_layouts(
489
+ sm_type.shape,
490
+ r_type.shape,
491
+ k_type.shape,
492
+ v_type.shape,
493
+ w_type.shape,
494
+ u_type.shape,
495
+ s_type.shape,
496
+ )
497
+ else:
498
+ operands = [r, k, v, w, u]
499
+ operand_layouts = default_layouts(
500
+ r_type.shape,
501
+ k_type.shape,
502
+ v_type.shape,
503
+ w_type.shape,
504
+ u_type.shape,
505
+ )
506
+
507
+ out = custom_call(
508
+ b"wkv_forward_with_state",
509
+ result_types=[
510
+ ir.RankedTensorType.get(r_type.shape, output_type),
511
+ ir.RankedTensorType.get(state_shape, output_type),
512
+ ],
513
+ operands=operands,
514
+ backend_config=opaque,
515
+ operand_layouts=operand_layouts,
516
+ result_layouts=default_layouts(r_type.shape, state_shape),
517
+ ).results
518
+ return out
519
+
520
+ """
521
+ 将算子绑定到C++
522
+ """
523
+ mlir.register_lowering(
524
+ _rwkv_fwd_state_p,
525
+ _rwkv_fwd_with_state_cuda_lowering,
526
+ platform="gpu",
527
+ )
528
+ """
529
+ 定义抽象过程,告知jax输出张量形状与数据类型
530
+ """
531
+
532
+ def _rwkv_fwd_with_state_abstract(r, k, v, w, u, s=None, s_map=None):
533
+ assert all([r.shape == xx.shape for xx in [k, v, w]])
534
+ assert len(r.shape) == 3
535
+ bz, seq_len, channels = r.shape
536
+ assert channels % head_size == 0
537
+ assert seq_len <= max_sequence_length
538
+ assert reduce(lambda x, y: x * y, u.shape, 1) == channels, (
539
+ "the elements of u (time first) is not equal to hidden_size"
540
+ )
541
+ num_heads = channels // head_size
542
+ r_dtype = dtypes.canonicalize_dtype(r.dtype)
543
+ k_dtype = dtypes.canonicalize_dtype(k.dtype)
544
+ v_dtype = dtypes.canonicalize_dtype(v.dtype)
545
+ w_dtype = dtypes.canonicalize_dtype(w.dtype)
546
+ u_dtype = dtypes.canonicalize_dtype(u.dtype)
547
+
548
+ assert all(
549
+ [r_dtype == xx for xx in [k_dtype, v_dtype, w_dtype, u_dtype]]
550
+ )
551
+ assert r_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
552
+ if r_dtype in [jnp.float32, jnp.bfloat16]:
553
+ output_dtype = r_dtype
554
+ else:
555
+ output_dtype = jnp.float32
556
+
557
+ state_shape = (bz, num_heads, head_size, head_size)
558
+ if s is not None:
559
+ s_dtype = dtypes.canonicalize_dtype(s.dtype)
560
+ assert s_dtype == output_dtype
561
+ assert s.shape == state_shape, (
562
+ "the shape of init_state must be (batch_size, seq_len, num_heads, head_size, head_size)"
563
+ )
564
+
565
+ return [
566
+ ShapedArray(
567
+ r.shape, output_dtype, named_shape=r.named_shape
568
+ ), # output
569
+ ShapedArray(state_shape, output_dtype, named_shape=r.named_shape),
570
+ ]
571
+
572
+ _rwkv_fwd_state_p.def_abstract_eval(_rwkv_fwd_with_state_abstract)
573
+ return _rwkv_fwd_with_state, _rwkv_fwd_state_p
574
+
575
+ self.rwkv_fwd_with_state, _ = build_forward_with_state_op()
576
+
577
+ def __call__(
578
+ self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
579
+ ):
580
+ if with_state:
581
+ return self.rwkv_fwd_with_state(r, k, v, w, u, init_state, state_map)
582
+ else:
583
+ return self.rwkv_normal_operator(r, k, v, w, u), None
584
+
585
+ @staticmethod
586
+ def _load_or_build_kernel(head_size, max_sequence_length):
587
+ assert head_size % 4 == 0, f"head size必须是4的倍数,而{head_size}显然不是."
588
+ assert isinstance(head_size, int), "你是在搞笑吗? head_size肯定得是int类型的啊"
589
+ assert isinstance(max_sequence_length, int), (
590
+ "你是在搞笑吗? max_sequence_length肯定得是int类型的啊"
591
+ )
592
+ assert head_size > 0 and max_sequence_length > 0, (
593
+ "难绷,head_size与max_sequence_length肯定得是大于0的正整数啊。"
594
+ )
595
+ assert os.path.exists(cuda_lib_dir) and len(os.listdir(cuda_lib_dir)) > 0, (
596
+ f"请检查{cuda_lib_dir}文件夹是否存在,这个文件本质是是您的cuda library的超链接。"
597
+ )
598
+ kernel_dir = os.path.abspath(
599
+ os.path.join(os.path.dirname(__file__), kernel_dir_name)
600
+ )
601
+ builds_dir = os.path.join(kernel_dir, "builds")
602
+ assert os.path.exists(kernel_dir), (
603
+ f"找不到{kernel_dir_name}文件夹,请问您的文件是完整的吗?"
604
+ )
605
+ if not os.path.exists(builds_dir):
606
+ os.mkdir(builds_dir)
607
+ target_dir_name = f"_N_{head_size}_T_{max_sequence_length}"
608
+ target_dir = os.path.join(builds_dir, target_dir_name)
609
+ if not os.path.exists(target_dir):
610
+ os.mkdir(target_dir)
611
+
612
+ def get_cflags():
613
+ getvar = sysconfig.get_config_var
614
+ flags = [
615
+ "-I" + sysconfig.get_path("include"),
616
+ "-I" + sysconfig.get_path("platinclude"),
617
+ ]
618
+
619
+ flags.extend(getvar("CFLAGS").split())
620
+ return " ".join(flags)
621
+
622
+ def get_suffix():
623
+ getvar = sysconfig.get_config_var
624
+ return getvar("EXT_SUFFIX")
625
+
626
+ build_cmds = []
627
+
628
+ # first, build cuda kernel
629
+ if use_rocm:
630
+ cu_src = os.path.join(kernel_dir, "rwkv_kernels.hip")
631
+ assert os.path.exists(cu_src)
632
+ cu_dst = os.path.join(target_dir, "rwkv_kernels.hip.o")
633
+ kernel_cmd = (
634
+ "hipcc -O3 --hipstdpar -xhip -fopenmp -ffast-math"
635
+ + " -munsafe-fp-atomics -enable-vectorize-compares"
636
+ + f" -I{cuda_lib_dir} -I{pybind11.get_include()}"
637
+ + " -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
638
+ + " --gpu-max-threads-per-block=120"
639
+ + f" -c {cu_src} -o {cu_dst} -D _N_={head_size} -D _T_={max_sequence_length}"
640
+ )
641
+ else:
642
+ cu_src = os.path.join(kernel_dir, "rwkv_kernels.cu")
643
+ assert os.path.exists(cu_src)
644
+ cu_dst = os.path.join(target_dir, "rwkv_kernels.cu.o")
645
+ kernel_cmd = (
646
+ "nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3"
647
+ + " --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86]"
648
+ + f" -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c {cu_src} -o {cu_dst} -D _N_={head_size} -D _T_={max_sequence_length}"
649
+ )
650
+ build_cmds.append(kernel_cmd)
651
+
652
+ so_dst = os.path.join(target_dir, f"{kernel_name}{get_suffix()}")
653
+ if not os.path.exists(so_dst):
654
+ # second, build C++ code.
655
+ cpp_src = os.path.join(kernel_dir, f"{kernel_name}.cpp")
656
+ cpp_dst = os.path.join(builds_dir, f"{kernel_name}.cpp.o")
657
+ if not os.path.exists(cpp_dst):
658
+ if use_rocm:
659
+ cpp_cmd = (
660
+ f"c++ -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
661
+ + " -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
662
+ + " -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects"
663
+ + f" -o {cpp_dst} -c {cpp_src}"
664
+ )
665
+ else:
666
+ cpp_cmd = (
667
+ f"c++ -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
668
+ + " -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects"
669
+ + f" -o {cpp_dst} -c {cpp_src}"
670
+ )
671
+ build_cmds.append(cpp_cmd)
672
+
673
+ # third assembly C++ and cuda
674
+ if use_rocm:
675
+ assembly_cmd = (
676
+ f"c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o {so_dst} {cpp_dst} {cu_dst}"
677
+ + f" -fPIC -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
678
+ + " -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
679
+ + " -L/opt/rocm/lib -lamdhip64 -lpthread -ldl"
680
+ )
681
+ else:
682
+ assembly_cmd = (
683
+ f"c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o {so_dst} {cpp_dst} {cu_dst}"
684
+ + " -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl"
685
+ )
686
+ build_cmds.append(assembly_cmd)
687
+
688
+ # finally strip the so library
689
+ strip_cmd = f"strip {so_dst}"
690
+ build_cmds.append(strip_cmd)
691
+
692
+ print("-------------------starting build kernel -------------------")
693
+ for cmd in build_cmds:
694
+ print("--------------- execute cmd ---------------")
695
+ print(cmd)
696
+ p = subprocess.Popen(cmd, shell=True)
697
+ p.wait()
698
+ print("-------------------build kernel finished -------------------")
699
+ kernel_module_path = f"rwkv_ops.rwkv6_kernel.{kernel_dir_name}.builds.{target_dir_name}.{kernel_name}"
700
+ print("loading cuda kernel from the path:", kernel_module_path)
701
+ rwkv_op = importlib.import_module(kernel_module_path)
702
+ return rwkv_op
703
+
704
+
705
+ if __name__ == "__main__":
706
+ bz, seq_len, hd_sz = 1, 16, 8
707
+ r = jnp.zeros(shape=(bz, seq_len, hd_sz), dtype="float16") + 2
708
+ k = jnp.zeros(shape=(bz, seq_len, hd_sz), dtype="float16") + 2
709
+ v = jnp.zeros(shape=(bz, seq_len, hd_sz), dtype="float16") + 2
710
+ w = jnp.zeros(shape=(bz, seq_len, hd_sz), dtype="float16") - 2
711
+ u = jnp.zeros(shape=(hd_sz,), dtype="float16") + 2
712
+ rwkv_op = RWKVKernelOperator(head_size=hd_sz, max_sequence_length=seq_len)
713
+ out = rwkv_op(r, k, v, w, u)
714
+
715
+ print(out.dtype)
716
+
717
+ def ref_loss(r, k, v, w, u):
718
+ predictions = rwkv_op(r, k, v, w, u)
719
+ return -jnp.mean(predictions**2)
720
+
721
+ ref_out = jax.grad(ref_loss, argnums=(0, 1, 2, 3, 4))(r, k, v, w, u)
722
+ print(ref_out)