rwkv-ops 0.1.1__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

@@ -0,0 +1,86 @@
1
+ from keras import ops
2
+ import keras
3
+
4
+
5
+ class RWKVKernelOperator:
6
+ def __init__(self, head_size, max_sequence_length):
7
+ self.head_size = head_size
8
+ self.max_sequence_length = max_sequence_length
9
+
10
+ def __call__(
11
+ self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
12
+ ):
13
+ B, T, C = ops.shape(r)
14
+ assert C % self.head_size == 0
15
+ H = C // self.head_size
16
+ w = ops.reshape(w, [B, T, H, self.head_size, 1])
17
+ k = ops.reshape(k, [B, T, H, self.head_size, 1])
18
+
19
+ v = ops.reshape(v, [B, T, H, 1, self.head_size])
20
+ r = ops.reshape(r, [B, T, H, 1, self.head_size])
21
+ u = ops.reshape(u, [1, H, self.head_size, 1])
22
+
23
+ if init_state is not None:
24
+ assert len(init_state.shape) in [3, 4], (
25
+ "init_state的形状必须为(state_kinds,num_heads,head_size,head_size)"
26
+ )
27
+ if len(init_state.shape) == 3:
28
+ assert init_state.shape == (H, self.head_size, self.head_size), (
29
+ "state_kinds的形状必须为(BatchSize,num_heads,head_size,head_size)"
30
+ )
31
+ init_state = init_state[None, :]
32
+ else:
33
+ assert init_state.shape[1:] == (H, self.head_size, self.head_size), (
34
+ "state_kinds的形状必须为(BatchSize,num_heads,head_size,head_size)"
35
+ )
36
+ state_kinds = init_state.shape[0]
37
+ if state_map is None:
38
+ state_kinds = init_state.shape[0]
39
+ if state_kinds == 1:
40
+ state_map = ops.zeros(shape=(B,), dtype="int32")
41
+ elif state_kinds == B:
42
+ state_map = ops.convert_to_tensor(
43
+ [i for i in range(B)], dtype="int32"
44
+ )
45
+ else:
46
+ raise ValueError(
47
+ "无法为您推断state_map的形状,请您手动指定state_map"
48
+ )
49
+
50
+ else:
51
+ if isinstance(state_map, list):
52
+ state_map = ops.convert_to_tensor(state_map, dtype="int32")
53
+ state_map = ops.cast(state_map, "int32")
54
+ assert (state_map >= 0).all() and (state_map < state_kinds).all(), (
55
+ f"请确保state_map的值域为[0, {state_kinds})"
56
+ )
57
+ s = ops.take(init_state, state_map, axis=0)
58
+
59
+ else:
60
+ assert state_map is None
61
+ s = ops.zeros((B, H, self.head_size, self.head_size), dtype=u.dtype)
62
+
63
+ w = ops.exp(-ops.exp(w))
64
+
65
+ def cond(i, k, v, w, r, s, y):
66
+ return i < T
67
+
68
+ def body(i, k, v, w, r, s, y):
69
+ k_t = ops.take(k, i, 1)
70
+ v_t = ops.take(v, i, 1)
71
+ kv_t = k_t @ v_t
72
+ w_t = ops.take(w, i, 1)
73
+
74
+ r_t = ops.take(r, i, 1)
75
+ y_t = r_t @ (u * kv_t + s)
76
+ y_t = ops.reshape(y_t, (B, 1, C))
77
+ s = kv_t + w_t * s
78
+
79
+ y = ops.slice_update(y, [0, i, 0], y_t)
80
+ return i + 1, k, v, w, r, s, y
81
+
82
+ y = ops.zeros([B, T, C], r.dtype)
83
+ i, k, v, w, r, s, y = ops.while_loop(cond, body, (0, k, v, w, r, s, y), T)
84
+ if with_state:
85
+ return y, s
86
+ return y, None
@@ -0,0 +1,305 @@
1
+ import os
2
+ import torch
3
+ from torch.utils.cpp_extension import load
4
+ from keras import ops
5
+
6
+ kernel_dir_name = "torch_kernel"
7
+
8
+ use_rocm = "RWKV_USE_ROCM" in os.environ and os.environ["RWKV_USE_ROCM"] == "1"
9
+
10
+
11
+ class RWKVKernelOperator:
12
+ def __init__(self, head_size, max_sequence_length):
13
+ current_dir = os.path.dirname(__file__)
14
+ # current_dir = os.pat
15
+ if use_rocm:
16
+ wkv6_cuda = load(
17
+ name="wkv6",
18
+ sources=[
19
+ os.path.join(current_dir, f"{kernel_dir_name}/wkv6_op.cpp"),
20
+ os.path.join(current_dir, f"{kernel_dir_name}/wkv6_cuda.cu"),
21
+ ],
22
+ # verbose=True, extra_cuda_cflags=[f"-D_N_={head_size}", f"-D_T_={max_sequence_length}"])
23
+ verbose=True,
24
+ extra_cuda_cflags=[
25
+ "-fopenmp -ffast-math -munsafe-fp-atomics --gpu-max-threads-per-block=120 -enable-vectorize-compares",
26
+ f"-D_N_={head_size}",
27
+ f"-D_T_={max_sequence_length}",
28
+ ],
29
+ )
30
+ else:
31
+ wkv6_cuda = load(
32
+ name="wkv6",
33
+ sources=[
34
+ os.path.join(current_dir, f"{kernel_dir_name}/wkv6_op.cpp"),
35
+ os.path.join(current_dir, f"{kernel_dir_name}/wkv6_cuda.cu"),
36
+ ],
37
+ # verbose=True, extra_cuda_cflags=[f"-D_N_={head_size}", f"-D_T_={max_sequence_length}"])
38
+ verbose=True,
39
+ extra_cuda_cflags=[
40
+ "-res-usage",
41
+ "--use_fast_math",
42
+ "-O3",
43
+ "-Xptxas -O3",
44
+ "--extra-device-vectorization",
45
+ f"-D_N_={head_size}",
46
+ f"-D_T_={max_sequence_length}",
47
+ ],
48
+ )
49
+
50
+ class RWKV_6(torch.autograd.Function):
51
+ @staticmethod
52
+ def forward(ctx, B, T, C, H, r, k, v, w, u):
53
+ if not isinstance(u, torch.Tensor):
54
+ u = u.value
55
+ with torch.no_grad():
56
+ assert r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
57
+ assert r.dtype in [torch.float32, torch.bfloat16, torch.float16]
58
+
59
+ assert head_size == C // H
60
+ ctx.B = B
61
+ ctx.T = T
62
+ ctx.C = C
63
+ ctx.H = H
64
+ assert r.is_contiguous()
65
+ assert k.is_contiguous()
66
+ assert v.is_contiguous()
67
+ assert w.is_contiguous()
68
+ assert u.is_contiguous()
69
+ ctx.save_for_backward(r, k, v, w, u)
70
+
71
+ y_dtype = r.dtype if r.dtype != torch.float16 else torch.float32
72
+
73
+ y = torch.empty(
74
+ (B, T, C),
75
+ device=r.device,
76
+ dtype=y_dtype,
77
+ memory_format=torch.contiguous_format,
78
+ ) # .uniform_(-100, 100)
79
+
80
+ if r.dtype == torch.float32:
81
+ wkv6_cuda.forward_fp32(B, T, C, H, r, k, v, w, u, y)
82
+ elif r.dtype == torch.bfloat16:
83
+ wkv6_cuda.forward_bf16(B, T, C, H, r, k, v, w, u, y)
84
+ else:
85
+ wkv6_cuda.forward_fp16(B, T, C, H, r, k, v, w, u, y)
86
+ return y
87
+
88
+ @staticmethod
89
+ def backward(ctx, gy):
90
+ assert gy.is_cuda
91
+ with torch.no_grad():
92
+ assert gy.dtype in [torch.bfloat16, torch.float32]
93
+ B = ctx.B
94
+ T = ctx.T
95
+ C = ctx.C
96
+ H = ctx.H
97
+ assert gy.is_contiguous()
98
+ r, k, v, w, u = ctx.saved_tensors
99
+ y_dtype = r.dtype if r.dtype != torch.float16 else torch.float32
100
+
101
+ gr = torch.empty(
102
+ (B, T, C),
103
+ device=gy.device,
104
+ requires_grad=False,
105
+ dtype=y_dtype,
106
+ memory_format=torch.contiguous_format,
107
+ ) # .uniform_(-100, 100)
108
+ gk = torch.empty(
109
+ (B, T, C),
110
+ device=gy.device,
111
+ requires_grad=False,
112
+ dtype=y_dtype,
113
+ memory_format=torch.contiguous_format,
114
+ ) # .uniform_(-100, 100)
115
+ gv = torch.empty(
116
+ (B, T, C),
117
+ device=gy.device,
118
+ requires_grad=False,
119
+ dtype=y_dtype,
120
+ memory_format=torch.contiguous_format,
121
+ ) # .uniform_(-100, 100)
122
+ gw = torch.empty(
123
+ (B, T, C),
124
+ device=gy.device,
125
+ requires_grad=False,
126
+ dtype=y_dtype,
127
+ memory_format=torch.contiguous_format,
128
+ ) # .uniform_(-100, 100)
129
+ gu = torch.empty(
130
+ (B, C),
131
+ device=gy.device,
132
+ requires_grad=False,
133
+ dtype=y_dtype,
134
+ memory_format=torch.contiguous_format,
135
+ ) # .uniform_(-100, 100)
136
+
137
+ if r.dtype == torch.float32:
138
+ wkv6_cuda.backward_fp32(
139
+ B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
140
+ )
141
+ elif r.dtype == torch.bfloat16:
142
+ wkv6_cuda.backward_bf16(
143
+ B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
144
+ )
145
+ else:
146
+ wkv6_cuda.backward_fp16(
147
+ B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
148
+ )
149
+
150
+ gu = torch.sum(gu, 0).view(H, C // H)
151
+
152
+ return (None, None, None, None, gr, gk, gv, gw, gu)
153
+
154
+ class RWKV_6_with_state:
155
+ @staticmethod
156
+ def apply(B, T, C, H, S, s_map, r, k, v, w, u, s):
157
+ with torch.no_grad():
158
+ assert s_map.dtype == torch.int64, (
159
+ "s_map 必须为None 或者是长度为B的,int64类型的数组。"
160
+ )
161
+ assert (s is None and s_map is None) or (
162
+ s is not None and s_map is not None
163
+ ), "init_state与s_map必须同时为None 或者同时不为None"
164
+ assert (
165
+ r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
166
+ and r.dtype in [torch.float16, torch.float32, torch.bfloat16]
167
+ ), " r, k, v, w, u 必须为fp16 fp32 bf16中的一种 并且类型相同"
168
+ if r.dtype in [torch.float32, torch.bfloat16]:
169
+ o_dtype = r.dtype
170
+ else:
171
+ o_dtype = torch.float32
172
+ assert (
173
+ r.device
174
+ == k.device
175
+ == v.device
176
+ == w.device
177
+ == u.device
178
+ == s.device
179
+ == s_map.device
180
+ ), "what kan i say? 请确保r k v w u s s_map在同一设备上,快去检查!"
181
+
182
+ y = torch.empty(
183
+ (B, T, C),
184
+ device=r.device,
185
+ dtype=o_dtype,
186
+ memory_format=torch.contiguous_format,
187
+ )
188
+ ys = torch.empty(
189
+ (B, H, head_size, head_size),
190
+ device=r.device,
191
+ dtype=o_dtype,
192
+ memory_format=torch.contiguous_format,
193
+ )
194
+ # print(ys)
195
+ if r.dtype == torch.bfloat16:
196
+ wkv6_cuda.forward_with_state_bf16(
197
+ B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
198
+ )
199
+ elif r.dtype == torch.float32:
200
+ wkv6_cuda.forward_with_state_fp32(
201
+ B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
202
+ )
203
+ else:
204
+ wkv6_cuda.forward_with_state_fp16(
205
+ B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
206
+ )
207
+
208
+ return y, ys
209
+
210
+ self.head_size = head_size
211
+ self.normal_kernenl = RWKV_6
212
+ self.kernel_with_state = RWKV_6_with_state
213
+
214
+ def __call__(
215
+ self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
216
+ ):
217
+ B, T, C = r.shape
218
+ assert C % self.head_size == 0
219
+ H = C // self.head_size
220
+ if not isinstance(u, torch.Tensor):
221
+ u = u.value
222
+
223
+ assert r.is_cuda
224
+ assert k.is_cuda
225
+ assert v.is_cuda
226
+ assert w.is_cuda
227
+ assert u.is_cuda
228
+
229
+ if isinstance(r, torch.Tensor):
230
+ assert r.device == k.device == v.device == w.device == u.device
231
+ else:
232
+ r.get_device() == k.get_device() == v.get_device() == w.get_device() == u.get_device()
233
+
234
+ assert r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
235
+
236
+ if r.dtype in [torch.float32, torch.bfloat16]:
237
+ s_dtype = r.dtype
238
+ else:
239
+ s_dtype = torch.float32
240
+
241
+ is_custom_init = init_state is not None
242
+
243
+ if init_state is not None:
244
+ assert len(init_state.shape) in [3, 4], (
245
+ "init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
246
+ )
247
+ if len(init_state.shape) == 3:
248
+ init_state = init_state[None, :]
249
+ assert (
250
+ init_state.shape[1:] == (H, self.head_size, self.head_size)
251
+ and init_state.shape[0] <= B
252
+ ), (
253
+ "init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
254
+ )
255
+
256
+ assert init_state.dtype == s_dtype, f"init_state的数值类型应为: {s_dtype}"
257
+ assert init_state.device == r.device
258
+
259
+ if state_map is not None:
260
+ if isinstance(state_map, list):
261
+ state_map = torch.tensor(state_map, dtype=torch.int64)
262
+ elif isinstance(state_map, torch.Tensor):
263
+ assert state_map.dtype in [torch.int32, torch.int64], (
264
+ "state_map是一个长度为Batch_Size的int64类型的映射数组"
265
+ )
266
+ state_map = state_map.to(torch.int64)
267
+ assert state_map.shape == (B,), "state_map的shape必须为(Batch_Size,)"
268
+ assert state_map.device == r.deivec
269
+
270
+ if with_state:
271
+ if init_state is None:
272
+ assert state_map is None, (
273
+ "您必须在指定了init_state的情况下才能使用state_map"
274
+ )
275
+ init_state = torch.zeros((0,), device=r.device, dtype=s_dtype)
276
+ state_map = torch.zeros((0,), device=r.device, dtype=torch.int64)
277
+ else:
278
+ n_state = init_state.shape[0]
279
+ if state_map is None:
280
+ assert n_state == 1 or n_state == B, (
281
+ "我无法为您推断state_map的形状,请手动指定。"
282
+ )
283
+ if n_state == 1:
284
+ state_map = torch.tensor(
285
+ [0] * B, dtype=torch.int64, device=r.device
286
+ )
287
+ elif n_state == B:
288
+ state_map = torch.tensor(
289
+ [i for i in range(B)], dtype=torch.int64, device=r.device
290
+ )
291
+ else:
292
+ assert False, "未实现"
293
+ else:
294
+ assert state_map.shape == (B,), "state_map的形状必须为(batch_size,)"
295
+ assert (state_map >= 0).all() and (state_map < n_state).all(), (
296
+ f"state_map的取值范围为[0,{n_state})之间的整数,您的输入显然不满足。"
297
+ )
298
+ # print('state map:',state_map)
299
+ o, ys = self.kernel_with_state.apply(
300
+ B, T, C, H, is_custom_init, state_map, r, k, v, w, u, init_state
301
+ )
302
+ return o, ys
303
+ else:
304
+ o = self.normal_kernenl.apply(B, T, C, H, r, k, v, w, u)
305
+ return o, None
@@ -25,6 +25,7 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
25
25
  CHUNK_LEN = 16
26
26
  USE_KERNEL = True
27
27
  from torch.utils.cpp_extension import load
28
+ import os
28
29
 
29
30
  flags = [
30
31
  "-res-usage",
@@ -40,16 +41,11 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
40
41
 
41
42
  # 获取当前文件的目录路径
42
43
  current_dir_path = os.path.dirname(current_file_path)
43
-
44
- # 获取上一级目录的路径
45
- parent_dir_path = os.path.abspath(
46
- os.path.join(current_dir_path, os.path.pardir)
47
- )
48
44
  load(
49
45
  name="wind_backstepping",
50
46
  sources=[
51
- os.path.join(parent_dir_path, "cuda_kernel/wkv7_cuda.cu"),
52
- os.path.join(parent_dir_path, "cuda_kernel/wkv7_op.cpp"),
47
+ os.path.join(current_dir_path, "cuda_kernel/wkv7_cuda.cu"),
48
+ os.path.join(current_dir_path, "cuda_kernel/wkv7_op.cpp"),
53
49
  ],
54
50
  is_python_module=False,
55
51
  verbose=True,
@@ -1,3 +1,15 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ This file implements the forward and backward pass of a chunked delta rule attention mechanism,
4
+ optimized with Triton kernels for GPU acceleration. It includes functions for forward propagation,
5
+ backward gradient computation, and integration with PyTorch's autograd system.
6
+
7
+ 该文件实现了分块 Delta Rule 注意力机制的前向与反向传播,
8
+ 使用 Triton 内核进行 GPU 加速优化。包括前向传播、梯度反向传播函数,
9
+ 并集成了 PyTorch 的自动求导系统。
10
+
11
+ """
12
+
1
13
  import warnings
2
14
  from typing import Optional
3
15
 
@@ -43,6 +55,27 @@ def chunk_dplr_fwd(
43
55
  output_final_state: bool = True,
44
56
  chunk_size: int = 16,
45
57
  ):
58
+ """
59
+ Forward pass of chunked delta rule attention.
60
+
61
+ 分块 Delta Rule 注意力机制的前向传播。
62
+
63
+ Args:
64
+ q (torch.Tensor): Queries tensor [B, T, H, K]
65
+ k (torch.Tensor): Keys tensor [B, T, H, K]
66
+ v (torch.Tensor): Values tensor [B, T, H, V]
67
+ a (torch.Tensor): Activations tensor [B, T, H, K]
68
+ b (torch.Tensor): Betas tensor [B, T, H, K]
69
+ gk (torch.Tensor): Log decay tensor [B, T, H, K]
70
+ scale (float): Scale factor for attention scores
71
+ initial_state (Optional[torch.Tensor]): Initial state for recurrent processing
72
+ output_final_state (bool): Whether to return final state
73
+ chunk_size (int): Chunk size for processing
74
+
75
+ Returns:
76
+ o (torch.Tensor): Output tensor [B, T, H, V]
77
+ final_state (Optional[torch.Tensor]): Final state if requested
78
+ """
46
79
  T = q.shape[1]
47
80
  BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
48
81
  gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
@@ -100,6 +133,33 @@ def chunk_dplr_bwd(
100
133
  dht,
101
134
  BT: int = 16,
102
135
  ):
136
+ """
137
+ Backward pass of chunked delta rule attention.
138
+
139
+ 分块 Delta Rule 注意力机制的反向传播。
140
+
141
+ Args:
142
+ q (torch.Tensor): Queries tensor [B, T, H, K]
143
+ k (torch.Tensor): Keys tensor [B, T, H, K]
144
+ v (torch.Tensor): Values tensor [B, T, H, V]
145
+ a (torch.Tensor): Activations tensor [B, T, H, K]
146
+ b (torch.Tensor): Betas tensor [B, T, H, K]
147
+ gk (torch.Tensor): Log decay tensor [B, T, H, K]
148
+ initial_state (torch.Tensor): Initial state for recurrent processing
149
+ scale (float): Scale factor for attention scores
150
+ do (torch.Tensor): Gradient of outputs
151
+ dht (torch.Tensor): Gradient of final hidden state
152
+ BT (int): Chunk size for processing
153
+
154
+ Returns:
155
+ dq (torch.Tensor): Gradient of queries
156
+ dk (torch.Tensor): Gradient of keys
157
+ dv (torch.Tensor): Gradient of values
158
+ da (torch.Tensor): Gradient of activations
159
+ db (torch.Tensor): Gradient of betas
160
+ dgk (torch.Tensor): Gradient of log decays
161
+ dh0 (torch.Tensor): Gradient of initial state
162
+ """
103
163
  # ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
104
164
  gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
105
165
  A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
@@ -279,6 +339,10 @@ def chunk_dplr_delta_rule(
279
339
  cu_seqlens: Optional[torch.LongTensor] = None,
280
340
  ):
281
341
  r"""
342
+ Main interface function for chunked delta rule attention.
343
+
344
+ 分块 Delta Rule 注意力机制的主要接口函数。
345
+
282
346
  Args:
283
347
  q (torch.Tensor):
284
348
  queries of shape `[B, T, H, K]`
@@ -361,35 +425,9 @@ def chunk_rwkv7(
361
425
  output_final_state: bool = True,
362
426
  ):
363
427
  """
364
- Args:
365
- r (torch.Tensor):
366
- r of shape `[B, H, T, K]` .
367
- k (torch.Tensor):
368
- k of shape `[B, H, T, K]` .
369
- v (torch.Tensor):
370
- v of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
371
- a (torch.Tensor):
372
- a of shape `[B, H, T, K]` .
373
- b (torch.Tensor):
374
- b of shape `[B, H, T, K]` .
375
- w (torch.Tensor):
376
- decay of shape `[B, H, T, K]` , kernel
377
- will apply log_w = -torch.exp(w)
378
- log_w (torch.Tensor):
379
- log decay of shape `[B, H, T, K]` .
380
- scale (float):
381
- scale of the attention.
382
- initial_state (Optional[torch.Tensor]):
383
- Initial state of shape `[N, H, K, V]` for `N` input sequences.
384
- For equal-length input sequences, `N` equals the batch size `B`.
385
- Default: `None`.
386
- output_final_state (Optional[bool]):
387
- Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
388
- cu_seqlens (torch.LongTensor):
389
- Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
390
- consistent with the FlashAttention API.
391
- head_first (bool):
392
- whether to use head first. Recommended to be False to avoid extra transposes.
428
+ Interface function for RWKV-7 attention.
429
+
430
+ RWKV-7 注意力机制的接口函数。
393
431
  """
394
432
 
395
433
  if w is not None: