rwkv-ops 0.1.0__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

@@ -0,0 +1,86 @@
1
+ from keras import ops
2
+ import keras
3
+
4
+
5
+ class RWKVKernelOperator:
6
+ def __init__(self, head_size, max_sequence_length):
7
+ self.head_size = head_size
8
+ self.max_sequence_length = max_sequence_length
9
+
10
+ def __call__(
11
+ self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
12
+ ):
13
+ B, T, C = ops.shape(r)
14
+ assert C % self.head_size == 0
15
+ H = C // self.head_size
16
+ w = ops.reshape(w, [B, T, H, self.head_size, 1])
17
+ k = ops.reshape(k, [B, T, H, self.head_size, 1])
18
+
19
+ v = ops.reshape(v, [B, T, H, 1, self.head_size])
20
+ r = ops.reshape(r, [B, T, H, 1, self.head_size])
21
+ u = ops.reshape(u, [1, H, self.head_size, 1])
22
+
23
+ if init_state is not None:
24
+ assert len(init_state.shape) in [3, 4], (
25
+ "init_state的形状必须为(state_kinds,num_heads,head_size,head_size)"
26
+ )
27
+ if len(init_state.shape) == 3:
28
+ assert init_state.shape == (H, self.head_size, self.head_size), (
29
+ "state_kinds的形状必须为(BatchSize,num_heads,head_size,head_size)"
30
+ )
31
+ init_state = init_state[None, :]
32
+ else:
33
+ assert init_state.shape[1:] == (H, self.head_size, self.head_size), (
34
+ "state_kinds的形状必须为(BatchSize,num_heads,head_size,head_size)"
35
+ )
36
+ state_kinds = init_state.shape[0]
37
+ if state_map is None:
38
+ state_kinds = init_state.shape[0]
39
+ if state_kinds == 1:
40
+ state_map = ops.zeros(shape=(B,), dtype="int32")
41
+ elif state_kinds == B:
42
+ state_map = ops.convert_to_tensor(
43
+ [i for i in range(B)], dtype="int32"
44
+ )
45
+ else:
46
+ raise ValueError(
47
+ "无法为您推断state_map的形状,请您手动指定state_map"
48
+ )
49
+
50
+ else:
51
+ if isinstance(state_map, list):
52
+ state_map = ops.convert_to_tensor(state_map, dtype="int32")
53
+ state_map = ops.cast(state_map, "int32")
54
+ assert (state_map >= 0).all() and (state_map < state_kinds).all(), (
55
+ f"请确保state_map的值域为[0, {state_kinds})"
56
+ )
57
+ s = ops.take(init_state, state_map, axis=0)
58
+
59
+ else:
60
+ assert state_map is None
61
+ s = ops.zeros((B, H, self.head_size, self.head_size), dtype=u.dtype)
62
+
63
+ w = ops.exp(-ops.exp(w))
64
+
65
+ def cond(i, k, v, w, r, s, y):
66
+ return i < T
67
+
68
+ def body(i, k, v, w, r, s, y):
69
+ k_t = ops.take(k, i, 1)
70
+ v_t = ops.take(v, i, 1)
71
+ kv_t = k_t @ v_t
72
+ w_t = ops.take(w, i, 1)
73
+
74
+ r_t = ops.take(r, i, 1)
75
+ y_t = r_t @ (u * kv_t + s)
76
+ y_t = ops.reshape(y_t, (B, 1, C))
77
+ s = kv_t + w_t * s
78
+
79
+ y = ops.slice_update(y, [0, i, 0], y_t)
80
+ return i + 1, k, v, w, r, s, y
81
+
82
+ y = ops.zeros([B, T, C], r.dtype)
83
+ i, k, v, w, r, s, y = ops.while_loop(cond, body, (0, k, v, w, r, s, y), T)
84
+ if with_state:
85
+ return y, s
86
+ return y, None
@@ -0,0 +1,305 @@
1
+ import os
2
+ import torch
3
+ from torch.utils.cpp_extension import load
4
+ from keras import ops
5
+
6
+ kernel_dir_name = "torch_kernel"
7
+
8
+ use_rocm = "RWKV_USE_ROCM" in os.environ and os.environ["RWKV_USE_ROCM"] == "1"
9
+
10
+
11
+ class RWKVKernelOperator:
12
+ def __init__(self, head_size, max_sequence_length):
13
+ current_dir = os.path.dirname(__file__)
14
+ # current_dir = os.pat
15
+ if use_rocm:
16
+ wkv6_cuda = load(
17
+ name="wkv6",
18
+ sources=[
19
+ os.path.join(current_dir, f"{kernel_dir_name}/wkv6_op.cpp"),
20
+ os.path.join(current_dir, f"{kernel_dir_name}/wkv6_cuda.cu"),
21
+ ],
22
+ # verbose=True, extra_cuda_cflags=[f"-D_N_={head_size}", f"-D_T_={max_sequence_length}"])
23
+ verbose=True,
24
+ extra_cuda_cflags=[
25
+ "-fopenmp -ffast-math -munsafe-fp-atomics --gpu-max-threads-per-block=120 -enable-vectorize-compares",
26
+ f"-D_N_={head_size}",
27
+ f"-D_T_={max_sequence_length}",
28
+ ],
29
+ )
30
+ else:
31
+ wkv6_cuda = load(
32
+ name="wkv6",
33
+ sources=[
34
+ os.path.join(current_dir, f"{kernel_dir_name}/wkv6_op.cpp"),
35
+ os.path.join(current_dir, f"{kernel_dir_name}/wkv6_cuda.cu"),
36
+ ],
37
+ # verbose=True, extra_cuda_cflags=[f"-D_N_={head_size}", f"-D_T_={max_sequence_length}"])
38
+ verbose=True,
39
+ extra_cuda_cflags=[
40
+ "-res-usage",
41
+ "--use_fast_math",
42
+ "-O3",
43
+ "-Xptxas -O3",
44
+ "--extra-device-vectorization",
45
+ f"-D_N_={head_size}",
46
+ f"-D_T_={max_sequence_length}",
47
+ ],
48
+ )
49
+
50
+ class RWKV_6(torch.autograd.Function):
51
+ @staticmethod
52
+ def forward(ctx, B, T, C, H, r, k, v, w, u):
53
+ if not isinstance(u, torch.Tensor):
54
+ u = u.value
55
+ with torch.no_grad():
56
+ assert r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
57
+ assert r.dtype in [torch.float32, torch.bfloat16, torch.float16]
58
+
59
+ assert head_size == C // H
60
+ ctx.B = B
61
+ ctx.T = T
62
+ ctx.C = C
63
+ ctx.H = H
64
+ assert r.is_contiguous()
65
+ assert k.is_contiguous()
66
+ assert v.is_contiguous()
67
+ assert w.is_contiguous()
68
+ assert u.is_contiguous()
69
+ ctx.save_for_backward(r, k, v, w, u)
70
+
71
+ y_dtype = r.dtype if r.dtype != torch.float16 else torch.float32
72
+
73
+ y = torch.empty(
74
+ (B, T, C),
75
+ device=r.device,
76
+ dtype=y_dtype,
77
+ memory_format=torch.contiguous_format,
78
+ ) # .uniform_(-100, 100)
79
+
80
+ if r.dtype == torch.float32:
81
+ wkv6_cuda.forward_fp32(B, T, C, H, r, k, v, w, u, y)
82
+ elif r.dtype == torch.bfloat16:
83
+ wkv6_cuda.forward_bf16(B, T, C, H, r, k, v, w, u, y)
84
+ else:
85
+ wkv6_cuda.forward_fp16(B, T, C, H, r, k, v, w, u, y)
86
+ return y
87
+
88
+ @staticmethod
89
+ def backward(ctx, gy):
90
+ assert gy.is_cuda
91
+ with torch.no_grad():
92
+ assert gy.dtype in [torch.bfloat16, torch.float32]
93
+ B = ctx.B
94
+ T = ctx.T
95
+ C = ctx.C
96
+ H = ctx.H
97
+ assert gy.is_contiguous()
98
+ r, k, v, w, u = ctx.saved_tensors
99
+ y_dtype = r.dtype if r.dtype != torch.float16 else torch.float32
100
+
101
+ gr = torch.empty(
102
+ (B, T, C),
103
+ device=gy.device,
104
+ requires_grad=False,
105
+ dtype=y_dtype,
106
+ memory_format=torch.contiguous_format,
107
+ ) # .uniform_(-100, 100)
108
+ gk = torch.empty(
109
+ (B, T, C),
110
+ device=gy.device,
111
+ requires_grad=False,
112
+ dtype=y_dtype,
113
+ memory_format=torch.contiguous_format,
114
+ ) # .uniform_(-100, 100)
115
+ gv = torch.empty(
116
+ (B, T, C),
117
+ device=gy.device,
118
+ requires_grad=False,
119
+ dtype=y_dtype,
120
+ memory_format=torch.contiguous_format,
121
+ ) # .uniform_(-100, 100)
122
+ gw = torch.empty(
123
+ (B, T, C),
124
+ device=gy.device,
125
+ requires_grad=False,
126
+ dtype=y_dtype,
127
+ memory_format=torch.contiguous_format,
128
+ ) # .uniform_(-100, 100)
129
+ gu = torch.empty(
130
+ (B, C),
131
+ device=gy.device,
132
+ requires_grad=False,
133
+ dtype=y_dtype,
134
+ memory_format=torch.contiguous_format,
135
+ ) # .uniform_(-100, 100)
136
+
137
+ if r.dtype == torch.float32:
138
+ wkv6_cuda.backward_fp32(
139
+ B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
140
+ )
141
+ elif r.dtype == torch.bfloat16:
142
+ wkv6_cuda.backward_bf16(
143
+ B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
144
+ )
145
+ else:
146
+ wkv6_cuda.backward_fp16(
147
+ B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
148
+ )
149
+
150
+ gu = torch.sum(gu, 0).view(H, C // H)
151
+
152
+ return (None, None, None, None, gr, gk, gv, gw, gu)
153
+
154
+ class RWKV_6_with_state:
155
+ @staticmethod
156
+ def apply(B, T, C, H, S, s_map, r, k, v, w, u, s):
157
+ with torch.no_grad():
158
+ assert s_map.dtype == torch.int64, (
159
+ "s_map 必须为None 或者是长度为B的,int64类型的数组。"
160
+ )
161
+ assert (s is None and s_map is None) or (
162
+ s is not None and s_map is not None
163
+ ), "init_state与s_map必须同时为None 或者同时不为None"
164
+ assert (
165
+ r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
166
+ and r.dtype in [torch.float16, torch.float32, torch.bfloat16]
167
+ ), " r, k, v, w, u 必须为fp16 fp32 bf16中的一种 并且类型相同"
168
+ if r.dtype in [torch.float32, torch.bfloat16]:
169
+ o_dtype = r.dtype
170
+ else:
171
+ o_dtype = torch.float32
172
+ assert (
173
+ r.device
174
+ == k.device
175
+ == v.device
176
+ == w.device
177
+ == u.device
178
+ == s.device
179
+ == s_map.device
180
+ ), "what kan i say? 请确保r k v w u s s_map在同一设备上,快去检查!"
181
+
182
+ y = torch.empty(
183
+ (B, T, C),
184
+ device=r.device,
185
+ dtype=o_dtype,
186
+ memory_format=torch.contiguous_format,
187
+ )
188
+ ys = torch.empty(
189
+ (B, H, head_size, head_size),
190
+ device=r.device,
191
+ dtype=o_dtype,
192
+ memory_format=torch.contiguous_format,
193
+ )
194
+ # print(ys)
195
+ if r.dtype == torch.bfloat16:
196
+ wkv6_cuda.forward_with_state_bf16(
197
+ B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
198
+ )
199
+ elif r.dtype == torch.float32:
200
+ wkv6_cuda.forward_with_state_fp32(
201
+ B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
202
+ )
203
+ else:
204
+ wkv6_cuda.forward_with_state_fp16(
205
+ B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
206
+ )
207
+
208
+ return y, ys
209
+
210
+ self.head_size = head_size
211
+ self.normal_kernenl = RWKV_6
212
+ self.kernel_with_state = RWKV_6_with_state
213
+
214
+ def __call__(
215
+ self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
216
+ ):
217
+ B, T, C = r.shape
218
+ assert C % self.head_size == 0
219
+ H = C // self.head_size
220
+ if not isinstance(u, torch.Tensor):
221
+ u = u.value
222
+
223
+ assert r.is_cuda
224
+ assert k.is_cuda
225
+ assert v.is_cuda
226
+ assert w.is_cuda
227
+ assert u.is_cuda
228
+
229
+ if isinstance(r, torch.Tensor):
230
+ assert r.device == k.device == v.device == w.device == u.device
231
+ else:
232
+ r.get_device() == k.get_device() == v.get_device() == w.get_device() == u.get_device()
233
+
234
+ assert r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
235
+
236
+ if r.dtype in [torch.float32, torch.bfloat16]:
237
+ s_dtype = r.dtype
238
+ else:
239
+ s_dtype = torch.float32
240
+
241
+ is_custom_init = init_state is not None
242
+
243
+ if init_state is not None:
244
+ assert len(init_state.shape) in [3, 4], (
245
+ "init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
246
+ )
247
+ if len(init_state.shape) == 3:
248
+ init_state = init_state[None, :]
249
+ assert (
250
+ init_state.shape[1:] == (H, self.head_size, self.head_size)
251
+ and init_state.shape[0] <= B
252
+ ), (
253
+ "init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
254
+ )
255
+
256
+ assert init_state.dtype == s_dtype, f"init_state的数值类型应为: {s_dtype}"
257
+ assert init_state.device == r.device
258
+
259
+ if state_map is not None:
260
+ if isinstance(state_map, list):
261
+ state_map = torch.tensor(state_map, dtype=torch.int64)
262
+ elif isinstance(state_map, torch.Tensor):
263
+ assert state_map.dtype in [torch.int32, torch.int64], (
264
+ "state_map是一个长度为Batch_Size的int64类型的映射数组"
265
+ )
266
+ state_map = state_map.to(torch.int64)
267
+ assert state_map.shape == (B,), "state_map的shape必须为(Batch_Size,)"
268
+ assert state_map.device == r.deivec
269
+
270
+ if with_state:
271
+ if init_state is None:
272
+ assert state_map is None, (
273
+ "您必须在指定了init_state的情况下才能使用state_map"
274
+ )
275
+ init_state = torch.zeros((0,), device=r.device, dtype=s_dtype)
276
+ state_map = torch.zeros((0,), device=r.device, dtype=torch.int64)
277
+ else:
278
+ n_state = init_state.shape[0]
279
+ if state_map is None:
280
+ assert n_state == 1 or n_state == B, (
281
+ "我无法为您推断state_map的形状,请手动指定。"
282
+ )
283
+ if n_state == 1:
284
+ state_map = torch.tensor(
285
+ [0] * B, dtype=torch.int64, device=r.device
286
+ )
287
+ elif n_state == B:
288
+ state_map = torch.tensor(
289
+ [i for i in range(B)], dtype=torch.int64, device=r.device
290
+ )
291
+ else:
292
+ assert False, "未实现"
293
+ else:
294
+ assert state_map.shape == (B,), "state_map的形状必须为(batch_size,)"
295
+ assert (state_map >= 0).all() and (state_map < n_state).all(), (
296
+ f"state_map的取值范围为[0,{n_state})之间的整数,您的输入显然不满足。"
297
+ )
298
+ # print('state map:',state_map)
299
+ o, ys = self.kernel_with_state.apply(
300
+ B, T, C, H, is_custom_init, state_map, r, k, v, w, u, init_state
301
+ )
302
+ return o, ys
303
+ else:
304
+ o = self.normal_kernenl.apply(B, T, C, H, r, k, v, w, u)
305
+ return o, None
@@ -1,5 +1,6 @@
1
1
  import keras
2
2
  from distutils.util import strtobool
3
+ import os
3
4
  from keras import ops
4
5
 
5
6
 
@@ -19,6 +20,7 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
19
20
  from .torch_op import generalized_delta_rule
20
21
 
21
22
  USE_KERNEL = True
23
+
22
24
  elif KERNEL_TYPE.lower() == "cuda":
23
25
  CHUNK_LEN = 16
24
26
  USE_KERNEL = True
@@ -39,16 +41,11 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
39
41
 
40
42
  # 获取当前文件的目录路径
41
43
  current_dir_path = os.path.dirname(current_file_path)
42
-
43
- # 获取上一级目录的路径
44
- parent_dir_path = os.path.abspath(
45
- os.path.join(current_dir_path, os.path.pardir)
46
- )
47
44
  load(
48
45
  name="wind_backstepping",
49
46
  sources=[
50
- os.path.join(parent_dir_path, "cuda_kernel/wkv7_cuda.cu"),
51
- os.path.join(parent_dir_path, "cuda_kernel/wkv7_op.cpp"),
47
+ os.path.join(current_dir_path, "cuda_kernel/wkv7_cuda.cu"),
48
+ os.path.join(current_dir_path, "cuda_kernel/wkv7_op.cpp"),
52
49
  ],
53
50
  is_python_module=False,
54
51
  verbose=True,
@@ -137,11 +134,17 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
137
134
  from jax.lib import xla_bridge
138
135
  import jax
139
136
  import os
137
+ import logging
140
138
 
139
+ logging.basicConfig(level=logging.ERROR)
140
+ os.environ["TRITON_LOG_LEVEL"] = "ERROR" # 只显示错误级别的日志
141
+ os.environ["TRITON_DISABLE_AUTOTUNE"] = "1" # 禁用自动调优日志
142
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # 禁用自动调优日志
141
143
  if (
142
144
  xla_bridge.get_backend().platform == "gpu"
143
145
  and KERNEL_TYPE.lower() == "triton"
144
146
  ):
147
+ os.environ["JAX_LOG_COMPUTATION"] = "0"
145
148
  from .jax_op import generalized_delta_rule
146
149
 
147
150
  USE_KERNEL = True
@@ -5,6 +5,8 @@ import functools
5
5
  import triton
6
6
  import jax
7
7
  import jax.numpy as jnp
8
+ from enum import Enum
9
+ import contextlib
8
10
 
9
11
 
10
12
  @lru_cache(maxsize=None)
@@ -82,9 +84,6 @@ def is_triton_shared_mem_enough(
82
84
 
83
85
  device_capacity = is_triton_shared_mem_enough()
84
86
 
85
- from enum import Enum
86
- import contextlib
87
-
88
87
 
89
88
  def _cpu_device_warning():
90
89
  import warnings
@@ -6,6 +6,8 @@ from typing import Literal
6
6
  import triton
7
7
  from packaging import version
8
8
  import torch
9
+ from enum import Enum
10
+ import contextlib
9
11
 
10
12
 
11
13
  @lru_cache(maxsize=None)
@@ -105,8 +107,6 @@ def is_triton_shared_mem_enough(
105
107
 
106
108
 
107
109
  device_capacity = is_triton_shared_mem_enough()
108
- from enum import Enum
109
- import contextlib
110
110
 
111
111
 
112
112
  def _cpu_device_warning():
@@ -1,5 +1,5 @@
1
1
  # -*- coding: utf-8 -*-
2
- # Copyright (c) 2023-2025,Qingwen Lin
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
3
 
4
4
  from typing import Tuple
5
5
 
@@ -7,7 +7,7 @@ import jax_triton as jt
7
7
  import jax
8
8
  import triton
9
9
 
10
- from ..get_torch_devices_info import check_shared_mem
10
+ from ..get_jax_devices_info import check_shared_mem
11
11
  from ..triton_kernel.chunk_o_bwd import *
12
12
 
13
13
 
@@ -104,7 +104,7 @@ def chunk_dplr_bwd_o(
104
104
  out_shape=out_shapes,
105
105
  grid=grid,
106
106
  )
107
- return dq, dk, dw, db, dgk_last
107
+ return (dq, dk, dw, db, dgk_last)
108
108
 
109
109
 
110
110
  def chunk_dplr_bwd_dAu(
@@ -223,7 +223,6 @@ def chunk_dplr_bwd(
223
223
 
224
224
  dv = chunk_dplr_bwd_dv(A_qk=A_qk, kg=kg, do=do, dh=dh, chunk_size=BT)
225
225
  del A_qk
226
-
227
226
  dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
228
227
  k=kg,
229
228
  b=bg,
@@ -239,7 +238,6 @@ def chunk_dplr_bwd(
239
238
  scale=scale,
240
239
  )
241
240
  del v_new
242
-
243
241
  dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
244
242
  A_ab_inv=A_ab_inv,
245
243
  A_ak=A_ak,
@@ -104,7 +104,7 @@ def chunk_dplr_bwd_o(
104
104
  BK=BK,
105
105
  BV=BV,
106
106
  )
107
- return dq, dk, dw, db, dgk_last
107
+ return (dq, dk, dw, db, dgk_last)
108
108
 
109
109
 
110
110
  def chunk_dplr_bwd_dAu(