rwkv-ops 0.1.1__py3-none-any.whl → 0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- rwkv_ops/__init__.py +3 -1
- rwkv_ops/rwkv6_kernel/__init__.py +126 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +724 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +86 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +3 -7
- rwkv_ops/rwkv7_kernel/torch_op.py +67 -29
- rwkv_ops-0.2.dist-info/METADATA +258 -0
- {rwkv_ops-0.1.1.dist-info → rwkv_ops-0.2.dist-info}/RECORD +12 -8
- rwkv_ops-0.1.1.dist-info/METADATA +0 -119
- {rwkv_ops-0.1.1.dist-info → rwkv_ops-0.2.dist-info}/LICENSE.txt +0 -0
- {rwkv_ops-0.1.1.dist-info → rwkv_ops-0.2.dist-info}/WHEEL +0 -0
- {rwkv_ops-0.1.1.dist-info → rwkv_ops-0.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
import keras
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class RWKVKernelOperator:
|
|
6
|
+
def __init__(self, head_size, max_sequence_length):
|
|
7
|
+
self.head_size = head_size
|
|
8
|
+
self.max_sequence_length = max_sequence_length
|
|
9
|
+
|
|
10
|
+
def __call__(
|
|
11
|
+
self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
|
|
12
|
+
):
|
|
13
|
+
B, T, C = ops.shape(r)
|
|
14
|
+
assert C % self.head_size == 0
|
|
15
|
+
H = C // self.head_size
|
|
16
|
+
w = ops.reshape(w, [B, T, H, self.head_size, 1])
|
|
17
|
+
k = ops.reshape(k, [B, T, H, self.head_size, 1])
|
|
18
|
+
|
|
19
|
+
v = ops.reshape(v, [B, T, H, 1, self.head_size])
|
|
20
|
+
r = ops.reshape(r, [B, T, H, 1, self.head_size])
|
|
21
|
+
u = ops.reshape(u, [1, H, self.head_size, 1])
|
|
22
|
+
|
|
23
|
+
if init_state is not None:
|
|
24
|
+
assert len(init_state.shape) in [3, 4], (
|
|
25
|
+
"init_state的形状必须为(state_kinds,num_heads,head_size,head_size)"
|
|
26
|
+
)
|
|
27
|
+
if len(init_state.shape) == 3:
|
|
28
|
+
assert init_state.shape == (H, self.head_size, self.head_size), (
|
|
29
|
+
"state_kinds的形状必须为(BatchSize,num_heads,head_size,head_size)"
|
|
30
|
+
)
|
|
31
|
+
init_state = init_state[None, :]
|
|
32
|
+
else:
|
|
33
|
+
assert init_state.shape[1:] == (H, self.head_size, self.head_size), (
|
|
34
|
+
"state_kinds的形状必须为(BatchSize,num_heads,head_size,head_size)"
|
|
35
|
+
)
|
|
36
|
+
state_kinds = init_state.shape[0]
|
|
37
|
+
if state_map is None:
|
|
38
|
+
state_kinds = init_state.shape[0]
|
|
39
|
+
if state_kinds == 1:
|
|
40
|
+
state_map = ops.zeros(shape=(B,), dtype="int32")
|
|
41
|
+
elif state_kinds == B:
|
|
42
|
+
state_map = ops.convert_to_tensor(
|
|
43
|
+
[i for i in range(B)], dtype="int32"
|
|
44
|
+
)
|
|
45
|
+
else:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
"无法为您推断state_map的形状,请您手动指定state_map"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
else:
|
|
51
|
+
if isinstance(state_map, list):
|
|
52
|
+
state_map = ops.convert_to_tensor(state_map, dtype="int32")
|
|
53
|
+
state_map = ops.cast(state_map, "int32")
|
|
54
|
+
assert (state_map >= 0).all() and (state_map < state_kinds).all(), (
|
|
55
|
+
f"请确保state_map的值域为[0, {state_kinds})"
|
|
56
|
+
)
|
|
57
|
+
s = ops.take(init_state, state_map, axis=0)
|
|
58
|
+
|
|
59
|
+
else:
|
|
60
|
+
assert state_map is None
|
|
61
|
+
s = ops.zeros((B, H, self.head_size, self.head_size), dtype=u.dtype)
|
|
62
|
+
|
|
63
|
+
w = ops.exp(-ops.exp(w))
|
|
64
|
+
|
|
65
|
+
def cond(i, k, v, w, r, s, y):
|
|
66
|
+
return i < T
|
|
67
|
+
|
|
68
|
+
def body(i, k, v, w, r, s, y):
|
|
69
|
+
k_t = ops.take(k, i, 1)
|
|
70
|
+
v_t = ops.take(v, i, 1)
|
|
71
|
+
kv_t = k_t @ v_t
|
|
72
|
+
w_t = ops.take(w, i, 1)
|
|
73
|
+
|
|
74
|
+
r_t = ops.take(r, i, 1)
|
|
75
|
+
y_t = r_t @ (u * kv_t + s)
|
|
76
|
+
y_t = ops.reshape(y_t, (B, 1, C))
|
|
77
|
+
s = kv_t + w_t * s
|
|
78
|
+
|
|
79
|
+
y = ops.slice_update(y, [0, i, 0], y_t)
|
|
80
|
+
return i + 1, k, v, w, r, s, y
|
|
81
|
+
|
|
82
|
+
y = ops.zeros([B, T, C], r.dtype)
|
|
83
|
+
i, k, v, w, r, s, y = ops.while_loop(cond, body, (0, k, v, w, r, s, y), T)
|
|
84
|
+
if with_state:
|
|
85
|
+
return y, s
|
|
86
|
+
return y, None
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from torch.utils.cpp_extension import load
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
kernel_dir_name = "torch_kernel"
|
|
7
|
+
|
|
8
|
+
use_rocm = "RWKV_USE_ROCM" in os.environ and os.environ["RWKV_USE_ROCM"] == "1"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RWKVKernelOperator:
|
|
12
|
+
def __init__(self, head_size, max_sequence_length):
|
|
13
|
+
current_dir = os.path.dirname(__file__)
|
|
14
|
+
# current_dir = os.pat
|
|
15
|
+
if use_rocm:
|
|
16
|
+
wkv6_cuda = load(
|
|
17
|
+
name="wkv6",
|
|
18
|
+
sources=[
|
|
19
|
+
os.path.join(current_dir, f"{kernel_dir_name}/wkv6_op.cpp"),
|
|
20
|
+
os.path.join(current_dir, f"{kernel_dir_name}/wkv6_cuda.cu"),
|
|
21
|
+
],
|
|
22
|
+
# verbose=True, extra_cuda_cflags=[f"-D_N_={head_size}", f"-D_T_={max_sequence_length}"])
|
|
23
|
+
verbose=True,
|
|
24
|
+
extra_cuda_cflags=[
|
|
25
|
+
"-fopenmp -ffast-math -munsafe-fp-atomics --gpu-max-threads-per-block=120 -enable-vectorize-compares",
|
|
26
|
+
f"-D_N_={head_size}",
|
|
27
|
+
f"-D_T_={max_sequence_length}",
|
|
28
|
+
],
|
|
29
|
+
)
|
|
30
|
+
else:
|
|
31
|
+
wkv6_cuda = load(
|
|
32
|
+
name="wkv6",
|
|
33
|
+
sources=[
|
|
34
|
+
os.path.join(current_dir, f"{kernel_dir_name}/wkv6_op.cpp"),
|
|
35
|
+
os.path.join(current_dir, f"{kernel_dir_name}/wkv6_cuda.cu"),
|
|
36
|
+
],
|
|
37
|
+
# verbose=True, extra_cuda_cflags=[f"-D_N_={head_size}", f"-D_T_={max_sequence_length}"])
|
|
38
|
+
verbose=True,
|
|
39
|
+
extra_cuda_cflags=[
|
|
40
|
+
"-res-usage",
|
|
41
|
+
"--use_fast_math",
|
|
42
|
+
"-O3",
|
|
43
|
+
"-Xptxas -O3",
|
|
44
|
+
"--extra-device-vectorization",
|
|
45
|
+
f"-D_N_={head_size}",
|
|
46
|
+
f"-D_T_={max_sequence_length}",
|
|
47
|
+
],
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
class RWKV_6(torch.autograd.Function):
|
|
51
|
+
@staticmethod
|
|
52
|
+
def forward(ctx, B, T, C, H, r, k, v, w, u):
|
|
53
|
+
if not isinstance(u, torch.Tensor):
|
|
54
|
+
u = u.value
|
|
55
|
+
with torch.no_grad():
|
|
56
|
+
assert r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
|
|
57
|
+
assert r.dtype in [torch.float32, torch.bfloat16, torch.float16]
|
|
58
|
+
|
|
59
|
+
assert head_size == C // H
|
|
60
|
+
ctx.B = B
|
|
61
|
+
ctx.T = T
|
|
62
|
+
ctx.C = C
|
|
63
|
+
ctx.H = H
|
|
64
|
+
assert r.is_contiguous()
|
|
65
|
+
assert k.is_contiguous()
|
|
66
|
+
assert v.is_contiguous()
|
|
67
|
+
assert w.is_contiguous()
|
|
68
|
+
assert u.is_contiguous()
|
|
69
|
+
ctx.save_for_backward(r, k, v, w, u)
|
|
70
|
+
|
|
71
|
+
y_dtype = r.dtype if r.dtype != torch.float16 else torch.float32
|
|
72
|
+
|
|
73
|
+
y = torch.empty(
|
|
74
|
+
(B, T, C),
|
|
75
|
+
device=r.device,
|
|
76
|
+
dtype=y_dtype,
|
|
77
|
+
memory_format=torch.contiguous_format,
|
|
78
|
+
) # .uniform_(-100, 100)
|
|
79
|
+
|
|
80
|
+
if r.dtype == torch.float32:
|
|
81
|
+
wkv6_cuda.forward_fp32(B, T, C, H, r, k, v, w, u, y)
|
|
82
|
+
elif r.dtype == torch.bfloat16:
|
|
83
|
+
wkv6_cuda.forward_bf16(B, T, C, H, r, k, v, w, u, y)
|
|
84
|
+
else:
|
|
85
|
+
wkv6_cuda.forward_fp16(B, T, C, H, r, k, v, w, u, y)
|
|
86
|
+
return y
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def backward(ctx, gy):
|
|
90
|
+
assert gy.is_cuda
|
|
91
|
+
with torch.no_grad():
|
|
92
|
+
assert gy.dtype in [torch.bfloat16, torch.float32]
|
|
93
|
+
B = ctx.B
|
|
94
|
+
T = ctx.T
|
|
95
|
+
C = ctx.C
|
|
96
|
+
H = ctx.H
|
|
97
|
+
assert gy.is_contiguous()
|
|
98
|
+
r, k, v, w, u = ctx.saved_tensors
|
|
99
|
+
y_dtype = r.dtype if r.dtype != torch.float16 else torch.float32
|
|
100
|
+
|
|
101
|
+
gr = torch.empty(
|
|
102
|
+
(B, T, C),
|
|
103
|
+
device=gy.device,
|
|
104
|
+
requires_grad=False,
|
|
105
|
+
dtype=y_dtype,
|
|
106
|
+
memory_format=torch.contiguous_format,
|
|
107
|
+
) # .uniform_(-100, 100)
|
|
108
|
+
gk = torch.empty(
|
|
109
|
+
(B, T, C),
|
|
110
|
+
device=gy.device,
|
|
111
|
+
requires_grad=False,
|
|
112
|
+
dtype=y_dtype,
|
|
113
|
+
memory_format=torch.contiguous_format,
|
|
114
|
+
) # .uniform_(-100, 100)
|
|
115
|
+
gv = torch.empty(
|
|
116
|
+
(B, T, C),
|
|
117
|
+
device=gy.device,
|
|
118
|
+
requires_grad=False,
|
|
119
|
+
dtype=y_dtype,
|
|
120
|
+
memory_format=torch.contiguous_format,
|
|
121
|
+
) # .uniform_(-100, 100)
|
|
122
|
+
gw = torch.empty(
|
|
123
|
+
(B, T, C),
|
|
124
|
+
device=gy.device,
|
|
125
|
+
requires_grad=False,
|
|
126
|
+
dtype=y_dtype,
|
|
127
|
+
memory_format=torch.contiguous_format,
|
|
128
|
+
) # .uniform_(-100, 100)
|
|
129
|
+
gu = torch.empty(
|
|
130
|
+
(B, C),
|
|
131
|
+
device=gy.device,
|
|
132
|
+
requires_grad=False,
|
|
133
|
+
dtype=y_dtype,
|
|
134
|
+
memory_format=torch.contiguous_format,
|
|
135
|
+
) # .uniform_(-100, 100)
|
|
136
|
+
|
|
137
|
+
if r.dtype == torch.float32:
|
|
138
|
+
wkv6_cuda.backward_fp32(
|
|
139
|
+
B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
|
|
140
|
+
)
|
|
141
|
+
elif r.dtype == torch.bfloat16:
|
|
142
|
+
wkv6_cuda.backward_bf16(
|
|
143
|
+
B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
|
|
144
|
+
)
|
|
145
|
+
else:
|
|
146
|
+
wkv6_cuda.backward_fp16(
|
|
147
|
+
B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
gu = torch.sum(gu, 0).view(H, C // H)
|
|
151
|
+
|
|
152
|
+
return (None, None, None, None, gr, gk, gv, gw, gu)
|
|
153
|
+
|
|
154
|
+
class RWKV_6_with_state:
|
|
155
|
+
@staticmethod
|
|
156
|
+
def apply(B, T, C, H, S, s_map, r, k, v, w, u, s):
|
|
157
|
+
with torch.no_grad():
|
|
158
|
+
assert s_map.dtype == torch.int64, (
|
|
159
|
+
"s_map 必须为None 或者是长度为B的,int64类型的数组。"
|
|
160
|
+
)
|
|
161
|
+
assert (s is None and s_map is None) or (
|
|
162
|
+
s is not None and s_map is not None
|
|
163
|
+
), "init_state与s_map必须同时为None 或者同时不为None"
|
|
164
|
+
assert (
|
|
165
|
+
r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
|
|
166
|
+
and r.dtype in [torch.float16, torch.float32, torch.bfloat16]
|
|
167
|
+
), " r, k, v, w, u 必须为fp16 fp32 bf16中的一种 并且类型相同"
|
|
168
|
+
if r.dtype in [torch.float32, torch.bfloat16]:
|
|
169
|
+
o_dtype = r.dtype
|
|
170
|
+
else:
|
|
171
|
+
o_dtype = torch.float32
|
|
172
|
+
assert (
|
|
173
|
+
r.device
|
|
174
|
+
== k.device
|
|
175
|
+
== v.device
|
|
176
|
+
== w.device
|
|
177
|
+
== u.device
|
|
178
|
+
== s.device
|
|
179
|
+
== s_map.device
|
|
180
|
+
), "what kan i say? 请确保r k v w u s s_map在同一设备上,快去检查!"
|
|
181
|
+
|
|
182
|
+
y = torch.empty(
|
|
183
|
+
(B, T, C),
|
|
184
|
+
device=r.device,
|
|
185
|
+
dtype=o_dtype,
|
|
186
|
+
memory_format=torch.contiguous_format,
|
|
187
|
+
)
|
|
188
|
+
ys = torch.empty(
|
|
189
|
+
(B, H, head_size, head_size),
|
|
190
|
+
device=r.device,
|
|
191
|
+
dtype=o_dtype,
|
|
192
|
+
memory_format=torch.contiguous_format,
|
|
193
|
+
)
|
|
194
|
+
# print(ys)
|
|
195
|
+
if r.dtype == torch.bfloat16:
|
|
196
|
+
wkv6_cuda.forward_with_state_bf16(
|
|
197
|
+
B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
|
|
198
|
+
)
|
|
199
|
+
elif r.dtype == torch.float32:
|
|
200
|
+
wkv6_cuda.forward_with_state_fp32(
|
|
201
|
+
B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
wkv6_cuda.forward_with_state_fp16(
|
|
205
|
+
B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
return y, ys
|
|
209
|
+
|
|
210
|
+
self.head_size = head_size
|
|
211
|
+
self.normal_kernenl = RWKV_6
|
|
212
|
+
self.kernel_with_state = RWKV_6_with_state
|
|
213
|
+
|
|
214
|
+
def __call__(
|
|
215
|
+
self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
|
|
216
|
+
):
|
|
217
|
+
B, T, C = r.shape
|
|
218
|
+
assert C % self.head_size == 0
|
|
219
|
+
H = C // self.head_size
|
|
220
|
+
if not isinstance(u, torch.Tensor):
|
|
221
|
+
u = u.value
|
|
222
|
+
|
|
223
|
+
assert r.is_cuda
|
|
224
|
+
assert k.is_cuda
|
|
225
|
+
assert v.is_cuda
|
|
226
|
+
assert w.is_cuda
|
|
227
|
+
assert u.is_cuda
|
|
228
|
+
|
|
229
|
+
if isinstance(r, torch.Tensor):
|
|
230
|
+
assert r.device == k.device == v.device == w.device == u.device
|
|
231
|
+
else:
|
|
232
|
+
r.get_device() == k.get_device() == v.get_device() == w.get_device() == u.get_device()
|
|
233
|
+
|
|
234
|
+
assert r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
|
|
235
|
+
|
|
236
|
+
if r.dtype in [torch.float32, torch.bfloat16]:
|
|
237
|
+
s_dtype = r.dtype
|
|
238
|
+
else:
|
|
239
|
+
s_dtype = torch.float32
|
|
240
|
+
|
|
241
|
+
is_custom_init = init_state is not None
|
|
242
|
+
|
|
243
|
+
if init_state is not None:
|
|
244
|
+
assert len(init_state.shape) in [3, 4], (
|
|
245
|
+
"init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
|
|
246
|
+
)
|
|
247
|
+
if len(init_state.shape) == 3:
|
|
248
|
+
init_state = init_state[None, :]
|
|
249
|
+
assert (
|
|
250
|
+
init_state.shape[1:] == (H, self.head_size, self.head_size)
|
|
251
|
+
and init_state.shape[0] <= B
|
|
252
|
+
), (
|
|
253
|
+
"init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
assert init_state.dtype == s_dtype, f"init_state的数值类型应为: {s_dtype}"
|
|
257
|
+
assert init_state.device == r.device
|
|
258
|
+
|
|
259
|
+
if state_map is not None:
|
|
260
|
+
if isinstance(state_map, list):
|
|
261
|
+
state_map = torch.tensor(state_map, dtype=torch.int64)
|
|
262
|
+
elif isinstance(state_map, torch.Tensor):
|
|
263
|
+
assert state_map.dtype in [torch.int32, torch.int64], (
|
|
264
|
+
"state_map是一个长度为Batch_Size的int64类型的映射数组"
|
|
265
|
+
)
|
|
266
|
+
state_map = state_map.to(torch.int64)
|
|
267
|
+
assert state_map.shape == (B,), "state_map的shape必须为(Batch_Size,)"
|
|
268
|
+
assert state_map.device == r.deivec
|
|
269
|
+
|
|
270
|
+
if with_state:
|
|
271
|
+
if init_state is None:
|
|
272
|
+
assert state_map is None, (
|
|
273
|
+
"您必须在指定了init_state的情况下才能使用state_map"
|
|
274
|
+
)
|
|
275
|
+
init_state = torch.zeros((0,), device=r.device, dtype=s_dtype)
|
|
276
|
+
state_map = torch.zeros((0,), device=r.device, dtype=torch.int64)
|
|
277
|
+
else:
|
|
278
|
+
n_state = init_state.shape[0]
|
|
279
|
+
if state_map is None:
|
|
280
|
+
assert n_state == 1 or n_state == B, (
|
|
281
|
+
"我无法为您推断state_map的形状,请手动指定。"
|
|
282
|
+
)
|
|
283
|
+
if n_state == 1:
|
|
284
|
+
state_map = torch.tensor(
|
|
285
|
+
[0] * B, dtype=torch.int64, device=r.device
|
|
286
|
+
)
|
|
287
|
+
elif n_state == B:
|
|
288
|
+
state_map = torch.tensor(
|
|
289
|
+
[i for i in range(B)], dtype=torch.int64, device=r.device
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
assert False, "未实现"
|
|
293
|
+
else:
|
|
294
|
+
assert state_map.shape == (B,), "state_map的形状必须为(batch_size,)"
|
|
295
|
+
assert (state_map >= 0).all() and (state_map < n_state).all(), (
|
|
296
|
+
f"state_map的取值范围为[0,{n_state})之间的整数,您的输入显然不满足。"
|
|
297
|
+
)
|
|
298
|
+
# print('state map:',state_map)
|
|
299
|
+
o, ys = self.kernel_with_state.apply(
|
|
300
|
+
B, T, C, H, is_custom_init, state_map, r, k, v, w, u, init_state
|
|
301
|
+
)
|
|
302
|
+
return o, ys
|
|
303
|
+
else:
|
|
304
|
+
o = self.normal_kernenl.apply(B, T, C, H, r, k, v, w, u)
|
|
305
|
+
return o, None
|
|
@@ -25,6 +25,7 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
25
25
|
CHUNK_LEN = 16
|
|
26
26
|
USE_KERNEL = True
|
|
27
27
|
from torch.utils.cpp_extension import load
|
|
28
|
+
import os
|
|
28
29
|
|
|
29
30
|
flags = [
|
|
30
31
|
"-res-usage",
|
|
@@ -40,16 +41,11 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
40
41
|
|
|
41
42
|
# 获取当前文件的目录路径
|
|
42
43
|
current_dir_path = os.path.dirname(current_file_path)
|
|
43
|
-
|
|
44
|
-
# 获取上一级目录的路径
|
|
45
|
-
parent_dir_path = os.path.abspath(
|
|
46
|
-
os.path.join(current_dir_path, os.path.pardir)
|
|
47
|
-
)
|
|
48
44
|
load(
|
|
49
45
|
name="wind_backstepping",
|
|
50
46
|
sources=[
|
|
51
|
-
os.path.join(
|
|
52
|
-
os.path.join(
|
|
47
|
+
os.path.join(current_dir_path, "cuda_kernel/wkv7_cuda.cu"),
|
|
48
|
+
os.path.join(current_dir_path, "cuda_kernel/wkv7_op.cpp"),
|
|
53
49
|
],
|
|
54
50
|
is_python_module=False,
|
|
55
51
|
verbose=True,
|
|
@@ -1,3 +1,15 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
This file implements the forward and backward pass of a chunked delta rule attention mechanism,
|
|
4
|
+
optimized with Triton kernels for GPU acceleration. It includes functions for forward propagation,
|
|
5
|
+
backward gradient computation, and integration with PyTorch's autograd system.
|
|
6
|
+
|
|
7
|
+
该文件实现了分块 Delta Rule 注意力机制的前向与反向传播,
|
|
8
|
+
使用 Triton 内核进行 GPU 加速优化。包括前向传播、梯度反向传播函数,
|
|
9
|
+
并集成了 PyTorch 的自动求导系统。
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
|
|
1
13
|
import warnings
|
|
2
14
|
from typing import Optional
|
|
3
15
|
|
|
@@ -43,6 +55,27 @@ def chunk_dplr_fwd(
|
|
|
43
55
|
output_final_state: bool = True,
|
|
44
56
|
chunk_size: int = 16,
|
|
45
57
|
):
|
|
58
|
+
"""
|
|
59
|
+
Forward pass of chunked delta rule attention.
|
|
60
|
+
|
|
61
|
+
分块 Delta Rule 注意力机制的前向传播。
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
q (torch.Tensor): Queries tensor [B, T, H, K]
|
|
65
|
+
k (torch.Tensor): Keys tensor [B, T, H, K]
|
|
66
|
+
v (torch.Tensor): Values tensor [B, T, H, V]
|
|
67
|
+
a (torch.Tensor): Activations tensor [B, T, H, K]
|
|
68
|
+
b (torch.Tensor): Betas tensor [B, T, H, K]
|
|
69
|
+
gk (torch.Tensor): Log decay tensor [B, T, H, K]
|
|
70
|
+
scale (float): Scale factor for attention scores
|
|
71
|
+
initial_state (Optional[torch.Tensor]): Initial state for recurrent processing
|
|
72
|
+
output_final_state (bool): Whether to return final state
|
|
73
|
+
chunk_size (int): Chunk size for processing
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
o (torch.Tensor): Output tensor [B, T, H, V]
|
|
77
|
+
final_state (Optional[torch.Tensor]): Final state if requested
|
|
78
|
+
"""
|
|
46
79
|
T = q.shape[1]
|
|
47
80
|
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
48
81
|
gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
|
|
@@ -100,6 +133,33 @@ def chunk_dplr_bwd(
|
|
|
100
133
|
dht,
|
|
101
134
|
BT: int = 16,
|
|
102
135
|
):
|
|
136
|
+
"""
|
|
137
|
+
Backward pass of chunked delta rule attention.
|
|
138
|
+
|
|
139
|
+
分块 Delta Rule 注意力机制的反向传播。
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
q (torch.Tensor): Queries tensor [B, T, H, K]
|
|
143
|
+
k (torch.Tensor): Keys tensor [B, T, H, K]
|
|
144
|
+
v (torch.Tensor): Values tensor [B, T, H, V]
|
|
145
|
+
a (torch.Tensor): Activations tensor [B, T, H, K]
|
|
146
|
+
b (torch.Tensor): Betas tensor [B, T, H, K]
|
|
147
|
+
gk (torch.Tensor): Log decay tensor [B, T, H, K]
|
|
148
|
+
initial_state (torch.Tensor): Initial state for recurrent processing
|
|
149
|
+
scale (float): Scale factor for attention scores
|
|
150
|
+
do (torch.Tensor): Gradient of outputs
|
|
151
|
+
dht (torch.Tensor): Gradient of final hidden state
|
|
152
|
+
BT (int): Chunk size for processing
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
dq (torch.Tensor): Gradient of queries
|
|
156
|
+
dk (torch.Tensor): Gradient of keys
|
|
157
|
+
dv (torch.Tensor): Gradient of values
|
|
158
|
+
da (torch.Tensor): Gradient of activations
|
|
159
|
+
db (torch.Tensor): Gradient of betas
|
|
160
|
+
dgk (torch.Tensor): Gradient of log decays
|
|
161
|
+
dh0 (torch.Tensor): Gradient of initial state
|
|
162
|
+
"""
|
|
103
163
|
# ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
|
|
104
164
|
gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
|
|
105
165
|
A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
|
|
@@ -279,6 +339,10 @@ def chunk_dplr_delta_rule(
|
|
|
279
339
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
|
280
340
|
):
|
|
281
341
|
r"""
|
|
342
|
+
Main interface function for chunked delta rule attention.
|
|
343
|
+
|
|
344
|
+
分块 Delta Rule 注意力机制的主要接口函数。
|
|
345
|
+
|
|
282
346
|
Args:
|
|
283
347
|
q (torch.Tensor):
|
|
284
348
|
queries of shape `[B, T, H, K]`
|
|
@@ -361,35 +425,9 @@ def chunk_rwkv7(
|
|
|
361
425
|
output_final_state: bool = True,
|
|
362
426
|
):
|
|
363
427
|
"""
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
k (torch.Tensor):
|
|
368
|
-
k of shape `[B, H, T, K]` .
|
|
369
|
-
v (torch.Tensor):
|
|
370
|
-
v of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
|
|
371
|
-
a (torch.Tensor):
|
|
372
|
-
a of shape `[B, H, T, K]` .
|
|
373
|
-
b (torch.Tensor):
|
|
374
|
-
b of shape `[B, H, T, K]` .
|
|
375
|
-
w (torch.Tensor):
|
|
376
|
-
decay of shape `[B, H, T, K]` , kernel
|
|
377
|
-
will apply log_w = -torch.exp(w)
|
|
378
|
-
log_w (torch.Tensor):
|
|
379
|
-
log decay of shape `[B, H, T, K]` .
|
|
380
|
-
scale (float):
|
|
381
|
-
scale of the attention.
|
|
382
|
-
initial_state (Optional[torch.Tensor]):
|
|
383
|
-
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
|
384
|
-
For equal-length input sequences, `N` equals the batch size `B`.
|
|
385
|
-
Default: `None`.
|
|
386
|
-
output_final_state (Optional[bool]):
|
|
387
|
-
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
|
388
|
-
cu_seqlens (torch.LongTensor):
|
|
389
|
-
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
|
390
|
-
consistent with the FlashAttention API.
|
|
391
|
-
head_first (bool):
|
|
392
|
-
whether to use head first. Recommended to be False to avoid extra transposes.
|
|
428
|
+
Interface function for RWKV-7 attention.
|
|
429
|
+
|
|
430
|
+
RWKV-7 注意力机制的接口函数。
|
|
393
431
|
"""
|
|
394
432
|
|
|
395
433
|
if w is not None:
|