rwkv-ops 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- rwkv_ops/__init__.py +1 -1
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +54 -54
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +3 -3
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +16 -18
- rwkv_ops/rwkv7_kernel/__init__.py +7 -3
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +3 -3
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +1 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +6 -6
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +3 -3
- rwkv_ops/rwkv7_kernel/mlx_op.py +132 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +6 -6
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +3 -3
- {rwkv_ops-0.3.2.dist-info → rwkv_ops-0.3.3.dist-info}/METADATA +3 -1
- {rwkv_ops-0.3.2.dist-info → rwkv_ops-0.3.3.dist-info}/RECORD +16 -15
- {rwkv_ops-0.3.2.dist-info → rwkv_ops-0.3.3.dist-info}/WHEEL +0 -0
- {rwkv_ops-0.3.2.dist-info → rwkv_ops-0.3.3.dist-info}/licenses/LICENSE.txt +0 -0
rwkv_ops/__init__.py
CHANGED
|
@@ -103,9 +103,9 @@ class RWKVKernelOperator:
|
|
|
103
103
|
bz, seq_len, hd_sz = r_type.shape
|
|
104
104
|
|
|
105
105
|
assert hd_sz % head_size == 0
|
|
106
|
-
assert
|
|
107
|
-
|
|
108
|
-
)
|
|
106
|
+
assert (
|
|
107
|
+
reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz
|
|
108
|
+
), "the elements of u (time first) is not equal to hidden_size"
|
|
109
109
|
input_type = r_type.element_type
|
|
110
110
|
|
|
111
111
|
if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
|
|
@@ -159,9 +159,9 @@ class RWKVKernelOperator:
|
|
|
159
159
|
bz, seq_len, channels = r.shape
|
|
160
160
|
assert channels % head_size == 0
|
|
161
161
|
assert seq_len <= max_sequence_length
|
|
162
|
-
assert
|
|
163
|
-
|
|
164
|
-
)
|
|
162
|
+
assert (
|
|
163
|
+
reduce(lambda x, y: x * y, u.shape, 1) == channels
|
|
164
|
+
), "the elements of u (time first) is not equal to hidden_size"
|
|
165
165
|
|
|
166
166
|
r_dtype = dtypes.canonicalize_dtype(r.dtype)
|
|
167
167
|
k_dtype = dtypes.canonicalize_dtype(k.dtype)
|
|
@@ -237,9 +237,9 @@ class RWKVKernelOperator:
|
|
|
237
237
|
bz, seq_len, hd_sz = r_type.shape
|
|
238
238
|
|
|
239
239
|
assert hd_sz % head_size == 0
|
|
240
|
-
assert
|
|
241
|
-
|
|
242
|
-
)
|
|
240
|
+
assert (
|
|
241
|
+
reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz
|
|
242
|
+
), "the elements of u (time first) is not equal to hidden_size"
|
|
243
243
|
input_type = r_type.element_type
|
|
244
244
|
|
|
245
245
|
if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
|
|
@@ -304,9 +304,9 @@ class RWKVKernelOperator:
|
|
|
304
304
|
bz, seq_len, channels = r.shape
|
|
305
305
|
assert channels % head_size == 0
|
|
306
306
|
assert seq_len <= max_sequence_length
|
|
307
|
-
assert
|
|
308
|
-
|
|
309
|
-
)
|
|
307
|
+
assert (
|
|
308
|
+
reduce(lambda x, y: x * y, u.shape, 1) == channels
|
|
309
|
+
), "the elements of u (time first) is not equal to hidden_size"
|
|
310
310
|
|
|
311
311
|
r_dtype = dtypes.canonicalize_dtype(r.dtype)
|
|
312
312
|
k_dtype = dtypes.canonicalize_dtype(k.dtype)
|
|
@@ -374,9 +374,9 @@ class RWKVKernelOperator:
|
|
|
374
374
|
n_state = jnp.shape(init_state)[0]
|
|
375
375
|
B = jnp.shape(r)[0]
|
|
376
376
|
# print('ns:',n_state,'B:',B,r.shape,k.shape,v.shape)
|
|
377
|
-
assert
|
|
378
|
-
|
|
379
|
-
)
|
|
377
|
+
assert (
|
|
378
|
+
n_state == 1 or n_state == B
|
|
379
|
+
), "我无法为您推断state_map的形状,请手动指定。"
|
|
380
380
|
if n_state == 1:
|
|
381
381
|
state_map = jnp.array([0] * B, dtype=jnp.int32)
|
|
382
382
|
elif n_state == B:
|
|
@@ -392,12 +392,12 @@ class RWKVKernelOperator:
|
|
|
392
392
|
jnp.int32,
|
|
393
393
|
], "state_map的数值类型必须为int32"
|
|
394
394
|
state_map = jnp.astype(state_map, jnp.int32)
|
|
395
|
-
assert jnp.all(state_map >= 0) and jnp.add(
|
|
396
|
-
|
|
397
|
-
)
|
|
398
|
-
assert (init_state is None) == (
|
|
399
|
-
|
|
400
|
-
)
|
|
395
|
+
assert jnp.all(state_map >= 0) and jnp.add(
|
|
396
|
+
state_map < bz
|
|
397
|
+
), f"state_map内为state的映射下标,因此范围为: [0,{bz})"
|
|
398
|
+
assert (init_state is None) == (
|
|
399
|
+
state_map is None
|
|
400
|
+
), "init_state与state_map必须同时传入"
|
|
401
401
|
|
|
402
402
|
if init_state is None:
|
|
403
403
|
y, s = _rwkv_fwd_state_p.bind(r, k, v, w, u)
|
|
@@ -440,9 +440,9 @@ class RWKVKernelOperator:
|
|
|
440
440
|
|
|
441
441
|
assert hd_sz % head_size == 0
|
|
442
442
|
num_heads = hd_sz // head_size
|
|
443
|
-
assert
|
|
444
|
-
|
|
445
|
-
)
|
|
443
|
+
assert (
|
|
444
|
+
reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz
|
|
445
|
+
), "the elements of u (time first) is not equal to hidden_size"
|
|
446
446
|
input_type = r_type.element_type
|
|
447
447
|
|
|
448
448
|
if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
|
|
@@ -452,25 +452,25 @@ class RWKVKernelOperator:
|
|
|
452
452
|
state_shape = (bz, num_heads, head_size, head_size)
|
|
453
453
|
|
|
454
454
|
if with_init_state:
|
|
455
|
-
assert
|
|
456
|
-
|
|
457
|
-
)
|
|
455
|
+
assert (
|
|
456
|
+
s_map is not None
|
|
457
|
+
), "您必须同时传入init_state与state_map 或者都赋值为None."
|
|
458
458
|
|
|
459
459
|
s_type = ir.RankedTensorType(s.type)
|
|
460
460
|
sm_type = ir.RankedTensorType(s_map.type)
|
|
461
461
|
# print(sm_type, ir.IntegerType.get_signless(64))
|
|
462
|
-
assert sm_type.element_type == ir.IntegerType.get_signless(
|
|
463
|
-
|
|
464
|
-
)
|
|
462
|
+
assert sm_type.element_type == ir.IntegerType.get_signless(
|
|
463
|
+
32
|
|
464
|
+
), "state_map的数据类型必须为int32"
|
|
465
465
|
# print(sm_type.shape,bz)
|
|
466
|
-
assert tuple(sm_type.shape) == (
|
|
467
|
-
|
|
468
|
-
)
|
|
466
|
+
assert tuple(sm_type.shape) == (
|
|
467
|
+
bz,
|
|
468
|
+
), "state_map的shape 形状必须为(batch_size,)"
|
|
469
469
|
|
|
470
470
|
assert s_type.element_type == output_type
|
|
471
|
-
assert
|
|
472
|
-
|
|
473
|
-
)
|
|
471
|
+
assert (
|
|
472
|
+
tuple(s_type.shape) == state_shape
|
|
473
|
+
), "the shape of init state must be (batch_size,num_heads,head_size,head_size)"
|
|
474
474
|
# assert s_type.shape[0] == bz and reduce(lambda x,y: x * y, s_type.shape[1:],1) == head_size * hd_sz,"the shape of init state must be (batch_size,num_heads,head_size,head_size)"
|
|
475
475
|
|
|
476
476
|
opaque = rwkv_kernel.create_rwkv_descriptor(
|
|
@@ -535,9 +535,9 @@ class RWKVKernelOperator:
|
|
|
535
535
|
bz, seq_len, channels = r.shape
|
|
536
536
|
assert channels % head_size == 0
|
|
537
537
|
assert seq_len <= max_sequence_length
|
|
538
|
-
assert
|
|
539
|
-
|
|
540
|
-
)
|
|
538
|
+
assert (
|
|
539
|
+
reduce(lambda x, y: x * y, u.shape, 1) == channels
|
|
540
|
+
), "the elements of u (time first) is not equal to hidden_size"
|
|
541
541
|
num_heads = channels // head_size
|
|
542
542
|
r_dtype = dtypes.canonicalize_dtype(r.dtype)
|
|
543
543
|
k_dtype = dtypes.canonicalize_dtype(k.dtype)
|
|
@@ -558,9 +558,9 @@ class RWKVKernelOperator:
|
|
|
558
558
|
if s is not None:
|
|
559
559
|
s_dtype = dtypes.canonicalize_dtype(s.dtype)
|
|
560
560
|
assert s_dtype == output_dtype
|
|
561
|
-
assert
|
|
562
|
-
|
|
563
|
-
)
|
|
561
|
+
assert (
|
|
562
|
+
s.shape == state_shape
|
|
563
|
+
), "the shape of init_state must be (batch_size, seq_len, num_heads, head_size, head_size)"
|
|
564
564
|
|
|
565
565
|
return [
|
|
566
566
|
ShapedArray(
|
|
@@ -586,22 +586,22 @@ class RWKVKernelOperator:
|
|
|
586
586
|
def _load_or_build_kernel(head_size, max_sequence_length):
|
|
587
587
|
assert head_size % 4 == 0, f"head size必须是4的倍数,而{head_size}显然不是."
|
|
588
588
|
assert isinstance(head_size, int), "你是在搞笑吗? head_size肯定得是int类型的啊"
|
|
589
|
-
assert isinstance(
|
|
590
|
-
|
|
591
|
-
)
|
|
592
|
-
assert
|
|
593
|
-
|
|
594
|
-
)
|
|
595
|
-
assert
|
|
596
|
-
|
|
597
|
-
)
|
|
589
|
+
assert isinstance(
|
|
590
|
+
max_sequence_length, int
|
|
591
|
+
), "你是在搞笑吗? max_sequence_length肯定得是int类型的啊"
|
|
592
|
+
assert (
|
|
593
|
+
head_size > 0 and max_sequence_length > 0
|
|
594
|
+
), "难绷,head_size与max_sequence_length肯定得是大于0的正整数啊。"
|
|
595
|
+
assert (
|
|
596
|
+
os.path.exists(cuda_lib_dir) and len(os.listdir(cuda_lib_dir)) > 0
|
|
597
|
+
), f"请检查{cuda_lib_dir}文件夹是否存在,这个文件本质是是您的cuda library的超链接。"
|
|
598
598
|
kernel_dir = os.path.abspath(
|
|
599
599
|
os.path.join(os.path.dirname(__file__), kernel_dir_name)
|
|
600
600
|
)
|
|
601
601
|
builds_dir = os.path.join(kernel_dir, "builds")
|
|
602
|
-
assert os.path.exists(
|
|
603
|
-
|
|
604
|
-
)
|
|
602
|
+
assert os.path.exists(
|
|
603
|
+
kernel_dir
|
|
604
|
+
), f"找不到{kernel_dir_name}文件夹,请问您的文件是完整的吗?"
|
|
605
605
|
if not os.path.exists(builds_dir):
|
|
606
606
|
os.mkdir(builds_dir)
|
|
607
607
|
target_dir_name = f"_N_{head_size}_T_{max_sequence_length}"
|
|
@@ -55,9 +55,9 @@ class RWKVKernelOperator:
|
|
|
55
55
|
if isinstance(state_map, list):
|
|
56
56
|
state_map = ops.convert_to_tensor(state_map, dtype="int32")
|
|
57
57
|
state_map = ops.cast(state_map, "int32")
|
|
58
|
-
assert (state_map >= 0).all() and (
|
|
59
|
-
|
|
60
|
-
)
|
|
58
|
+
assert (state_map >= 0).all() and (
|
|
59
|
+
state_map < state_kinds
|
|
60
|
+
).all(), f"请确保state_map的值域为[0, {state_kinds})"
|
|
61
61
|
s = ops.take(init_state, state_map, axis=0)
|
|
62
62
|
|
|
63
63
|
else:
|
|
@@ -154,9 +154,9 @@ class RWKVKernelOperator:
|
|
|
154
154
|
@staticmethod
|
|
155
155
|
def apply(B, T, C, H, S, s_map, r, k, v, w, u, s):
|
|
156
156
|
with torch.no_grad():
|
|
157
|
-
assert
|
|
158
|
-
|
|
159
|
-
)
|
|
157
|
+
assert (
|
|
158
|
+
s_map.dtype == torch.int64
|
|
159
|
+
), "s_map 必须为None 或者是长度为B的,int64类型的数组。"
|
|
160
160
|
assert (s is None and s_map is None) or (
|
|
161
161
|
s is not None and s_map is not None
|
|
162
162
|
), "init_state与s_map必须同时为None 或者同时不为None"
|
|
@@ -240,17 +240,15 @@ class RWKVKernelOperator:
|
|
|
240
240
|
is_custom_init = init_state is not None
|
|
241
241
|
|
|
242
242
|
if init_state is not None:
|
|
243
|
-
assert
|
|
244
|
-
|
|
245
|
-
)
|
|
243
|
+
assert (
|
|
244
|
+
len(init_state.shape) in [3, 4]
|
|
245
|
+
), "init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
|
|
246
246
|
if len(init_state.shape) == 3:
|
|
247
247
|
init_state = init_state[None, :]
|
|
248
248
|
assert (
|
|
249
249
|
init_state.shape[1:] == (H, self.head_size, self.head_size)
|
|
250
250
|
and init_state.shape[0] <= B
|
|
251
|
-
), (
|
|
252
|
-
"init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
|
|
253
|
-
)
|
|
251
|
+
), "init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
|
|
254
252
|
|
|
255
253
|
assert init_state.dtype == s_dtype, f"init_state的数值类型应为: {s_dtype}"
|
|
256
254
|
assert init_state.device == r.device
|
|
@@ -269,17 +267,17 @@ class RWKVKernelOperator:
|
|
|
269
267
|
|
|
270
268
|
if with_state:
|
|
271
269
|
if init_state is None:
|
|
272
|
-
assert
|
|
273
|
-
|
|
274
|
-
)
|
|
270
|
+
assert (
|
|
271
|
+
state_map is None
|
|
272
|
+
), "您必须在指定了init_state的情况下才能使用state_map"
|
|
275
273
|
init_state = torch.zeros((0,), device=r.device, dtype=s_dtype)
|
|
276
274
|
state_map = torch.zeros((0,), device=r.device, dtype=torch.int64)
|
|
277
275
|
else:
|
|
278
276
|
n_state = init_state.shape[0]
|
|
279
277
|
if state_map is None:
|
|
280
|
-
assert
|
|
281
|
-
|
|
282
|
-
)
|
|
278
|
+
assert (
|
|
279
|
+
n_state == 1 or n_state == B
|
|
280
|
+
), "我无法为您推断state_map的形状,请手动指定。"
|
|
283
281
|
if n_state == 1:
|
|
284
282
|
state_map = torch.tensor(
|
|
285
283
|
[0] * B, dtype=torch.int64, device=r.device
|
|
@@ -292,9 +290,9 @@ class RWKVKernelOperator:
|
|
|
292
290
|
assert False, "未实现"
|
|
293
291
|
else:
|
|
294
292
|
assert state_map.shape == (B,), "state_map的形状必须为(batch_size,)"
|
|
295
|
-
assert (
|
|
296
|
-
|
|
297
|
-
)
|
|
293
|
+
assert (
|
|
294
|
+
(state_map >= 0).all() and (state_map < n_state).all()
|
|
295
|
+
), f"state_map的取值范围为[0,{n_state})之间的整数,您的输入显然不满足。"
|
|
298
296
|
# print('state map:',state_map)
|
|
299
297
|
o, ys = self.kernel_with_state.apply(
|
|
300
298
|
B, T, C, H, is_custom_init, state_map, r, k, v, w, u, init_state
|
|
@@ -15,9 +15,11 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
15
15
|
USE_TRITON_KERNEL = False
|
|
16
16
|
if keras.config.backend() == "torch":
|
|
17
17
|
import torch
|
|
18
|
+
|
|
18
19
|
if not torch.cuda.is_available():
|
|
19
20
|
from .native_keras_op import generalized_delta_rule
|
|
20
|
-
|
|
21
|
+
|
|
22
|
+
return generalized_delta_rule, False
|
|
21
23
|
|
|
22
24
|
if KERNEL_TYPE.lower() == "triton":
|
|
23
25
|
from .torch_op import generalized_delta_rule
|
|
@@ -157,8 +159,8 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
157
159
|
elif keras.config.backend() == "jax":
|
|
158
160
|
import jax
|
|
159
161
|
import os
|
|
160
|
-
|
|
161
|
-
if jax.devices()[0].platform
|
|
162
|
+
|
|
163
|
+
if jax.devices()[0].platform == "gpu":
|
|
162
164
|
if KERNEL_TYPE.lower() == "triton":
|
|
163
165
|
os.environ["JAX_LOG_COMPUTATION"] = "0"
|
|
164
166
|
from .jax_op import generalized_delta_rule
|
|
@@ -196,6 +198,8 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
196
198
|
from .native_keras_op import generalized_delta_rule
|
|
197
199
|
else:
|
|
198
200
|
from .native_keras_op import generalized_delta_rule
|
|
201
|
+
elif keras.config.backend() == "mlx" and KERNEL_TYPE.lower() == "cuda":
|
|
202
|
+
from .mlx_op import generalized_delta_rule
|
|
199
203
|
else:
|
|
200
204
|
from .native_keras_op import generalized_delta_rule
|
|
201
205
|
return generalized_delta_rule, USE_TRITON_KERNEL
|
|
@@ -74,9 +74,9 @@ if check_pytorch_version("2.4"):
|
|
|
74
74
|
def custom_device_ctx(index: int):
|
|
75
75
|
return device_torch_lib.device(index)
|
|
76
76
|
else:
|
|
77
|
-
assert
|
|
78
|
-
|
|
79
|
-
)
|
|
77
|
+
assert (
|
|
78
|
+
device == "cuda"
|
|
79
|
+
), "Only cuda device is supported for PyTorch version < 2.4.0."
|
|
80
80
|
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
|
|
81
81
|
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
|
|
82
82
|
|
|
@@ -25,9 +25,9 @@ def chunk_dplr_bwd_dhu(
|
|
|
25
25
|
B, T, H, K, V = *qg.shape, do.shape[-1]
|
|
26
26
|
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
27
27
|
BK = triton.next_power_of_2(K)
|
|
28
|
-
assert
|
|
29
|
-
|
|
30
|
-
)
|
|
28
|
+
assert (
|
|
29
|
+
BK <= 256
|
|
30
|
+
), "current kernel does not support head dimension being larger than 256."
|
|
31
31
|
# H100
|
|
32
32
|
if check_shared_mem("hopper"):
|
|
33
33
|
BV = 64
|
|
@@ -42,9 +42,9 @@ def chunk_dplr_bwd_dhu(
|
|
|
42
42
|
N, NT = B, triton.cdiv(T, BT)
|
|
43
43
|
BC = min(BT, BC)
|
|
44
44
|
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
|
45
|
-
assert
|
|
46
|
-
|
|
47
|
-
)
|
|
45
|
+
assert (
|
|
46
|
+
NK == 1
|
|
47
|
+
), "NK > 1 is not supported because it involves time-consuming synchronization"
|
|
48
48
|
dh_shape = (B, NT, H, K, V)
|
|
49
49
|
out_shapes = [
|
|
50
50
|
jax.ShapeDtypeStruct(dh_shape, dv.dtype),
|
|
@@ -43,9 +43,9 @@ def chunk_dplr_fwd_h(
|
|
|
43
43
|
BC = min(BT, BC)
|
|
44
44
|
NK = triton.cdiv(K, BK)
|
|
45
45
|
NV = triton.cdiv(V, BV)
|
|
46
|
-
assert
|
|
47
|
-
|
|
48
|
-
)
|
|
46
|
+
assert (
|
|
47
|
+
NK == 1
|
|
48
|
+
), "NK > 1 is not supported because it involves time-consuming synchronization"
|
|
49
49
|
|
|
50
50
|
out_shapes = [
|
|
51
51
|
jax.ShapeDtypeStruct((B, NT, H, K, V), kg.dtype),
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# copy from https://github.com/ml-explore/mlx-lm/pull/580
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import mlx.nn as nn
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@partial(mx.compile, shapeless=True)
|
|
11
|
+
def addcmul(x, y, z):
|
|
12
|
+
return x + y * z
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@partial(mx.compile, shapeless=True)
|
|
16
|
+
def l2_norm(x):
|
|
17
|
+
return x / mx.maximum(mx.linalg.norm(x, axis=-1, keepdims=True), 1e-7)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _make_wkv7_kernel():
|
|
21
|
+
if not mx.metal.is_available():
|
|
22
|
+
return None
|
|
23
|
+
source = f"""
|
|
24
|
+
auto n = thread_position_in_grid.z;
|
|
25
|
+
auto b_idx = n / H;
|
|
26
|
+
auto h_idx = n % H;
|
|
27
|
+
constexpr int n_per_t = D / 32;
|
|
28
|
+
// [B, T, H, D]
|
|
29
|
+
auto r_ = r + b_idx * T * H * D + h_idx * D;
|
|
30
|
+
auto w_ = w + b_idx * T * H * D + h_idx * D;
|
|
31
|
+
auto k_ = k + b_idx * T * H * D + h_idx * D;
|
|
32
|
+
auto v_ = v + b_idx * T * H * D + h_idx * D;
|
|
33
|
+
auto a_ = a + b_idx * T * H * D + h_idx * D;
|
|
34
|
+
auto b_ = b + b_idx * T * H * D + h_idx * D;
|
|
35
|
+
y += b_idx * T * H * D + h_idx * D;
|
|
36
|
+
auto dk_idx = thread_position_in_threadgroup.x;
|
|
37
|
+
auto dv_idx = thread_position_in_grid.y;
|
|
38
|
+
// state_in, state_out: [B, H, D, D]
|
|
39
|
+
auto i_state = state_in + (n * D + dv_idx) * D;
|
|
40
|
+
auto o_state = state_out + (n * D + dv_idx) * D;
|
|
41
|
+
float state[n_per_t];
|
|
42
|
+
for (int i = 0; i < n_per_t; ++i) {{
|
|
43
|
+
auto s_idx = n_per_t * dk_idx + i;
|
|
44
|
+
state[i] = static_cast<float>(i_state[s_idx]);
|
|
45
|
+
}}
|
|
46
|
+
for (int t = 0; t < T; ++t) {{
|
|
47
|
+
float sa = 0.0f;
|
|
48
|
+
for (int i = 0; i < n_per_t; ++i) {{
|
|
49
|
+
auto s_idx = n_per_t * dk_idx + i;
|
|
50
|
+
sa += state[i] * a_[s_idx];
|
|
51
|
+
state[i] = state[i] * w_[s_idx];
|
|
52
|
+
}}
|
|
53
|
+
sa = simd_sum(sa);
|
|
54
|
+
float out = 0.0f;
|
|
55
|
+
for (int i = 0; i < n_per_t; ++i) {{
|
|
56
|
+
auto s_idx = n_per_t * dk_idx + i;
|
|
57
|
+
state[i] = state[i] + k_[s_idx] * v_[dv_idx] + sa * b_[s_idx];
|
|
58
|
+
out += state[i] * r_[s_idx];
|
|
59
|
+
}}
|
|
60
|
+
out = simd_sum(out);
|
|
61
|
+
if (thread_index_in_simdgroup == 0) {{
|
|
62
|
+
y[dv_idx] = static_cast<InT>(out);
|
|
63
|
+
}}
|
|
64
|
+
// Increment data pointers to next time step
|
|
65
|
+
r_ += H * D;
|
|
66
|
+
w_ += H * D;
|
|
67
|
+
k_ += H * D;
|
|
68
|
+
v_ += H * D;
|
|
69
|
+
a_ += H * D;
|
|
70
|
+
b_ += H * D;
|
|
71
|
+
y += H * D;
|
|
72
|
+
}}
|
|
73
|
+
for (int i = 0; i < n_per_t; ++i) {{
|
|
74
|
+
auto s_idx = n_per_t * dk_idx + i;
|
|
75
|
+
o_state[s_idx] = static_cast<InT>(state[i]);
|
|
76
|
+
}}
|
|
77
|
+
"""
|
|
78
|
+
inputs = ["r", "w", "k", "v", "a", "b", "state_in", "T"]
|
|
79
|
+
return mx.fast.metal_kernel(
|
|
80
|
+
name="wkv7_kernel",
|
|
81
|
+
input_names=inputs,
|
|
82
|
+
output_names=["y", "state_out"],
|
|
83
|
+
source=source,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
_wkv7_kernel = _make_wkv7_kernel()
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def transpose_head(x, head_first: bool = True):
|
|
91
|
+
if head_first:
|
|
92
|
+
return mx.transpose(x, (0, 2, 1, 3))
|
|
93
|
+
return x
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def generalized_delta_rule(
|
|
97
|
+
r,
|
|
98
|
+
w,
|
|
99
|
+
k,
|
|
100
|
+
v,
|
|
101
|
+
a,
|
|
102
|
+
b,
|
|
103
|
+
initial_state=None,
|
|
104
|
+
output_final_state: bool = True,
|
|
105
|
+
head_first: bool = False,
|
|
106
|
+
):
|
|
107
|
+
state = initial_state
|
|
108
|
+
|
|
109
|
+
r = transpose_head(r, head_first)
|
|
110
|
+
k = transpose_head(k, head_first)
|
|
111
|
+
v = transpose_head(v, head_first)
|
|
112
|
+
a = transpose_head(a, head_first)
|
|
113
|
+
b = transpose_head(b, head_first)
|
|
114
|
+
|
|
115
|
+
B, T, H, D = r.shape
|
|
116
|
+
input_dtype = r.dtype
|
|
117
|
+
|
|
118
|
+
y, out_state = _wkv7_kernel(
|
|
119
|
+
inputs=[r, w, k, v, a, b, state, T],
|
|
120
|
+
template=[
|
|
121
|
+
("InT", input_dtype),
|
|
122
|
+
("H", H),
|
|
123
|
+
("D", D),
|
|
124
|
+
],
|
|
125
|
+
grid=(32, D, B * H),
|
|
126
|
+
threadgroup=(32, 4, 1),
|
|
127
|
+
output_shapes=[(B, T, H, D), state.shape],
|
|
128
|
+
output_dtypes=[input_dtype, input_dtype],
|
|
129
|
+
)
|
|
130
|
+
if output_final_state:
|
|
131
|
+
return y, out_state
|
|
132
|
+
return y
|
|
@@ -24,9 +24,9 @@ def chunk_dplr_bwd_dhu(
|
|
|
24
24
|
B, T, H, K, V = *qg.shape, do.shape[-1]
|
|
25
25
|
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
26
26
|
BK = triton.next_power_of_2(K)
|
|
27
|
-
assert
|
|
28
|
-
|
|
29
|
-
)
|
|
27
|
+
assert (
|
|
28
|
+
BK <= 256
|
|
29
|
+
), "current kernel does not support head dimension being larger than 256."
|
|
30
30
|
# H100
|
|
31
31
|
if check_shared_mem("hopper", qg.device.index):
|
|
32
32
|
BV = 64
|
|
@@ -42,9 +42,9 @@ def chunk_dplr_bwd_dhu(
|
|
|
42
42
|
|
|
43
43
|
BC = min(BT, BC)
|
|
44
44
|
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
|
45
|
-
assert
|
|
46
|
-
|
|
47
|
-
)
|
|
45
|
+
assert (
|
|
46
|
+
NK == 1
|
|
47
|
+
), "NK > 1 is not supported because it involves time-consuming synchronization"
|
|
48
48
|
|
|
49
49
|
dh = qg.new_empty(B, NT, H, K, V)
|
|
50
50
|
dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
|
|
@@ -42,9 +42,9 @@ def chunk_dplr_fwd_h(
|
|
|
42
42
|
BC = min(BT, BC)
|
|
43
43
|
NK = triton.cdiv(K, BK)
|
|
44
44
|
NV = triton.cdiv(V, BV)
|
|
45
|
-
assert
|
|
46
|
-
|
|
47
|
-
)
|
|
45
|
+
assert (
|
|
46
|
+
NK == 1
|
|
47
|
+
), "NK > 1 is not supported because it involves time-consuming synchronization"
|
|
48
48
|
|
|
49
49
|
h = kg.new_empty(B, NT, H, K, V)
|
|
50
50
|
final_state = (
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rwkv-ops
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.3
|
|
4
4
|
Summary: RWKV operators for multiple backends (PyTorch, JAX, Keras)
|
|
5
5
|
Project-URL: Homepage, https://github.com/pass-lin/rwkv_ops
|
|
6
6
|
Author-email: pass-lin <qw_lin@qq.com>
|
|
@@ -125,12 +125,14 @@ if padding_mask is not None:
|
|
|
125
125
|
| JAX | ✅ | ✅ | ✅ |
|
|
126
126
|
| TensorFlow | ⚠️ | ❌ | ✅ |
|
|
127
127
|
| NumPy | ❌ | ❌ | ✅ |
|
|
128
|
+
| MLX | ⚠️ | ❌ | ❌ |
|
|
128
129
|
|
|
129
130
|
---
|
|
130
131
|
> `native` 为原生算子,无 chunkwise,速度慢且显存高。
|
|
131
132
|
> `triton` 使用的是chunkwise算法实现,速度快,并行度高,缺点是精度很差,介意勿用
|
|
132
133
|
> `cuda` 为基于 CUDA 的原生算子,速度很快,并且kernel内部使用fp32实现,所以精度也很高。缺点就是长序列的时候比较吃亏跑不满。
|
|
133
134
|
> tensorflow的CUDA实现只支持前向计算,是没有梯度的。并且这个是使用jax的cuda实现实现的,你需要保证你能够成功运行jax的cuda kernel。
|
|
135
|
+
> 因为MLX还没合并到keras,所以原生算子暂不支持。但是我们提供了一个前向的算子。
|
|
134
136
|
|
|
135
137
|
## rwkv6op 使用方法
|
|
136
138
|
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
rwkv_ops/__init__.py,sha256=
|
|
1
|
+
rwkv_ops/__init__.py,sha256=voJa6h1nvEua5iD2H-mxsZR3j6s1at8xrqnXj4Q-WYQ,855
|
|
2
2
|
rwkv_ops/rwkv6_kernel/__init__.py,sha256=ktIzkK6EUc2nonLQnl2NAjJj9kMt02i9zqfjFcnM_NQ,3647
|
|
3
|
-
rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py,sha256=
|
|
4
|
-
rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py,sha256=
|
|
5
|
-
rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py,sha256=
|
|
3
|
+
rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py,sha256=DTYtT2v2WjOlndD-ESNmbmVj1ili03KOXOZek1V4DLw,29954
|
|
4
|
+
rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py,sha256=XBsYAkbsyUCsxVPR-RfrboQkyM8TIzW1saOsh6vEjcM,3352
|
|
5
|
+
rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py,sha256=tq5H-ndm9nq2nWTQYI-xpw5YK0j0E31yF552muSFHN4,12883
|
|
6
6
|
rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp,sha256=oM13TCQi2GMIf3f-Z39WOL8M_8GmGI_Kdhiq3Y2keJw,1643
|
|
7
7
|
rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h,sha256=epwsW8OUIOvrlNuW3BAmAbgB8n8CKOFEYafBxQy3ptw,2209
|
|
8
8
|
rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h,sha256=KYJiWmmig0Wh-zpiWV96J_be8jlyc38Ztd1iqNoqVFI,1501
|
|
@@ -15,21 +15,22 @@ rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h,sha256=CMQclcyHaD
|
|
|
15
15
|
rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip,sha256=givSxPA7YfKGz75rOtN8TAjTxWWraVNgTGPZfAJsZsQ,20836
|
|
16
16
|
rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu,sha256=tfRbMQBkl_LT7EVaJ6KoWYcQ902ApCrS6zkjXldFZXY,12770
|
|
17
17
|
rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp,sha256=cyCTiF--4SQiDJu7Dy_NuEhSe1vyki6JS4I2rsvT714,6659
|
|
18
|
-
rwkv_ops/rwkv7_kernel/__init__.py,sha256=
|
|
18
|
+
rwkv_ops/rwkv7_kernel/__init__.py,sha256=OJq9ZU1GPP5Si8LY66miTFFotxyGlBtPyUp49Cedl8k,8250
|
|
19
19
|
rwkv_ops/rwkv7_kernel/get_jax_devices_info.py,sha256=cMIaNED7d1PvYNSyq8wNI3G7wNvcgdUj9HWRBLuSVM8,6004
|
|
20
|
-
rwkv_ops/rwkv7_kernel/get_torch_devices_info.py,sha256=
|
|
20
|
+
rwkv_ops/rwkv7_kernel/get_torch_devices_info.py,sha256=BR6IqwcBDKjLf-uRCh0LAzYtRl4KP43JO5fnd9jsr2c,7380
|
|
21
21
|
rwkv_ops/rwkv7_kernel/jax_op.py,sha256=C7jOvJ-ZWTFfCZBQNzMbqgoVHuDS2QCGlBsGEMM4Fn0,9140
|
|
22
|
+
rwkv_ops/rwkv7_kernel/mlx_op.py,sha256=Ss9i_1TdGPNnc1YpD7QBSdK1_sTQ2R5l8Mk5UIcWQJ0,3795
|
|
22
23
|
rwkv_ops/rwkv7_kernel/native_keras_op.py,sha256=dCWdzuVZxAKHCBURZqgOLN3n_yKFFNX5uORlbvztH6w,2502
|
|
23
24
|
rwkv_ops/rwkv7_kernel/tf_eager_kernel.py,sha256=2t2uf1iNznYpYFlqt9REY0GwGeycYuaJl-4QFk2rJHc,4357
|
|
24
25
|
rwkv_ops/rwkv7_kernel/torch_op.py,sha256=jw_AvqshTAG4t9-MRqxFQNi_bTzxNbx3lwnMifPk8-8,14070
|
|
25
26
|
rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt,sha256=Dq4Ea8N2xOEej2jZpEw4MtFjUFgN0PUciejVOCSP-FM,1400
|
|
26
27
|
rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu,sha256=WePveEdUixaQA51hJUK8Sr7Q7jDTstybEWZczdjuGSo,9690
|
|
27
|
-
rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py,sha256=
|
|
28
|
+
rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py,sha256=t2mQ_zGhMeBZClcaLJSwRG4n2MtRhvn9z-vHWK79F6w,7999
|
|
28
29
|
rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py,sha256=uHsf_1qrtRK62IvhLuzefHGPWpHXmw1p0tqmwlHcptk,346
|
|
29
30
|
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py,sha256=2Voq1Bdzn0DFloiLvwINBk7akmxRWIqXIQeyafrJJGg,2138
|
|
30
31
|
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py,sha256=rhmglqHIIww7yPzaSBEp9ISxhhxoUbMtV51AUDyhUd8,1425
|
|
31
|
-
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py,sha256=
|
|
32
|
-
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py,sha256=
|
|
32
|
+
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py,sha256=U06dcacmND-y022mN4UmDunfRDxJYWthU_4V8z0HcSs,1888
|
|
33
|
+
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py,sha256=Nl0migPjRmQopIsysSqt7ZMQ_X-vyblb7e2t-xghzlA,1992
|
|
33
34
|
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py,sha256=gQnToi1e1GZCvjWsEdWx6WakUN4Lc0JfaBSsSXYdN84,3369
|
|
34
35
|
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py,sha256=4SjQ_zTZvFxsBMeWOx0JGFg9EQ4vllvEx30EcvSZJzI,853
|
|
35
36
|
rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py,sha256=NoOh2_hA_rdH5bmaNNMAdCgVPfWvQpf-Q8BqF926jrw,667
|
|
@@ -40,8 +41,8 @@ rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp,sha256=Wk5QYvIM9m-YJdSEh6zSz
|
|
|
40
41
|
rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py,sha256=_u1srIATeoHKlVTVWbWXdpkjaggugl9y-Kx_Y4pYdIY,430
|
|
41
42
|
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py,sha256=CWtotXkVvHz4-rkuOqWh6zKy95jwimS9If6SU45ylW0,2103
|
|
42
43
|
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py,sha256=4RJbyUTO23OxwH1rGVxeBiBVZKNHpPL_tJ7MFoDCIts,1475
|
|
43
|
-
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py,sha256=
|
|
44
|
-
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py,sha256=
|
|
44
|
+
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py,sha256=AdbgPd0JRfPZ_poK_XAQ5iV1GsBqDehiN0lf_-_CbUw,1795
|
|
45
|
+
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py,sha256=o_EbLxqqnzW8_aNduqv_Brd_-SlUU3szfi8Lfn40rqc,1858
|
|
45
46
|
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py,sha256=ioPrS0NYQhpFk1j8rAxqtbwpx1CwjJQnrJEBDqVy-As,3283
|
|
46
47
|
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py,sha256=54yoa3NpV64H-koURt-hUWpFHhUjwXpGvXPp2_ETCnw,825
|
|
47
48
|
rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py,sha256=hQkpyaa0eUyB4V3UVks7l1_dHwOrbump0FZILityBKw,611
|
|
@@ -58,7 +59,7 @@ rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py,sha256=pRp_z587PrnpgRVpi031IndyjVI
|
|
|
58
59
|
rwkv_ops/rwkv7_kernel/triton_kernel/utils.py,sha256=TNGlkwGq4t-TOcdVBk_N_vHPLzMFTu_F0V-O1RprIO4,553
|
|
59
60
|
rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py,sha256=szaG11q_WmpyhXi6aVWwzizvflCh5wND8wGA_V8afzA,5479
|
|
60
61
|
rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py,sha256=jbb19DUTHENU2RIOv_T4m_W1eXMqdRqG0XevIkBOhI4,9438
|
|
61
|
-
rwkv_ops-0.3.
|
|
62
|
-
rwkv_ops-0.3.
|
|
63
|
-
rwkv_ops-0.3.
|
|
64
|
-
rwkv_ops-0.3.
|
|
62
|
+
rwkv_ops-0.3.3.dist-info/METADATA,sha256=EU1tq3Ub9WqVpcNYRa1T_I9H2Nx1nNypX8fZWdu7bsM,9011
|
|
63
|
+
rwkv_ops-0.3.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
64
|
+
rwkv_ops-0.3.3.dist-info/licenses/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
|
65
|
+
rwkv_ops-0.3.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|