rwkv-ops 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

rwkv_ops/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.3.2"
1
+ __version__ = "0.3.3"
2
2
  import os
3
3
 
4
4
  KERNEL_TYPE = os.environ.get("KERNEL_TYPE", "cuda").lower()
@@ -103,9 +103,9 @@ class RWKVKernelOperator:
103
103
  bz, seq_len, hd_sz = r_type.shape
104
104
 
105
105
  assert hd_sz % head_size == 0
106
- assert reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz, (
107
- "the elements of u (time first) is not equal to hidden_size"
108
- )
106
+ assert (
107
+ reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz
108
+ ), "the elements of u (time first) is not equal to hidden_size"
109
109
  input_type = r_type.element_type
110
110
 
111
111
  if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
@@ -159,9 +159,9 @@ class RWKVKernelOperator:
159
159
  bz, seq_len, channels = r.shape
160
160
  assert channels % head_size == 0
161
161
  assert seq_len <= max_sequence_length
162
- assert reduce(lambda x, y: x * y, u.shape, 1) == channels, (
163
- "the elements of u (time first) is not equal to hidden_size"
164
- )
162
+ assert (
163
+ reduce(lambda x, y: x * y, u.shape, 1) == channels
164
+ ), "the elements of u (time first) is not equal to hidden_size"
165
165
 
166
166
  r_dtype = dtypes.canonicalize_dtype(r.dtype)
167
167
  k_dtype = dtypes.canonicalize_dtype(k.dtype)
@@ -237,9 +237,9 @@ class RWKVKernelOperator:
237
237
  bz, seq_len, hd_sz = r_type.shape
238
238
 
239
239
  assert hd_sz % head_size == 0
240
- assert reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz, (
241
- "the elements of u (time first) is not equal to hidden_size"
242
- )
240
+ assert (
241
+ reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz
242
+ ), "the elements of u (time first) is not equal to hidden_size"
243
243
  input_type = r_type.element_type
244
244
 
245
245
  if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
@@ -304,9 +304,9 @@ class RWKVKernelOperator:
304
304
  bz, seq_len, channels = r.shape
305
305
  assert channels % head_size == 0
306
306
  assert seq_len <= max_sequence_length
307
- assert reduce(lambda x, y: x * y, u.shape, 1) == channels, (
308
- "the elements of u (time first) is not equal to hidden_size"
309
- )
307
+ assert (
308
+ reduce(lambda x, y: x * y, u.shape, 1) == channels
309
+ ), "the elements of u (time first) is not equal to hidden_size"
310
310
 
311
311
  r_dtype = dtypes.canonicalize_dtype(r.dtype)
312
312
  k_dtype = dtypes.canonicalize_dtype(k.dtype)
@@ -374,9 +374,9 @@ class RWKVKernelOperator:
374
374
  n_state = jnp.shape(init_state)[0]
375
375
  B = jnp.shape(r)[0]
376
376
  # print('ns:',n_state,'B:',B,r.shape,k.shape,v.shape)
377
- assert n_state == 1 or n_state == B, (
378
- "我无法为您推断state_map的形状,请手动指定。"
379
- )
377
+ assert (
378
+ n_state == 1 or n_state == B
379
+ ), "我无法为您推断state_map的形状,请手动指定。"
380
380
  if n_state == 1:
381
381
  state_map = jnp.array([0] * B, dtype=jnp.int32)
382
382
  elif n_state == B:
@@ -392,12 +392,12 @@ class RWKVKernelOperator:
392
392
  jnp.int32,
393
393
  ], "state_map的数值类型必须为int32"
394
394
  state_map = jnp.astype(state_map, jnp.int32)
395
- assert jnp.all(state_map >= 0) and jnp.add(state_map < bz), (
396
- f"state_map内为state的映射下标,因此范围为: [0,{bz})"
397
- )
398
- assert (init_state is None) == (state_map is None), (
399
- "init_state与state_map必须同时传入"
400
- )
395
+ assert jnp.all(state_map >= 0) and jnp.add(
396
+ state_map < bz
397
+ ), f"state_map内为state的映射下标,因此范围为: [0,{bz})"
398
+ assert (init_state is None) == (
399
+ state_map is None
400
+ ), "init_state与state_map必须同时传入"
401
401
 
402
402
  if init_state is None:
403
403
  y, s = _rwkv_fwd_state_p.bind(r, k, v, w, u)
@@ -440,9 +440,9 @@ class RWKVKernelOperator:
440
440
 
441
441
  assert hd_sz % head_size == 0
442
442
  num_heads = hd_sz // head_size
