rwkv-ops 0.1.1__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

@@ -0,0 +1,724 @@
1
+ import os
2
+ import pybind11
3
+ import importlib
4
+ import sys
5
+ import sysconfig
6
+ import subprocess
7
+ from functools import partial, reduce
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ from jax import core, dtypes
12
+ from jax.core import ShapedArray
13
+ from jax.experimental.custom_partitioning import custom_partitioning
14
+ from jax.experimental.pjit import pjit
15
+ from jax.interpreters import batching, mlir, xla
16
+ from jax.interpreters.mlir import ir
17
+ from jax.lib import xla_client
18
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
19
+ from jaxlib.hlo_helpers import custom_call
20
+
21
+
22
+ use_rocm = os.environ.get("RWKV_USE_ROCM", "0") == "1"
23
+ if use_rocm:
24
+ kernel_dir_name = "jax_kernel_hip"
25
+ cuda_lib_dir = "/opt/rocm/include"
26
+ kernel_name = "gpu_ops"
27
+ platform = "ROCM"
28
+ else:
29
+ kernel_dir_name = "jax_kernel_cuda"
30
+ cuda_lib_dir = "/usr/local/cuda/include"
31
+ kernel_name = "gpu_ops"
32
+ platform = "gpu"
33
+
34
+
35
+ def default_layouts(*shapes):
36
+ return [range(len(shape) - 1, -1, -1) for shape in shapes]
37
+
38
+
39
+ class RWKVKernelOperator:
40
+ def __init__(self, head_size, max_sequence_length):
41
+ """
42
+ 加载或构建rwkv内核
43
+ """
44
+ rwkv_kernel = RWKVKernelOperator._load_or_build_kernel(
45
+ head_size, max_sequence_length
46
+ )
47
+ """
48
+ 向mlir注册C++算子入口
49
+ """
50
+ for _name, _value in rwkv_kernel.get_rwkv_registrations().items():
51
+ xla_client.register_custom_call_target(_name, _value, platform=platform)
52
+
53
+ """
54
+ 构建jax前端数据类型到C++后端的映射
55
+ """
56
+
57
+ def element_type_to_descriptor_type_mapping(element_type):
58
+ _element_type_to_descriptor_type_mapping = {
59
+ ir.BF16Type.get(): rwkv_kernel.ElementType.BF16,
60
+ ir.F16Type.get(): rwkv_kernel.ElementType.F16,
61
+ ir.F32Type.get(): rwkv_kernel.ElementType.F32,
62
+ }
63
+ return _element_type_to_descriptor_type_mapping.get(element_type)
64
+
65
+ """
66
+ 定义前向过程算子
67
+ """
68
+
69
+ def build_forward_op():
70
+ _rwkv_fwd_p = core.Primitive("rwkv_fwd")
71
+ _rwkv_fwd_p.multiple_results = False
72
+ _rwkv_fwd_p.def_impl(partial(xla.apply_primitive, _rwkv_fwd_p))
73
+
74
+ """
75
+ 构建前向过程方法
76
+ """
77
+
78
+ def _rwkv_fwd(r, k, v, w, u):
79
+ y = _rwkv_fwd_p.bind(r, k, v, w, u)
80
+ ctx = r, k, v, w, u
81
+ return y, ctx
82
+
83
+ """
84
+ milr调用上面注册的前向过程算子,milr帮助算子创建输出张量
85
+ """
86
+
87
+ def _rwkv_fwd_cuda_lowering(ctx, r, k, v, w, u):
88
+ r_type = ir.RankedTensorType(r.type)
89
+ k_type = ir.RankedTensorType(k.type)
90
+ v_type = ir.RankedTensorType(v.type)
91
+ w_type = ir.RankedTensorType(w.type)
92
+ u_type = ir.RankedTensorType(u.type)
93
+ assert all(
94
+ [
95
+ r_type.element_type == xx.element_type
96
+ for xx in [k_type, v_type, w_type, u_type]
97
+ ]
98
+ )
99
+ assert all(
100
+ [r_type.shape == xx.shape for xx in [k_type, v_type, w_type]]
101
+ )
102
+ assert r_type.element_type in [
103
+ ir.F32Type.get(),
104
+ ir.BF16Type.get(),
105
+ ir.F16Type.get(),
106
+ ]
107
+ bz, seq_len, hd_sz = r_type.shape
108
+
109
+ assert hd_sz % head_size == 0
110
+ assert reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz, (
111
+ "the elements of u (time first) is not equal to hidden_size"
112
+ )
113
+ input_type = r_type.element_type
114
+
115
+ if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
116
+ output_type = input_type
117
+ else:
118
+ output_type = ir.F32Type.get()
119
+
120
+ opaque = rwkv_kernel.create_rwkv_descriptor(
121
+ bz,
122
+ seq_len,
123
+ hd_sz,
124
+ hd_sz // head_size,
125
+ False, # unused
126
+ element_type_to_descriptor_type_mapping(input_type),
127
+ element_type_to_descriptor_type_mapping(output_type),
128
+ )
129
+
130
+ out = custom_call(
131
+ b"wkv_forward",
132
+ result_types=[
133
+ ir.RankedTensorType.get(r_type.shape, output_type),
134
+ ],
135
+ operands=[r, k, v, w, u],
136
+ backend_config=opaque,
137
+ operand_layouts=default_layouts(
138
+ r_type.shape,
139
+ k_type.shape,
140
+ v_type.shape,
141
+ w_type.shape,
142
+ u_type.shape,
143
+ ),
144
+ result_layouts=default_layouts(r_type.shape),
145
+ ).results
146
+ return out
147
+
148
+ """
149
+ 将算子绑定到C++
150
+ """
151
+ mlir.register_lowering(
152
+ _rwkv_fwd_p,
153
+ _rwkv_fwd_cuda_lowering,
154
+ platform="gpu",
155
+ )
156
+ """
157
+ 定义抽象过程,告知jax输出张量形状与数据类型
158
+ """
159
+
160
+ def _rwkv_fwd_abstract(r, k, v, w, u):
161
+ assert all([r.shape == xx.shape for xx in [k, v, w]])
162
+ assert len(r.shape) == 3
163
+ bz, seq_len, channels = r.shape
164
+ assert channels % head_size == 0
165
+ assert seq_len <= max_sequence_length
166
+ assert reduce(lambda x, y: x * y, u.shape, 1) == channels, (
167
+ "the elements of u (time first) is not equal to hidden_size"
168
+ )
169
+
170
+ r_dtype = dtypes.canonicalize_dtype(r.dtype)
171
+ k_dtype = dtypes.canonicalize_dtype(k.dtype)
172
+ v_dtype = dtypes.canonicalize_dtype(v.dtype)
173
+ w_dtype = dtypes.canonicalize_dtype(w.dtype)
174
+ u_dtype = dtypes.canonicalize_dtype(u.dtype)
175
+
176
+ assert all(
177
+ [r_dtype == xx for xx in [k_dtype, v_dtype, w_dtype, u_dtype]]
178
+ )
179
+ assert r_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
180
+ if r_dtype in [jnp.float32, jnp.bfloat16]:
181
+ output_dtype = r_dtype
182
+ else:
183
+ output_dtype = jnp.float32
184
+ return ShapedArray(
185
+ r.shape, output_dtype, named_shape=r.named_shape
186
+ ) # output
187
+
188
+ _rwkv_fwd_p.def_abstract_eval(_rwkv_fwd_abstract)
189
+ return _rwkv_fwd, _rwkv_fwd_p
190
+
191
+ rwkv_fwd, _ = build_forward_op()
192
+
193
+ def build_backward_op():
194
+ """
195
+ 注册反向传播过程算子
196
+ """
197
+ _rwkv_bwd_p = core.Primitive("rwkv_bwd")
198
+ _rwkv_bwd_p.multiple_results = True
199
+ _rwkv_bwd_p.def_impl(partial(xla.apply_primitive, _rwkv_bwd_p))
200
+
201
+ """
202
+ 反向传播过程
203
+ """
204
+
205
+ def _rwkv_bwd(ctx, gy):
206
+ r, k, v, w, u = ctx
207
+ gr, gk, gv, gw, gu = _rwkv_bwd_p.bind(r, k, v, w, u, gy)
208
+ gu = jnp.sum(gu, axis=0)
209
+
210
+ return gr, gk, gv, gw, gu
211
+
212
+ """
213
+ 反向传播mlir后端
214
+ """
215
+
216
+ def _rwkv_bwd_cuda_lowering(ctx, r, k, v, w, u, gy):
217
+ r_type = ir.RankedTensorType(r.type)
218
+ k_type = ir.RankedTensorType(k.type)
219
+ v_type = ir.RankedTensorType(v.type)
220
+ w_type = ir.RankedTensorType(w.type)
221
+ u_type = ir.RankedTensorType(u.type)
222
+ gy_type = ir.RankedTensorType(gy.type)
223
+
224
+ assert all(
225
+ [
226
+ r_type.element_type == xx.element_type
227
+ for xx in [k_type, v_type, w_type, u_type]
228
+ ]
229
+ )
230
+ assert all(
231
+ [
232
+ r_type.shape == xx.shape
233
+ for xx in [k_type, v_type, w_type, gy_type]
234
+ ]
235
+ )
236
+ assert r_type.element_type in [
237
+ ir.F32Type.get(),
238
+ ir.BF16Type.get(),
239
+ ir.F16Type.get(),
240
+ ]
241
+ bz, seq_len, hd_sz = r_type.shape
242
+
243
+ assert hd_sz % head_size == 0
244
+ assert reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz, (
245
+ "the elements of u (time first) is not equal to hidden_size"
246
+ )
247
+ input_type = r_type.element_type
248
+
249
+ if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
250
+ output_type = input_type
251
+ else:
252
+ output_type = ir.F32Type.get()
253
+
254
+ assert output_type == gy_type.element_type
255
+
256
+ opaque = rwkv_kernel.create_rwkv_descriptor(
257
+ bz,
258
+ seq_len,
259
+ hd_sz,
260
+ hd_sz // head_size,
261
+ False, # unused
262
+ element_type_to_descriptor_type_mapping(input_type),
263
+ element_type_to_descriptor_type_mapping(output_type),
264
+ )
265
+
266
+ gu_shape = (bz,) + u_type.shape
267
+
268
+ out = custom_call(
269
+ b"wkv_backward",
270
+ result_types=[
271
+ ir.RankedTensorType.get(r_type.shape, output_type), # gr
272
+ ir.RankedTensorType.get(k_type.shape, output_type), # gk
273
+ ir.RankedTensorType.get(v_type.shape, output_type), # gw
274
+ ir.RankedTensorType.get(w_type.shape, output_type), # gw
275
+ ir.RankedTensorType.get(gu_shape, output_type), # gu
276
+ ],
277
+ operands=[r, k, v, w, u, gy],
278
+ backend_config=opaque,
279
+ operand_layouts=default_layouts(
280
+ r_type.shape,
281
+ k_type.shape,
282
+ v_type.shape,
283
+ w_type.shape,
284
+ u_type.shape,
285
+ gy_type.shape,
286
+ ),
287
+ result_layouts=default_layouts(
288
+ r_type.shape, k_type.shape, v_type.shape, w_type.shape, gu_shape
289
+ ),
290
+ ).results
291
+ return out
292
+
293
+ """
294
+ 注册反向传播算子到mlir
295
+ """
296
+ mlir.register_lowering(
297
+ _rwkv_bwd_p,
298
+ _rwkv_bwd_cuda_lowering,
299
+ platform="gpu",
300
+ )
301
+ """
302
+ 反向传播的抽象过程
303
+ """
304
+
305
+ def _rwkv_bwd_abstract(r, k, v, w, u, gy):
306
+ assert all([r.shape == xx.shape for xx in [k, v, w]])
307
+ assert len(r.shape) == 3
308
+ bz, seq_len, channels = r.shape
309
+ assert channels % head_size == 0
310
+ assert seq_len <= max_sequence_length
311
+ assert reduce(lambda x, y: x * y, u.shape, 1) == channels, (
312
+ "the elements of u (time first) is not equal to hidden_size"
313
+ )
314
+
315
+ r_dtype = dtypes.canonicalize_dtype(r.dtype)
316
+ k_dtype = dtypes.canonicalize_dtype(k.dtype)
317
+ v_dtype = dtypes.canonicalize_dtype(v.dtype)
318
+ w_dtype = dtypes.canonicalize_dtype(w.dtype)
319
+ u_dtype = dtypes.canonicalize_dtype(u.dtype)
320
+ gy_dtype = dtypes.canonicalize_dtype(gy.dtype)
321
+
322
+ u_shape = (bz,) + u_shape
323
+ assert all(
324
+ [r_dtype == xx for xx in [k_dtype, v_dtype, w_dtype, u_dtype]]
325
+ )
326
+ assert r_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
327
+ if r_dtype in [jnp.float32, jnp.bfloat16]:
328
+ output_dtype = r_dtype
329
+ else:
330
+ output_dtype = jnp.float32
331
+ assert output_dtype == gy_dtype
332
+
333
+ outputs = [
334
+ ShapedArray(r.shape, output_dtype, named_shape=r.named_shape),
335
+ ShapedArray(k.shape, output_dtype, named_shape=k.named_shape),
336
+ ShapedArray(v.shape, output_dtype, named_shape=v.named_shape),
337
+ ShapedArray(w.shape, output_dtype, named_shape=w.named_shape),
338
+ ShapedArray(u_shape, output_dtype, named_shape=u.named_shape),
339
+ ] # output
340
+ return outputs
341
+
342
+ _rwkv_bwd_p.def_abstract_eval(_rwkv_bwd_abstract)
343
+ return _rwkv_bwd, _rwkv_bwd_p
344
+
345
+ rwkv_bwd, _ = build_backward_op()
346
+ """
347
+ 组合算子
348
+ """
349
+
350
+ @jax.custom_vjp
351
+ def rwkv_normal_operator(r, k, v, w, u):
352
+ output, _ = rwkv_fwd(r, k, v, w, u)
353
+ return output
354
+
355
+ rwkv_normal_operator.defvjp(rwkv_fwd, rwkv_bwd)
356
+ self.rwkv_normal_operator = rwkv_normal_operator
357
+
358
+ def build_forward_with_state_op():
359
+ _rwkv_fwd_state_p = core.Primitive("rwkv_fwd_with_state")
360
+ _rwkv_fwd_state_p.multiple_results = True
361
+ _rwkv_fwd_state_p.def_impl(partial(xla.apply_primitive, _rwkv_fwd_state_p))
362
+
363
+ """
364
+ 构建带有状态前向过程方法
365
+ """
366
+
367
+ def _rwkv_fwd_with_state(r, k, v, w, u, init_state=None, state_map=None):
368
+ bz = r.shape[0]
369
+ if init_state is not None: # shape=(B,H,D,D)
370
+ assert len(init_state.shape) in [3, 4], (
371
+ "init_state的shape为(Batch_size,num_heads,head_size,head_size)"
372
+ )
373
+ if len(init_state.shape) == 3:
374
+ state_map = jnp.zeros((bz,), dtype=jnp.int32)
375
+
376
+ if state_map is None:
377
+ n_state = jnp.shape(init_state)[0]
378
+ B = jnp.shape(r)[0]
379
+ # print('ns:',n_state,'B:',B,r.shape,k.shape,v.shape)
380
+ assert n_state == 1 or n_state == B, (
381
+ "我无法为您推断state_map的形状,请手动指定。"
382
+ )
383
+ if n_state == 1:
384
+ state_map = jnp.array([0] * B, dtype=jnp.int32)
385
+ elif n_state == B:
386
+ state_map = jnp.array(
387
+ [i for i in range(B)], dtype=jnp.int32
388
+ )
389
+ else:
390
+ assert False, "未实现"
391
+ else:
392
+ # assert state_map is not None,"请传入一个state_map,这是一个int32类型的shape为(bz,)的数组,存放的是int_state到每一维度上的映射关系"
393
+ assert state_map.dtype in [jnp.int64, jnp.int32], (
394
+ "state_map的数值类型必须为int32"
395
+ )
396
+ state_map = jnp.astype(state_map, jnp.int32)
397
+ assert jnp.all(state_map >= 0) and jnp.add(state_map < bz), (
398
+ f"state_map内为state的映射下标,因此范围为: [0,{bz})"
399
+ )
400
+ assert (init_state is None) == (state_map is None), (
401
+ "init_state与state_map必须同时传入"
402
+ )
403
+
404
+ if init_state is None:
405
+ y, s = _rwkv_fwd_state_p.bind(r, k, v, w, u)
406
+ else:
407
+ # state_map = jnp.astype(state_map,jnp.int32)
408
+ # print('sd',state_map.dtype)
409
+ y, s = _rwkv_fwd_state_p.bind(r, k, v, w, u, init_state, state_map)
410
+ return y, s
411
+
412
+ """
413
+ milr调用上面注册的前向过程算子,milr帮助算子创建输出张量
414
+ """
415
+
416
+ def _rwkv_fwd_with_state_cuda_lowering(
417
+ ctx, r, k, v, w, u, s=None, s_map=None
418
+ ):
419
+ r_type = ir.RankedTensorType(r.type)
420
+ k_type = ir.RankedTensorType(k.type)
421
+ v_type = ir.RankedTensorType(v.type)
422
+ w_type = ir.RankedTensorType(w.type)
423
+ u_type = ir.RankedTensorType(u.type)
424
+
425
+ with_init_state = s is not None
426
+
427
+ assert all(
428
+ [
429
+ r_type.element_type == xx.element_type
430
+ for xx in [k_type, v_type, w_type, u_type]
431
+ ]
432
+ )
433
+ assert all(
434
+ [r_type.shape == xx.shape for xx in [k_type, v_type, w_type]]
435
+ )
436
+ assert r_type.element_type in [
437
+ ir.F32Type.get(),
438
+ ir.BF16Type.get(),
439
+ ir.F16Type.get(),
440
+ ]
441
+ bz, seq_len, hd_sz = r_type.shape
442
+
443
+ assert hd_sz % head_size == 0
444
+ num_heads = hd_sz // head_size
445
+ assert reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz, (
446
+ "the elements of u (time first) is not equal to hidden_size"
447
+ )
448
+ input_type = r_type.element_type
449
+
450
+ if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
451
+ output_type = input_type
452
+ else:
453
+ output_type = ir.F32Type.get()
454
+ state_shape = (bz, num_heads, head_size, head_size)
455
+
456
+ if with_init_state:
457
+ assert s_map is not None, (
458
+ "您必须同时传入init_state与state_map 或者都赋值为None."
459
+ )
460
+
461
+ s_type = ir.RankedTensorType(s.type)
462
+ sm_type = ir.RankedTensorType(s_map.type)
463
+ # print(sm_type, ir.IntegerType.get_signless(64))
464
+ assert sm_type.element_type == ir.IntegerType.get_signless(32), (
465
+ "state_map的数据类型必须为int32"
466
+ )
467
+ # print(sm_type.shape,bz)
468
+ assert tuple(sm_type.shape) == (bz,), (
469
+ "state_map的shape 形状必须为(batch_size,)"
470
+ )
471
+
472
+ assert s_type.element_type == output_type
473
+ assert tuple(s_type.shape) == state_shape, (
474
+ "the shape of init state must be (batch_size,num_heads,head_size,head_size)"
475
+ )
476
+ # assert s_type.shape[0] == bz and reduce(lambda x,y: x * y, s_type.shape[1:],1) == head_size * hd_sz,"the shape of init state must be (batch_size,num_heads,head_size,head_size)"
477
+
478
+ opaque = rwkv_kernel.create_rwkv_descriptor(
479
+ bz,
480
+ seq_len,
481
+ hd_sz,
482
+ hd_sz // head_size,
483
+ with_init_state,
484
+ element_type_to_descriptor_type_mapping(input_type),
485
+ element_type_to_descriptor_type_mapping(output_type),
486
+ )
487
+
488
+ if with_init_state:
489
+ operands = [s_map, r, k, v, w, u, s]
490
+ operand_layouts = default_layouts(
491
+ sm_type.shape,
492
+ r_type.shape,
493
+ k_type.shape,
494
+ v_type.shape,
495
+ w_type.shape,
496
+ u_type.shape,
497
+ s_type.shape,
498
+ )
499
+ else:
500
+ operands = [r, k, v, w, u]
501
+ operand_layouts = default_layouts(
502
+ r_type.shape,
503
+ k_type.shape,
504
+ v_type.shape,
505
+ w_type.shape,
506
+ u_type.shape,
507
+ )
508
+
509
+ out = custom_call(
510
+ b"wkv_forward_with_state",
511
+ result_types=[
512
+ ir.RankedTensorType.get(r_type.shape, output_type),
513
+ ir.RankedTensorType.get(state_shape, output_type),
514
+ ],
515
+ operands=operands,
516
+ backend_config=opaque,
517
+ operand_layouts=operand_layouts,
518
+ result_layouts=default_layouts(r_type.shape, state_shape),
519
+ ).results
520
+ return out
521
+
522
+ """
523
+ 将算子绑定到C++
524
+ """
525
+ mlir.register_lowering(
526
+ _rwkv_fwd_state_p,
527
+ _rwkv_fwd_with_state_cuda_lowering,
528
+ platform="gpu",
529
+ )
530
+ """
531
+ 定义抽象过程,告知jax输出张量形状与数据类型
532
+ """
533
+
534
+ def _rwkv_fwd_with_state_abstract(r, k, v, w, u, s=None, s_map=None):
535
+ assert all([r.shape == xx.shape for xx in [k, v, w]])
536
+ assert len(r.shape) == 3
537
+ bz, seq_len, channels = r.shape
538
+ assert channels % head_size == 0
539
+ assert seq_len <= max_sequence_length
540
+ assert reduce(lambda x, y: x * y, u.shape, 1) == channels, (
541
+ "the elements of u (time first) is not equal to hidden_size"
542
+ )
543
+ num_heads = channels // head_size
544
+ r_dtype = dtypes.canonicalize_dtype(r.dtype)
545
+ k_dtype = dtypes.canonicalize_dtype(k.dtype)
546
+ v_dtype = dtypes.canonicalize_dtype(v.dtype)
547
+ w_dtype = dtypes.canonicalize_dtype(w.dtype)
548
+ u_dtype = dtypes.canonicalize_dtype(u.dtype)
549
+
550
+ assert all(
551
+ [r_dtype == xx for xx in [k_dtype, v_dtype, w_dtype, u_dtype]]
552
+ )
553
+ assert r_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
554
+ if r_dtype in [jnp.float32, jnp.bfloat16]:
555
+ output_dtype = r_dtype
556
+ else:
557
+ output_dtype = jnp.float32
558
+
559
+ state_shape = (bz, num_heads, head_size, head_size)
560
+ if s is not None:
561
+ s_dtype = dtypes.canonicalize_dtype(s.dtype)
562
+ assert s_dtype == output_dtype
563
+ assert s.shape == state_shape, (
564
+ "the shape of init_state must be (batch_size, seq_len, num_heads, head_size, head_size)"
565
+ )
566
+
567
+ return [
568
+ ShapedArray(
569
+ r.shape, output_dtype, named_shape=r.named_shape
570
+ ), # output
571
+ ShapedArray(state_shape, output_dtype, named_shape=r.named_shape),
572
+ ]
573
+
574
+ _rwkv_fwd_state_p.def_abstract_eval(_rwkv_fwd_with_state_abstract)
575
+ return _rwkv_fwd_with_state, _rwkv_fwd_state_p
576
+
577
+ self.rwkv_fwd_with_state, _ = build_forward_with_state_op()
578
+
579
+ def __call__(
580
+ self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
581
+ ):
582
+ if with_state:
583
+ return self.rwkv_fwd_with_state(r, k, v, w, u, init_state, state_map)
584
+ else:
585
+ return self.rwkv_normal_operator(r, k, v, w, u), None
586
+
587
+ @staticmethod
588
+ def _load_or_build_kernel(head_size, max_sequence_length):
589
+ assert head_size % 4 == 0, f"head size必须是4的倍数,而{head_size}显然不是."
590
+ assert isinstance(head_size, int), "你是在搞笑吗? head_size肯定得是int类型的啊"
591
+ assert isinstance(max_sequence_length, int), (
592
+ "你是在搞笑吗? max_sequence_length肯定得是int类型的啊"
593
+ )
594
+ assert head_size > 0 and max_sequence_length > 0, (
595
+ "难绷,head_size与max_sequence_length肯定得是大于0的正整数啊。"
596
+ )
597
+ assert os.path.exists(cuda_lib_dir) and len(os.listdir(cuda_lib_dir)) > 0, (
598
+ f"请检查{cuda_lib_dir}文件夹是否存在,这个文件本质是是您的cuda library的超链接。"
599
+ )
600
+ kernel_dir = os.path.abspath(
601
+ os.path.join(os.path.dirname(__file__), kernel_dir_name)
602
+ )
603
+ builds_dir = os.path.join(kernel_dir, "builds")
604
+ assert os.path.exists(kernel_dir), (
605
+ f"找不到{kernel_dir_name}文件夹,请问您的文件是完整的吗?"
606
+ )
607
+ if not os.path.exists(builds_dir):
608
+ os.mkdir(builds_dir)
609
+ target_dir_name = f"_N_{head_size}_T_{max_sequence_length}"
610
+ target_dir = os.path.join(builds_dir, target_dir_name)
611
+ if not os.path.exists(target_dir):
612
+ os.mkdir(target_dir)
613
+
614
+ def get_cflags():
615
+ getvar = sysconfig.get_config_var
616
+ flags = [
617
+ "-I" + sysconfig.get_path("include"),
618
+ "-I" + sysconfig.get_path("platinclude"),
619
+ ]
620
+
621
+ flags.extend(getvar("CFLAGS").split())
622
+ return " ".join(flags)
623
+
624
+ def get_suffix():
625
+ getvar = sysconfig.get_config_var
626
+ return getvar("EXT_SUFFIX")
627
+
628
+ build_cmds = []
629
+
630
+ # first, build cuda kernel
631
+ if use_rocm:
632
+ cu_src = os.path.join(kernel_dir, "rwkv_kernels.hip")
633
+ assert os.path.exists(cu_src)
634
+ cu_dst = os.path.join(target_dir, "rwkv_kernels.hip.o")
635
+ kernel_cmd = (
636
+ f"hipcc -O3 --hipstdpar -xhip -fopenmp -ffast-math"
637
+ + f" -munsafe-fp-atomics -enable-vectorize-compares"
638
+ + f" -I{cuda_lib_dir} -I{pybind11.get_include()}"
639
+ + f" -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
640
+ + f" --gpu-max-threads-per-block=120"
641
+ + f" -c {cu_src} -o {cu_dst} -D _N_={head_size} -D _T_={max_sequence_length}"
642
+ )
643
+ else:
644
+ cu_src = os.path.join(kernel_dir, "rwkv_kernels.cu")
645
+ assert os.path.exists(cu_src)
646
+ cu_dst = os.path.join(target_dir, "rwkv_kernels.cu.o")
647
+ kernel_cmd = (
648
+ f"nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3"
649
+ + f" --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86]"
650
+ + f" -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c {cu_src} -o {cu_dst} -D _N_={head_size} -D _T_={max_sequence_length}"
651
+ )
652
+ build_cmds.append(kernel_cmd)
653
+
654
+ so_dst = os.path.join(target_dir, f"{kernel_name}{get_suffix()}")
655
+ if not os.path.exists(so_dst):
656
+ # second, build C++ code.
657
+ cpp_src = os.path.join(kernel_dir, f"{kernel_name}.cpp")
658
+ cpp_dst = os.path.join(builds_dir, f"{kernel_name}.cpp.o")
659
+ if not os.path.exists(cpp_dst):
660
+ if use_rocm:
661
+ cpp_cmd = (
662
+ f"c++ -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
663
+ + f" -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
664
+ + f" -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects"
665
+ + f" -o {cpp_dst} -c {cpp_src}"
666
+ )
667
+ else:
668
+ cpp_cmd = (
669
+ f"c++ -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
670
+ + f" -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects"
671
+ + f" -o {cpp_dst} -c {cpp_src}"
672
+ )
673
+ build_cmds.append(cpp_cmd)
674
+
675
+ # third assembly C++ and cuda
676
+ if use_rocm:
677
+ assembly_cmd = (
678
+ f"c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o {so_dst} {cpp_dst} {cu_dst}"
679
+ + f" -fPIC -I{cuda_lib_dir} -I{pybind11.get_include()} {get_cflags()}"
680
+ + f" -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2"
681
+ + f" -L/opt/rocm/lib -lamdhip64 -lpthread -ldl"
682
+ )
683
+ else:
684
+ assembly_cmd = (
685
+ f"c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o {so_dst} {cpp_dst} {cu_dst}"
686
+ + f" -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl"
687
+ )
688
+ build_cmds.append(assembly_cmd)
689
+
690
+ # finally strip the so library
691
+ strip_cmd = f"strip {so_dst}"
692
+ build_cmds.append(strip_cmd)
693
+
694
+ print("-------------------starting build kernel -------------------")
695
+ for cmd in build_cmds:
696
+ print("--------------- execute cmd ---------------")
697
+ print(cmd)
698
+ p = subprocess.Popen(cmd, shell=True)
699
+ p.wait()
700
+ print("-------------------build kernel finished -------------------")
701
+ kernel_module_path = f"rwkv_ops.rwkv6_kernel.{kernel_dir_name}.builds.{target_dir_name}.{kernel_name}"
702
+ print("loading cuda kernel from the path:", kernel_module_path)
703
+ rwkv_op = importlib.import_module(kernel_module_path)
704
+ return rwkv_op
705
+
706
+
707
+ if __name__ == "__main__":
708
+ bz, seq_len, hd_sz = 1, 16, 8
709
+ r = jnp.zeros(shape=(bz, seq_len, hd_sz), dtype="float16") + 2
710
+ k = jnp.zeros(shape=(bz, seq_len, hd_sz), dtype="float16") + 2
711
+ v = jnp.zeros(shape=(bz, seq_len, hd_sz), dtype="float16") + 2
712
+ w = jnp.zeros(shape=(bz, seq_len, hd_sz), dtype="float16") - 2
713
+ u = jnp.zeros(shape=(hd_sz,), dtype="float16") + 2
714
+ rwkv_op = RWKVKernelOperator(head_size=hd_sz, max_sequence_length=seq_len)
715
+ out = rwkv_op(r, k, v, w, u)
716
+
717
+ print(out.dtype)
718
+
719
+ def ref_loss(r, k, v, w, u):
720
+ predictions = rwkv_op(r, k, v, w, u)
721
+ return -jnp.mean(predictions**2)
722
+
723
+ ref_out = jax.grad(ref_loss, argnums=(0, 1, 2, 3, 4))(r, k, v, w, u)
724
+ print(ref_out)