443
- assert reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz, (
444
- "the elements of u (time first) is not equal to hidden_size"
445
- )
443
+ assert (
444
+ reduce(lambda x, y: x * y, u_type.shape, 1) == hd_sz
445
+ ), "the elements of u (time first) is not equal to hidden_size"
446
446
  input_type = r_type.element_type
447
447
 
448
448
  if input_type in [ir.F32Type.get(), ir.BF16Type.get()]:
@@ -452,25 +452,25 @@ class RWKVKernelOperator:
452
452
  state_shape = (bz, num_heads, head_size, head_size)
453
453
 
454
454
  if with_init_state:
455
- assert s_map is not None, (
456
- "您必须同时传入init_state与state_map 或者都赋值为None."
457
- )
455
+ assert (
456
+ s_map is not None
457
+ ), "您必须同时传入init_state与state_map 或者都赋值为None."
458
458
 
459
459
  s_type = ir.RankedTensorType(s.type)
460
460
  sm_type = ir.RankedTensorType(s_map.type)
461
461
  # print(sm_type, ir.IntegerType.get_signless(64))
462
- assert sm_type.element_type == ir.IntegerType.get_signless(32), (
463
- "state_map的数据类型必须为int32"
464
- )
462
+ assert sm_type.element_type == ir.IntegerType.get_signless(
463
+ 32
464
+ ), "state_map的数据类型必须为int32"
465
465
  # print(sm_type.shape,bz)
466
- assert tuple(sm_type.shape) == (bz,), (
467
- "state_map的shape 形状必须为(batch_size,)"
468
- )
466
+ assert tuple(sm_type.shape) == (
467
+ bz,
468
+ ), "state_map的shape 形状必须为(batch_size,)"
469
469
 
470
470
  assert s_type.element_type == output_type
471
- assert tuple(s_type.shape) == state_shape, (
472
- "the shape of init state must be (batch_size,num_heads,head_size,head_size)"
473
- )
471
+ assert (
472
+ tuple(s_type.shape) == state_shape
473
+ ), "the shape of init state must be (batch_size,num_heads,head_size,head_size)"
474
474
  # assert s_type.shape[0] == bz and reduce(lambda x,y: x * y, s_type.shape[1:],1) == head_size * hd_sz,"the shape of init state must be (batch_size,num_heads,head_size,head_size)"
475
475
 
476
476
  opaque = rwkv_kernel.create_rwkv_descriptor(
@@ -535,9 +535,9 @@ class RWKVKernelOperator:
535
535
  bz, seq_len, channels = r.shape
536
536
  assert channels % head_size == 0
537
537
  assert seq_len <= max_sequence_length
538
- assert reduce(lambda x, y: x * y, u.shape, 1) == channels, (
539
- "the elements of u (time first) is not equal to hidden_size"
540
- )
538
+ assert (
539
+ reduce(lambda x, y: x * y, u.shape, 1) == channels
540
+ ), "the elements of u (time first) is not equal to hidden_size"
541
541
  num_heads = channels // head_size
542
542
  r_dtype = dtypes.canonicalize_dtype(r.dtype)
543
543
  k_dtype = dtypes.canonicalize_dtype(k.dtype)
@@ -558,9 +558,9 @@ class RWKVKernelOperator:
558
558
  if s is not None:
559
559
  s_dtype = dtypes.canonicalize_dtype(s.dtype)
560
560
  assert s_dtype == output_dtype
561
- assert s.shape == state_shape, (
562
- "the shape of init_state must be (batch_size, seq_len, num_heads, head_size, head_size)"
563
- )
561
+ assert (
562
+ s.shape == state_shape
563
+ ), "the shape of init_state must be (batch_size, seq_len, num_heads, head_size, head_size)"
564
564
 
565
565
  return [
566
566
  ShapedArray(
@@ -586,22 +586,22 @@ class RWKVKernelOperator:
586
586
  def _load_or_build_kernel(head_size, max_sequence_length):
587
587
  assert head_size % 4 == 0, f"head size必须是4的倍数,而{head_size}显然不是."
588
588
  assert isinstance(head_size, int), "你是在搞笑吗? head_size肯定得是int类型的啊"
589
- assert isinstance(max_sequence_length, int), (
590
- "你是在搞笑吗? max_sequence_length肯定得是int类型的啊"
591
- )
592
- assert head_size > 0 and max_sequence_length > 0, (
593
- "难绷,head_sizemax_sequence_length肯定得是大于0的正整数啊。"
594
- )
595
- assert os.path.exists(cuda_lib_dir) and len(os.listdir(cuda_lib_dir)) > 0, (
596
- f"请检查{cuda_lib_dir}文件夹是否存在,这个文件本质是是您的cuda library的超链接。"
597
- )
589
+ assert isinstance(
590
+ max_sequence_length, int
591
+ ), "你是在搞笑吗? max_sequence_length肯定得是int类型的啊"
592
+ assert (
593
+ head_size > 0 and max_sequence_length > 0
594
+ ), "难绷,head_size与max_sequence_length肯定得是大于0的正整数啊。"
595
+ assert (
596
+ os.path.exists(cuda_lib_dir) and len(os.listdir(cuda_lib_dir)) > 0
597
+ ), f"请检查{cuda_lib_dir}文件夹是否存在,这个文件本质是是您的cuda library的超链接。"
598
598
  kernel_dir = os.path.abspath(
599
599
  os.path.join(os.path.dirname(__file__), kernel_dir_name)
600
600
  )
601
601
  builds_dir = os.path.join(kernel_dir, "builds")
602
- assert os.path.exists(kernel_dir), (
603
- f"找不到{kernel_dir_name}文件夹,请问您的文件是完整的吗?"
604
- )
602
+ assert os.path.exists(
603
+ kernel_dir
604
+ ), f"找不到{kernel_dir_name}文件夹,请问您的文件是完整的吗?"
605
605
  if not os.path.exists(builds_dir):
606
606
  os.mkdir(builds_dir)
607
607
  target_dir_name = f"_N_{head_size}_T_{max_sequence_length}"
@@ -55,9 +55,9 @@ class RWKVKernelOperator:
55
55
  if isinstance(state_map, list):
56
56
  state_map = ops.convert_to_tensor(state_map, dtype="int32")
57
57
  state_map = ops.cast(state_map, "int32")
58
- assert (state_map >= 0).all() and (state_map < state_kinds).all(), (
59
- f"请确保state_map的值域为[0, {state_kinds})"
60
- )
58
+ assert (state_map >= 0).all() and (
59
+ state_map < state_kinds
60
+ ).all(), f"请确保state_map的值域为[0, {state_kinds})"
61
61
  s = ops.take(init_state, state_map, axis=0)
62
62
 
63
63
  else:
@@ -154,9 +154,9 @@ class RWKVKernelOperator:
154
154
  @staticmethod
155
155
  def apply(B, T, C, H, S, s_map, r, k, v, w, u, s):
156
156
  with torch.no_grad():
157
- assert s_map.dtype == torch.int64, (
158
- "s_map 必须为None 或者是长度为B的,int64类型的数组。"
159
- )
157
+ assert (
158
+ s_map.dtype == torch.int64
159
+ ), "s_map 必须为None 或者是长度为B的,int64类型的数组。"
160
160
  assert (s is None and s_map is None) or (
161
161
  s is not None and s_map is not None
162
162
  ), "init_state与s_map必须同时为None 或者同时不为None"
@@ -240,17 +240,15 @@ class RWKVKernelOperator:
240
240
  is_custom_init = init_state is not None
241
241
 
242
242
  if init_state is not None:
243
- assert len(init_state.shape) in [3, 4], (
244
- "init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
245
- )
243
+ assert (
244
+ len(init_state.shape) in [3, 4]
245
+ ), "init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
246
246
  if len(init_state.shape) == 3:
247
247
  init_state = init_state[None, :]
248
248
  assert (
249
249
  init_state.shape[1:] == (H, self.head_size, self.head_size)
250
250
  and init_state.shape[0] <= B
251
- ), (
252
- "init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
253
- )
251
+ ), "init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
254
252
 
255
253
  assert init_state.dtype == s_dtype, f"init_state的数值类型应为: {s_dtype}"
256
254
  assert init_state.device == r.device
@@ -269,17 +267,17 @@ class RWKVKernelOperator:
269
267
 
270
268
  if with_state:
271
269
  if init_state is None:
272
- assert state_map is None, (
273
- "您必须在指定了init_state的情况下才能使用state_map"
274
- )
270
+ assert (
271
+ state_map is None
272
+ ), "您必须在指定了init_state的情况下才能使用state_map"
275
273
  init_state = torch.zeros((0,), device=r.device, dtype=s_dtype)
276
274
  state_map = torch.zeros((0,), device=r.device, dtype=torch.int64)
277
275
  else:
278
276
  n_state = init_state.shape[0]
279
277
  if state_map is None:
280
- assert n_state == 1 or n_state == B, (
281
- "我无法为您推断state_map的形状,请手动指定。"
282
- )
278
+ assert (
279
+ n_state == 1 or n_state == B
280
+ ), "我无法为您推断state_map的形状,请手动指定。"
283
281
  if n_state == 1:
284
282
  state_map = torch.tensor(
285
283
  [0] * B, dtype=torch.int64, device=r.device
@@ -292,9 +290,9 @@ class RWKVKernelOperator:
292
290
  assert False, "未实现"
293
291
  else:
294
292
  assert state_map.shape == (B,), "state_map的形状必须为(batch_size,)"
295
- assert (state_map >= 0).all() and (state_map < n_state).all(), (
296
- f"state_map的取值范围为[0,{n_state})之间的整数,您的输入显然不满足。"
297
- )
293
+ assert (
294
+ (state_map >= 0).all() and (state_map < n_state).all()
295
+ ), f"state_map的取值范围为[0,{n_state})之间的整数,您的输入显然不满足。"
298
296
  # print('state map:',state_map)
299
297
  o, ys = self.kernel_with_state.apply(
300
298
  B, T, C, H, is_custom_init, state_map, r, k, v, w, u, init_state
@@ -15,9 +15,11 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
15
15
  USE_TRITON_KERNEL = False
16
16
  if keras.config.backend() == "torch":
17
17
  import torch
18
+
18
19
  if not torch.cuda.is_available():
19
20
  from .native_keras_op import generalized_delta_rule
20
- return generalized_delta_rule,False
21
+
22
+ return generalized_delta_rule, False
21
23
 
22
24
  if KERNEL_TYPE.lower() == "triton":
23
25
  from .torch_op import generalized_delta_rule
@@ -157,8 +159,8 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
157
159
  elif keras.config.backend() == "jax":
158
160
  import jax
159
161
  import os
160
-
161
- if jax.devices()[0].platform == "gpu":
162
+
163
+ if jax.devices()[0].platform == "gpu":
162
164
  if KERNEL_TYPE.lower() == "triton":
163
165
  os.environ["JAX_LOG_COMPUTATION"] = "0"
164
166
  from .jax_op import generalized_delta_rule
@@ -196,6 +198,8 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
196
198
  from .native_keras_op import generalized_delta_rule
197
199
  else:
198
200
  from .native_keras_op import generalized_delta_rule
201
+ elif keras.config.backend() == "mlx" and KERNEL_TYPE.lower() == "cuda":
202
+ from .mlx_op import generalized_delta_rule
199
203
  else:
200
204
  from .native_keras_op import generalized_delta_rule
201
205
  return generalized_delta_rule, USE_TRITON_KERNEL
@@ -74,9 +74,9 @@ if check_pytorch_version("2.4"):
74
74
  def custom_device_ctx(index: int):
75
75
  return device_torch_lib.device(index)
76
76
  else:
77
- assert device == "cuda", (
78
- "Only cuda device is supported for PyTorch version < 2.4.0."
79
- )
77
+ assert (
78
+ device == "cuda"
79
+ ), "Only cuda device is supported for PyTorch version < 2.4.0."
80
80
  autocast_custom_fwd = device_torch_lib.amp.custom_fwd
81
81
  autocast_custom_bwd = device_torch_lib.amp.custom_bwd
82
82
 
@@ -11,6 +11,7 @@ import jax
11
11
  import jax.numpy as jnp
12
12
  from typing import Optional, Tuple, Union
13
13
  from jax.ad_checkpoint import checkpoint_policies as cp
14
+
14
15
  CHUNK_LEN = 16 # 这是一个常数
15
16
  # ---------- 延迟编译(改到当前目录) ----------
16
17
  _CURRENT_DIR = pathlib.Path(
@@ -25,9 +25,9 @@ def chunk_dplr_bwd_dhu(
25
25
  B, T, H, K, V = *qg.shape, do.shape[-1]
26
26
  BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
27
27
  BK = triton.next_power_of_2(K)
28
- assert BK <= 256, (
29
- "current kernel does not support head dimension being larger than 256."
30
- )
28
+ assert (
29
+ BK <= 256
30
+ ), "current kernel does not support head dimension being larger than 256."
31
31
  # H100
32
32
  if check_shared_mem("hopper"):
33
33
  BV = 64
@@ -42,9 +42,9 @@ def chunk_dplr_bwd_dhu(
42
42
  N, NT = B, triton.cdiv(T, BT)
43
43
  BC = min(BT, BC)
44
44
  NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
45
- assert NK == 1, (
46
- "NK > 1 is not supported because it involves time-consuming synchronization"
47
- )
45
+ assert (
46
+ NK == 1
47
+ ), "NK > 1 is not supported because it involves time-consuming synchronization"
48
48
  dh_shape = (B, NT, H, K, V)
49
49
  out_shapes = [
50
50
  jax.ShapeDtypeStruct(dh_shape, dv.dtype),
@@ -43,9 +43,9 @@ def chunk_dplr_fwd_h(
43
43
  BC = min(BT, BC)
44
44
  NK = triton.cdiv(K, BK)
45
45
  NV = triton.cdiv(V, BV)
46
- assert NK == 1, (
47
- "NK > 1 is not supported because it involves time-consuming synchronization"
48
- )
46
+ assert (
47
+ NK == 1
48
+ ), "NK > 1 is not supported because it involves time-consuming synchronization"
49
49
 
50
50
  out_shapes = [
51
51
  jax.ShapeDtypeStruct((B, NT, H, K, V), kg.dtype),
@@ -0,0 +1,132 @@
1
+ # copy from https://github.com/ml-explore/mlx-lm/pull/580
2
+ from dataclasses import dataclass
3
+ from functools import partial
4
+ from typing import Optional
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+
9
+
10
+ @partial(mx.compile, shapeless=True)
11
+ def addcmul(x, y, z):
12
+ return x + y * z
13
+
14
+
15
+ @partial(mx.compile, shapeless=True)
16
+ def l2_norm(x):
17
+ return x / mx.maximum(mx.linalg.norm(x, axis=-1, keepdims=True), 1e-7)
18
+
19
+
20
+ def _make_wkv7_kernel():
21
+ if not mx.metal.is_available():
22
+ return None
23
+ source = f"""
24
+ auto n = thread_position_in_grid.z;
25
+ auto b_idx = n / H;
26
+ auto h_idx = n % H;
27
+ constexpr int n_per_t = D / 32;
28
+ // [B, T, H, D]
29
+ auto r_ = r + b_idx * T * H * D + h_idx * D;
30
+ auto w_ = w + b_idx * T * H * D + h_idx * D;
31
+ auto k_ = k + b_idx * T * H * D + h_idx * D;
32
+ auto v_ = v + b_idx * T * H * D + h_idx * D;
33
+ auto a_ = a + b_idx * T * H * D + h_idx * D;
34
+ auto b_ = b + b_idx * T * H * D + h_idx * D;
35
+ y += b_idx * T * H * D + h_idx * D;
36
+ auto dk_idx = thread_position_in_threadgroup.x;
37
+ auto dv_idx = thread_position_in_grid.y;
38
+ // state_in, state_out: [B, H, D, D]
39
+ auto i_state = state_in + (n * D + dv_idx) * D;
40
+ auto o_state = state_out + (n * D + dv_idx) * D;
41
+ float state[n_per_t];
42
+ for (int i = 0; i < n_per_t; ++i) {{
43
+ auto s_idx = n_per_t * dk_idx + i;
44
+ state[i] = static_cast<float>(i_state[s_idx]);
45
+ }}
46
+ for (int t = 0; t < T; ++t) {{
47
+ float sa = 0.0f;
48
+ for (int i = 0; i < n_per_t; ++i) {{
49
+ auto s_idx = n_per_t * dk_idx + i;
50
+ sa += state[i] * a_[s_idx];
51
+ state[i] = state[i] * w_[s_idx];
52
+ }}
53
+ sa = simd_sum(sa);
54
+ float out = 0.0f;
55
+ for (int i = 0; i < n_per_t; ++i) {{
56
+ auto s_idx = n_per_t * dk_idx + i;
57
+ state[i] = state[i] + k_[s_idx] * v_[dv_idx] + sa * b_[s_idx];
58
+ out += state[i] * r_[s_idx];
59
+ }}
60
+ out = simd_sum(out);
61
+ if (thread_index_in_simdgroup == 0) {{
62
+ y[dv_idx] = static_cast<InT>(out);
63
+ }}
64
+ // Increment data pointers to next time step
65
+ r_ += H * D;
66
+ w_ += H * D;
67
+ k_ += H * D;
68
+ v_ += H * D;
69
+ a_ += H * D;
70
+ b_ += H * D;
71
+ y += H * D;
72
+ }}
73
+ for (int i = 0; i < n_per_t; ++i) {{
74
+ auto s_idx = n_per_t * dk_idx + i;
75
+ o_state[s_idx] = static_cast<InT>(state[i]);
76
+ }}
77
+ """
78
+ inputs = ["r", "w", "k", "v", "a", "b", "state_in", "T"]
79
+ return mx.fast.metal_kernel(
80
+ name="wkv7_kernel",
81
+ input_names=inputs,
82
+ output_names=["y", "state_out"],
83
+ source=source,
84
+ )
85
+
86
+
87
+ _wkv7_kernel = _make_wkv7_kernel()
88
+
89
+
90
+ def transpose_head(x, head_first: bool = True):
91
+ if head_first:
92
+ return mx.transpose(x, (0, 2, 1, 3))
93
+ return x
94
+
95
+
96
+ def generalized_delta_rule(
97
+ r,
98
+ w,
99
+ k,
100
+ v,
101
+ a,
102
+ b,
103
+ initial_state=None,
104
+ output_final_state: bool = True,
105
+ head_first: bool = False,
106
+ ):
107
+ state = initial_state
108
+
109
+ r = transpose_head(r, head_first)
110
+ k = transpose_head(k, head_first)
111
+ v = transpose_head(v, head_first)
112
+ a = transpose_head(a, head_first)
113
+ b = transpose_head(b, head_first)
114
+
115
+ B, T, H, D = r.shape
116
+ input_dtype = r.dtype
117
+
118
+ y, out_state = _wkv7_kernel(
119
+ inputs=[r, w, k, v, a, b, state, T],
120
+ template=[
121
+ ("InT", input_dtype),
122
+ ("H", H),
123
+ ("D", D),
124
+ ],
125
+ grid=(32, D, B * H),
126
+ threadgroup=(32, 4, 1),
127
+ output_shapes=[(B, T, H, D), state.shape],
128
+ output_dtypes=[input_dtype, input_dtype],
129
+ )
130
+ if output_final_state:
131
+ return y, out_state
132
+ return y
@@ -24,9 +24,9 @@ def chunk_dplr_bwd_dhu(
24
24
  B, T, H, K, V = *qg.shape, do.shape[-1]
25
25
  BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
26
26
  BK = triton.next_power_of_2(K)
27
- assert BK <= 256, (
28
- "current kernel does not support head dimension being larger than 256."
29
- )
27
+ assert (
28
+ BK <= 256
29
+ ), "current kernel does not support head dimension being larger than 256."
30
30
  # H100
31
31
  if check_shared_mem("hopper", qg.device.index):
32
32
  BV = 64
@@ -42,9 +42,9 @@ def chunk_dplr_bwd_dhu(
42
42
 
43
43
  BC = min(BT, BC)
44
44
  NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
45
- assert NK == 1, (
46
- "NK > 1 is not supported because it involves time-consuming synchronization"
47
- )
45
+ assert (
46
+ NK == 1
47
+ ), "NK > 1 is not supported because it involves time-consuming synchronization"
48
48
 
49
49
  dh = qg.new_empty(B, NT, H, K, V)
50
50
  dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
@@ -42,9 +42,9 @@ def chunk_dplr_fwd_h(
42
42
  BC = min(BT, BC)
43
43
  NK = triton.cdiv(K, BK)
44
44
  NV = triton.cdiv(V, BV)
45
- assert NK == 1, (
46
- "NK > 1 is not supported because it involves time-consuming synchronization"
47
- )
45
+ assert (
46
+ NK == 1
47
+ ), "NK > 1 is not supported because it involves time-consuming synchronization"
48
48
 
49
49
  h = kg.new_empty(B, NT, H, K, V)
50
50
  final_state = (
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rwkv-ops
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: RWKV operators for multiple backends (PyTorch, JAX, Keras)
5
5
  Project-URL: Homepage, https://github.com/pass-lin/rwkv_ops
6
6
  Author-email: pass-lin <qw_lin@qq.com>
@@ -125,12 +125,14 @@ if padding_mask is not None:
125
125
  | JAX | ✅ | ✅ | ✅ |
126
126
  | TensorFlow | ⚠️ | ❌ | ✅ |
127
127
  | NumPy | ❌ | ❌ | ✅ |
128
+ | MLX | ⚠️ | ❌ | ❌ |
128
129
 
129
130
  ---
130
131
  > `native` 为原生算子,无 chunkwise,速度慢且显存高。
131
132
  > `triton` 使用的是chunkwise算法实现,速度快,并行度高,缺点是精度很差,介意勿用
132
133
  > `cuda` 为基于 CUDA 的原生算子,速度很快,并且kernel内部使用fp32实现,所以精度也很高。缺点就是长序列的时候比较吃亏跑不满。
133
134
  > tensorflow的CUDA实现只支持前向计算,是没有梯度的。并且这个是使用jax的cuda实现实现的,你需要保证你能够成功运行jax的cuda kernel。
135
+ > 因为MLX还没合并到keras,所以原生算子暂不支持。但是我们提供了一个前向的算子。
134
136
 
135
137
  ## rwkv6op 使用方法
136
138
 
@@ -1,8 +1,8 @@
1
- rwkv_ops/__init__.py,sha256=ojPQmkz3yWNqqJwIyjAfsxWB_h3TowtBrJtuRqssEvA,855
1
+ rwkv_ops/__init__.py,sha256=voJa6h1nvEua5iD2H-mxsZR3j6s1at8xrqnXj4Q-WYQ,855
2
2
  rwkv_ops/rwkv6_kernel/__init__.py,sha256=ktIzkK6EUc2nonLQnl2NAjJj9kMt02i9zqfjFcnM_NQ,3647
3
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py,sha256=4SL93Z4mmuQldHtmwqTKcP7M-outTU5Rge2qgDGzwBg,29966
4
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py,sha256=c3ZSJ9xC6-PKr88pOhjmBximdhwmP1_i7UOcIdKB43c,3354
5
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py,sha256=Pv0WsBp5byTSwkYrYkHcJa3wftSsHHzfRzleKdmJayY,12915
3
+ rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py,sha256=DTYtT2v2WjOlndD-ESNmbmVj1ili03KOXOZek1V4DLw,29954
4
+ rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py,sha256=XBsYAkbsyUCsxVPR-RfrboQkyM8TIzW1saOsh6vEjcM,3352
5
+ rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py,sha256=tq5H-ndm9nq2nWTQYI-xpw5YK0j0E31yF552muSFHN4,12883
6
6
  rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp,sha256=oM13TCQi2GMIf3f-Z39WOL8M_8GmGI_Kdhiq3Y2keJw,1643
7
7
  rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h,sha256=epwsW8OUIOvrlNuW3BAmAbgB8n8CKOFEYafBxQy3ptw,2209
8
8
  rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h,sha256=KYJiWmmig0Wh-zpiWV96J_be8jlyc38Ztd1iqNoqVFI,1501
@@ -15,21 +15,22 @@ rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h,sha256=CMQclcyHaD
15
15
  rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip,sha256=givSxPA7YfKGz75rOtN8TAjTxWWraVNgTGPZfAJsZsQ,20836
16
16
  rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu,sha256=tfRbMQBkl_LT7EVaJ6KoWYcQ902ApCrS6zkjXldFZXY,12770
17
17
  rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp,sha256=cyCTiF--4SQiDJu7Dy_NuEhSe1vyki6JS4I2rsvT714,6659
18
- rwkv_ops/rwkv7_kernel/__init__.py,sha256=HfoB043qxcIyljNcSd_XtH2UKB6wF2qQlOq9VvXwWRI,8129
18
+ rwkv_ops/rwkv7_kernel/__init__.py,sha256=OJq9ZU1GPP5Si8LY66miTFFotxyGlBtPyUp49Cedl8k,8250
19
19
  rwkv_ops/rwkv7_kernel/get_jax_devices_info.py,sha256=cMIaNED7d1PvYNSyq8wNI3G7wNvcgdUj9HWRBLuSVM8,6004
20
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py,sha256=ZL_rAM6lHB4nTOOU28Xm08qptfuIoijOMi_xwJG3KCo,7380
20
+ rwkv_ops/rwkv7_kernel/get_torch_devices_info.py,sha256=BR6IqwcBDKjLf-uRCh0LAzYtRl4KP43JO5fnd9jsr2c,7380
21
21
  rwkv_ops/rwkv7_kernel/jax_op.py,sha256=C7jOvJ-ZWTFfCZBQNzMbqgoVHuDS2QCGlBsGEMM4Fn0,9140
22
+ rwkv_ops/rwkv7_kernel/mlx_op.py,sha256=Ss9i_1TdGPNnc1YpD7QBSdK1_sTQ2R5l8Mk5UIcWQJ0,3795
22
23
  rwkv_ops/rwkv7_kernel/native_keras_op.py,sha256=dCWdzuVZxAKHCBURZqgOLN3n_yKFFNX5uORlbvztH6w,2502
23
24
  rwkv_ops/rwkv7_kernel/tf_eager_kernel.py,sha256=2t2uf1iNznYpYFlqt9REY0GwGeycYuaJl-4QFk2rJHc,4357
24
25
  rwkv_ops/rwkv7_kernel/torch_op.py,sha256=jw_AvqshTAG4t9-MRqxFQNi_bTzxNbx3lwnMifPk8-8,14070
25
26
  rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt,sha256=Dq4Ea8N2xOEej2jZpEw4MtFjUFgN0PUciejVOCSP-FM,1400
26
27
  rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu,sha256=WePveEdUixaQA51hJUK8Sr7Q7jDTstybEWZczdjuGSo,9690
27
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py,sha256=3lvCKIa9DO7MY3aZNyJM0AyHlQUvDKGsnYVr8MLl7Vg,7998
28
+ rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py,sha256=t2mQ_zGhMeBZClcaLJSwRG4n2MtRhvn9z-vHWK79F6w,7999
28
29
  rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py,sha256=uHsf_1qrtRK62IvhLuzefHGPWpHXmw1p0tqmwlHcptk,346
29
30
  rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py,sha256=2Voq1Bdzn0DFloiLvwINBk7akmxRWIqXIQeyafrJJGg,2138
30
31
  rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py,sha256=rhmglqHIIww7yPzaSBEp9ISxhhxoUbMtV51AUDyhUd8,1425
31
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py,sha256=JDfVZsMb8yMlMN3sKT3i3l3y1YQiQkyUjnSNyan5Fqc,1888
32
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py,sha256=g8b_81rIIjxeknYiklRGnox24rAvEvfKRKT-5nI0Euo,1992
32
+ rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py,sha256=U06dcacmND-y022mN4UmDunfRDxJYWthU_4V8z0HcSs,1888
33
+ rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py,sha256=Nl0migPjRmQopIsysSqt7ZMQ_X-vyblb7e2t-xghzlA,1992
33
34
  rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py,sha256=gQnToi1e1GZCvjWsEdWx6WakUN4Lc0JfaBSsSXYdN84,3369
34
35
  rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py,sha256=4SjQ_zTZvFxsBMeWOx0JGFg9EQ4vllvEx30EcvSZJzI,853
35
36
  rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py,sha256=NoOh2_hA_rdH5bmaNNMAdCgVPfWvQpf-Q8BqF926jrw,667
@@ -40,8 +41,8 @@ rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp,sha256=Wk5QYvIM9m-YJdSEh6zSz
40
41
  rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py,sha256=_u1srIATeoHKlVTVWbWXdpkjaggugl9y-Kx_Y4pYdIY,430
41
42
  rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py,sha256=CWtotXkVvHz4-rkuOqWh6zKy95jwimS9If6SU45ylW0,2103
42
43
  rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py,sha256=4RJbyUTO23OxwH1rGVxeBiBVZKNHpPL_tJ7MFoDCIts,1475
43
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py,sha256=zo6l0ZZUhXFu8wEFD76I0zSqFT9IXFKUKtyeaSwk380,1795
44
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py,sha256=0ucN1U0EDTDqcyTPLLcsAX6FLTf2E_3toOY9p81gWYE,1858
44
+ rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py,sha256=AdbgPd0JRfPZ_poK_XAQ5iV1GsBqDehiN0lf_-_CbUw,1795
45
+ rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py,sha256=o_EbLxqqnzW8_aNduqv_Brd_-SlUU3szfi8Lfn40rqc,1858
45
46
  rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py,sha256=ioPrS0NYQhpFk1j8rAxqtbwpx1CwjJQnrJEBDqVy-As,3283
46
47
  rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py,sha256=54yoa3NpV64H-koURt-hUWpFHhUjwXpGvXPp2_ETCnw,825
47
48
  rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py,sha256=hQkpyaa0eUyB4V3UVks7l1_dHwOrbump0FZILityBKw,611
@@ -58,7 +59,7 @@ rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py,sha256=pRp_z587PrnpgRVpi031IndyjVI
58
59
  rwkv_ops/rwkv7_kernel/triton_kernel/utils.py,sha256=TNGlkwGq4t-TOcdVBk_N_vHPLzMFTu_F0V-O1RprIO4,553
59
60
  rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py,sha256=szaG11q_WmpyhXi6aVWwzizvflCh5wND8wGA_V8afzA,5479
60
61
  rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py,sha256=jbb19DUTHENU2RIOv_T4m_W1eXMqdRqG0XevIkBOhI4,9438
61
- rwkv_ops-0.3.2.dist-info/METADATA,sha256=lkSey3fiZxPrVO05sSb7Q4Q2cAHFgo8-f8RZjmLAWL4,8853
62
- rwkv_ops-0.3.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
63
- rwkv_ops-0.3.2.dist-info/licenses/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
64
- rwkv_ops-0.3.2.dist-info/RECORD,,
62
+ rwkv_ops-0.3.3.dist-info/METADATA,sha256=EU1tq3Ub9WqVpcNYRa1T_I9H2Nx1nNypX8fZWdu7bsM,9011
63
+ rwkv_ops-0.3.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
64
+ rwkv_ops-0.3.3.dist-info/licenses/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
65
+ rwkv_ops-0.3.3.dist-info/RECORD,,