rwkv-ops 0.1.0__py3-none-any.whl → 0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- rwkv_ops/__init__.py +3 -0
- rwkv_ops/rwkv6_kernel/__init__.py +126 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +724 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +86 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +10 -7
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +2 -3
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +2 -2
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +3 -3
- rwkv_ops/rwkv7_kernel/jax_op.py +0 -2
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +1 -1
- rwkv_ops/rwkv7_kernel/torch_op.py +20 -53
- rwkv_ops-0.2.dist-info/METADATA +258 -0
- {rwkv_ops-0.1.0.dist-info → rwkv_ops-0.2.dist-info}/RECORD +17 -13
- rwkv_ops-0.1.0.dist-info/METADATA +0 -118
- {rwkv_ops-0.1.0.dist-info → rwkv_ops-0.2.dist-info}/LICENSE.txt +0 -0
- {rwkv_ops-0.1.0.dist-info → rwkv_ops-0.2.dist-info}/WHEEL +0 -0
- {rwkv_ops-0.1.0.dist-info → rwkv_ops-0.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
import keras
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class RWKVKernelOperator:
|
|
6
|
+
def __init__(self, head_size, max_sequence_length):
|
|
7
|
+
self.head_size = head_size
|
|
8
|
+
self.max_sequence_length = max_sequence_length
|
|
9
|
+
|
|
10
|
+
def __call__(
|
|
11
|
+
self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
|
|
12
|
+
):
|
|
13
|
+
B, T, C = ops.shape(r)
|
|
14
|
+
assert C % self.head_size == 0
|
|
15
|
+
H = C // self.head_size
|
|
16
|
+
w = ops.reshape(w, [B, T, H, self.head_size, 1])
|
|
17
|
+
k = ops.reshape(k, [B, T, H, self.head_size, 1])
|
|
18
|
+
|
|
19
|
+
v = ops.reshape(v, [B, T, H, 1, self.head_size])
|
|
20
|
+
r = ops.reshape(r, [B, T, H, 1, self.head_size])
|
|
21
|
+
u = ops.reshape(u, [1, H, self.head_size, 1])
|
|
22
|
+
|
|
23
|
+
if init_state is not None:
|
|
24
|
+
assert len(init_state.shape) in [3, 4], (
|
|
25
|
+
"init_state的形状必须为(state_kinds,num_heads,head_size,head_size)"
|
|
26
|
+
)
|
|
27
|
+
if len(init_state.shape) == 3:
|
|
28
|
+
assert init_state.shape == (H, self.head_size, self.head_size), (
|
|
29
|
+
"state_kinds的形状必须为(BatchSize,num_heads,head_size,head_size)"
|
|
30
|
+
)
|
|
31
|
+
init_state = init_state[None, :]
|
|
32
|
+
else:
|
|
33
|
+
assert init_state.shape[1:] == (H, self.head_size, self.head_size), (
|
|
34
|
+
"state_kinds的形状必须为(BatchSize,num_heads,head_size,head_size)"
|
|
35
|
+
)
|
|
36
|
+
state_kinds = init_state.shape[0]
|
|
37
|
+
if state_map is None:
|
|
38
|
+
state_kinds = init_state.shape[0]
|
|
39
|
+
if state_kinds == 1:
|
|
40
|
+
state_map = ops.zeros(shape=(B,), dtype="int32")
|
|
41
|
+
elif state_kinds == B:
|
|
42
|
+
state_map = ops.convert_to_tensor(
|
|
43
|
+
[i for i in range(B)], dtype="int32"
|
|
44
|
+
)
|
|
45
|
+
else:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
"无法为您推断state_map的形状,请您手动指定state_map"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
else:
|
|
51
|
+
if isinstance(state_map, list):
|
|
52
|
+
state_map = ops.convert_to_tensor(state_map, dtype="int32")
|
|
53
|
+
state_map = ops.cast(state_map, "int32")
|
|
54
|
+
assert (state_map >= 0).all() and (state_map < state_kinds).all(), (
|
|
55
|
+
f"请确保state_map的值域为[0, {state_kinds})"
|
|
56
|
+
)
|
|
57
|
+
s = ops.take(init_state, state_map, axis=0)
|
|
58
|
+
|
|
59
|
+
else:
|
|
60
|
+
assert state_map is None
|
|
61
|
+
s = ops.zeros((B, H, self.head_size, self.head_size), dtype=u.dtype)
|
|
62
|
+
|
|
63
|
+
w = ops.exp(-ops.exp(w))
|
|
64
|
+
|
|
65
|
+
def cond(i, k, v, w, r, s, y):
|
|
66
|
+
return i < T
|
|
67
|
+
|
|
68
|
+
def body(i, k, v, w, r, s, y):
|
|
69
|
+
k_t = ops.take(k, i, 1)
|
|
70
|
+
v_t = ops.take(v, i, 1)
|
|
71
|
+
kv_t = k_t @ v_t
|
|
72
|
+
w_t = ops.take(w, i, 1)
|
|
73
|
+
|
|
74
|
+
r_t = ops.take(r, i, 1)
|
|
75
|
+
y_t = r_t @ (u * kv_t + s)
|
|
76
|
+
y_t = ops.reshape(y_t, (B, 1, C))
|
|
77
|
+
s = kv_t + w_t * s
|
|
78
|
+
|
|
79
|
+
y = ops.slice_update(y, [0, i, 0], y_t)
|
|
80
|
+
return i + 1, k, v, w, r, s, y
|
|
81
|
+
|
|
82
|
+
y = ops.zeros([B, T, C], r.dtype)
|
|
83
|
+
i, k, v, w, r, s, y = ops.while_loop(cond, body, (0, k, v, w, r, s, y), T)
|
|
84
|
+
if with_state:
|
|
85
|
+
return y, s
|
|
86
|
+
return y, None
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
from torch.utils.cpp_extension import load
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
kernel_dir_name = "torch_kernel"
|
|
7
|
+
|
|
8
|
+
use_rocm = "RWKV_USE_ROCM" in os.environ and os.environ["RWKV_USE_ROCM"] == "1"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RWKVKernelOperator:
|
|
12
|
+
def __init__(self, head_size, max_sequence_length):
|
|
13
|
+
current_dir = os.path.dirname(__file__)
|
|
14
|
+
# current_dir = os.pat
|
|
15
|
+
if use_rocm:
|
|
16
|
+
wkv6_cuda = load(
|
|
17
|
+
name="wkv6",
|
|
18
|
+
sources=[
|
|
19
|
+
os.path.join(current_dir, f"{kernel_dir_name}/wkv6_op.cpp"),
|
|
20
|
+
os.path.join(current_dir, f"{kernel_dir_name}/wkv6_cuda.cu"),
|
|
21
|
+
],
|
|
22
|
+
# verbose=True, extra_cuda_cflags=[f"-D_N_={head_size}", f"-D_T_={max_sequence_length}"])
|
|
23
|
+
verbose=True,
|
|
24
|
+
extra_cuda_cflags=[
|
|
25
|
+
"-fopenmp -ffast-math -munsafe-fp-atomics --gpu-max-threads-per-block=120 -enable-vectorize-compares",
|
|
26
|
+
f"-D_N_={head_size}",
|
|
27
|
+
f"-D_T_={max_sequence_length}",
|
|
28
|
+
],
|
|
29
|
+
)
|
|
30
|
+
else:
|
|
31
|
+
wkv6_cuda = load(
|
|
32
|
+
name="wkv6",
|
|
33
|
+
sources=[
|
|
34
|
+
os.path.join(current_dir, f"{kernel_dir_name}/wkv6_op.cpp"),
|
|
35
|
+
os.path.join(current_dir, f"{kernel_dir_name}/wkv6_cuda.cu"),
|
|
36
|
+
],
|
|
37
|
+
# verbose=True, extra_cuda_cflags=[f"-D_N_={head_size}", f"-D_T_={max_sequence_length}"])
|
|
38
|
+
verbose=True,
|
|
39
|
+
extra_cuda_cflags=[
|
|
40
|
+
"-res-usage",
|
|
41
|
+
"--use_fast_math",
|
|
42
|
+
"-O3",
|
|
43
|
+
"-Xptxas -O3",
|
|
44
|
+
"--extra-device-vectorization",
|
|
45
|
+
f"-D_N_={head_size}",
|
|
46
|
+
f"-D_T_={max_sequence_length}",
|
|
47
|
+
],
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
class RWKV_6(torch.autograd.Function):
|
|
51
|
+
@staticmethod
|
|
52
|
+
def forward(ctx, B, T, C, H, r, k, v, w, u):
|
|
53
|
+
if not isinstance(u, torch.Tensor):
|
|
54
|
+
u = u.value
|
|
55
|
+
with torch.no_grad():
|
|
56
|
+
assert r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
|
|
57
|
+
assert r.dtype in [torch.float32, torch.bfloat16, torch.float16]
|
|
58
|
+
|
|
59
|
+
assert head_size == C // H
|
|
60
|
+
ctx.B = B
|
|
61
|
+
ctx.T = T
|
|
62
|
+
ctx.C = C
|
|
63
|
+
ctx.H = H
|
|
64
|
+
assert r.is_contiguous()
|
|
65
|
+
assert k.is_contiguous()
|
|
66
|
+
assert v.is_contiguous()
|
|
67
|
+
assert w.is_contiguous()
|
|
68
|
+
assert u.is_contiguous()
|
|
69
|
+
ctx.save_for_backward(r, k, v, w, u)
|
|
70
|
+
|
|
71
|
+
y_dtype = r.dtype if r.dtype != torch.float16 else torch.float32
|
|
72
|
+
|
|
73
|
+
y = torch.empty(
|
|
74
|
+
(B, T, C),
|
|
75
|
+
device=r.device,
|
|
76
|
+
dtype=y_dtype,
|
|
77
|
+
memory_format=torch.contiguous_format,
|
|
78
|
+
) # .uniform_(-100, 100)
|
|
79
|
+
|
|
80
|
+
if r.dtype == torch.float32:
|
|
81
|
+
wkv6_cuda.forward_fp32(B, T, C, H, r, k, v, w, u, y)
|
|
82
|
+
elif r.dtype == torch.bfloat16:
|
|
83
|
+
wkv6_cuda.forward_bf16(B, T, C, H, r, k, v, w, u, y)
|
|
84
|
+
else:
|
|
85
|
+
wkv6_cuda.forward_fp16(B, T, C, H, r, k, v, w, u, y)
|
|
86
|
+
return y
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def backward(ctx, gy):
|
|
90
|
+
assert gy.is_cuda
|
|
91
|
+
with torch.no_grad():
|
|
92
|
+
assert gy.dtype in [torch.bfloat16, torch.float32]
|
|
93
|
+
B = ctx.B
|
|
94
|
+
T = ctx.T
|
|
95
|
+
C = ctx.C
|
|
96
|
+
H = ctx.H
|
|
97
|
+
assert gy.is_contiguous()
|
|
98
|
+
r, k, v, w, u = ctx.saved_tensors
|
|
99
|
+
y_dtype = r.dtype if r.dtype != torch.float16 else torch.float32
|
|
100
|
+
|
|
101
|
+
gr = torch.empty(
|
|
102
|
+
(B, T, C),
|
|
103
|
+
device=gy.device,
|
|
104
|
+
requires_grad=False,
|
|
105
|
+
dtype=y_dtype,
|
|
106
|
+
memory_format=torch.contiguous_format,
|
|
107
|
+
) # .uniform_(-100, 100)
|
|
108
|
+
gk = torch.empty(
|
|
109
|
+
(B, T, C),
|
|
110
|
+
device=gy.device,
|
|
111
|
+
requires_grad=False,
|
|
112
|
+
dtype=y_dtype,
|
|
113
|
+
memory_format=torch.contiguous_format,
|
|
114
|
+
) # .uniform_(-100, 100)
|
|
115
|
+
gv = torch.empty(
|
|
116
|
+
(B, T, C),
|
|
117
|
+
device=gy.device,
|
|
118
|
+
requires_grad=False,
|
|
119
|
+
dtype=y_dtype,
|
|
120
|
+
memory_format=torch.contiguous_format,
|
|
121
|
+
) # .uniform_(-100, 100)
|
|
122
|
+
gw = torch.empty(
|
|
123
|
+
(B, T, C),
|
|
124
|
+
device=gy.device,
|
|
125
|
+
requires_grad=False,
|
|
126
|
+
dtype=y_dtype,
|
|
127
|
+
memory_format=torch.contiguous_format,
|
|
128
|
+
) # .uniform_(-100, 100)
|
|
129
|
+
gu = torch.empty(
|
|
130
|
+
(B, C),
|
|
131
|
+
device=gy.device,
|
|
132
|
+
requires_grad=False,
|
|
133
|
+
dtype=y_dtype,
|
|
134
|
+
memory_format=torch.contiguous_format,
|
|
135
|
+
) # .uniform_(-100, 100)
|
|
136
|
+
|
|
137
|
+
if r.dtype == torch.float32:
|
|
138
|
+
wkv6_cuda.backward_fp32(
|
|
139
|
+
B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
|
|
140
|
+
)
|
|
141
|
+
elif r.dtype == torch.bfloat16:
|
|
142
|
+
wkv6_cuda.backward_bf16(
|
|
143
|
+
B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
|
|
144
|
+
)
|
|
145
|
+
else:
|
|
146
|
+
wkv6_cuda.backward_fp16(
|
|
147
|
+
B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
gu = torch.sum(gu, 0).view(H, C // H)
|
|
151
|
+
|
|
152
|
+
return (None, None, None, None, gr, gk, gv, gw, gu)
|
|
153
|
+
|
|
154
|
+
class RWKV_6_with_state:
|
|
155
|
+
@staticmethod
|
|
156
|
+
def apply(B, T, C, H, S, s_map, r, k, v, w, u, s):
|
|
157
|
+
with torch.no_grad():
|
|
158
|
+
assert s_map.dtype == torch.int64, (
|
|
159
|
+
"s_map 必须为None 或者是长度为B的,int64类型的数组。"
|
|
160
|
+
)
|
|
161
|
+
assert (s is None and s_map is None) or (
|
|
162
|
+
s is not None and s_map is not None
|
|
163
|
+
), "init_state与s_map必须同时为None 或者同时不为None"
|
|
164
|
+
assert (
|
|
165
|
+
r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
|
|
166
|
+
and r.dtype in [torch.float16, torch.float32, torch.bfloat16]
|
|
167
|
+
), " r, k, v, w, u 必须为fp16 fp32 bf16中的一种 并且类型相同"
|
|
168
|
+
if r.dtype in [torch.float32, torch.bfloat16]:
|
|
169
|
+
o_dtype = r.dtype
|
|
170
|
+
else:
|
|
171
|
+
o_dtype = torch.float32
|
|
172
|
+
assert (
|
|
173
|
+
r.device
|
|
174
|
+
== k.device
|
|
175
|
+
== v.device
|
|
176
|
+
== w.device
|
|
177
|
+
== u.device
|
|
178
|
+
== s.device
|
|
179
|
+
== s_map.device
|
|
180
|
+
), "what kan i say? 请确保r k v w u s s_map在同一设备上,快去检查!"
|
|
181
|
+
|
|
182
|
+
y = torch.empty(
|
|
183
|
+
(B, T, C),
|
|
184
|
+
device=r.device,
|
|
185
|
+
dtype=o_dtype,
|
|
186
|
+
memory_format=torch.contiguous_format,
|
|
187
|
+
)
|
|
188
|
+
ys = torch.empty(
|
|
189
|
+
(B, H, head_size, head_size),
|
|
190
|
+
device=r.device,
|
|
191
|
+
dtype=o_dtype,
|
|
192
|
+
memory_format=torch.contiguous_format,
|
|
193
|
+
)
|
|
194
|
+
# print(ys)
|
|
195
|
+
if r.dtype == torch.bfloat16:
|
|
196
|
+
wkv6_cuda.forward_with_state_bf16(
|
|
197
|
+
B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
|
|
198
|
+
)
|
|
199
|
+
elif r.dtype == torch.float32:
|
|
200
|
+
wkv6_cuda.forward_with_state_fp32(
|
|
201
|
+
B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
wkv6_cuda.forward_with_state_fp16(
|
|
205
|
+
B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
return y, ys
|
|
209
|
+
|
|
210
|
+
self.head_size = head_size
|
|
211
|
+
self.normal_kernenl = RWKV_6
|
|
212
|
+
self.kernel_with_state = RWKV_6_with_state
|
|
213
|
+
|
|
214
|
+
def __call__(
|
|
215
|
+
self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
|
|
216
|
+
):
|
|
217
|
+
B, T, C = r.shape
|
|
218
|
+
assert C % self.head_size == 0
|
|
219
|
+
H = C // self.head_size
|
|
220
|
+
if not isinstance(u, torch.Tensor):
|
|
221
|
+
u = u.value
|
|
222
|
+
|
|
223
|
+
assert r.is_cuda
|
|
224
|
+
assert k.is_cuda
|
|
225
|
+
assert v.is_cuda
|
|
226
|
+
assert w.is_cuda
|
|
227
|
+
assert u.is_cuda
|
|
228
|
+
|
|
229
|
+
if isinstance(r, torch.Tensor):
|
|
230
|
+
assert r.device == k.device == v.device == w.device == u.device
|
|
231
|
+
else:
|
|
232
|
+
r.get_device() == k.get_device() == v.get_device() == w.get_device() == u.get_device()
|
|
233
|
+
|
|
234
|
+
assert r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
|
|
235
|
+
|
|
236
|
+
if r.dtype in [torch.float32, torch.bfloat16]:
|
|
237
|
+
s_dtype = r.dtype
|
|
238
|
+
else:
|
|
239
|
+
s_dtype = torch.float32
|
|
240
|
+
|
|
241
|
+
is_custom_init = init_state is not None
|
|
242
|
+
|
|
243
|
+
if init_state is not None:
|
|
244
|
+
assert len(init_state.shape) in [3, 4], (
|
|
245
|
+
"init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
|
|
246
|
+
)
|
|
247
|
+
if len(init_state.shape) == 3:
|
|
248
|
+
init_state = init_state[None, :]
|
|
249
|
+
assert (
|
|
250
|
+
init_state.shape[1:] == (H, self.head_size, self.head_size)
|
|
251
|
+
and init_state.shape[0] <= B
|
|
252
|
+
), (
|
|
253
|
+
"init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
assert init_state.dtype == s_dtype, f"init_state的数值类型应为: {s_dtype}"
|
|
257
|
+
assert init_state.device == r.device
|
|
258
|
+
|
|
259
|
+
if state_map is not None:
|
|
260
|
+
if isinstance(state_map, list):
|
|
261
|
+
state_map = torch.tensor(state_map, dtype=torch.int64)
|
|
262
|
+
elif isinstance(state_map, torch.Tensor):
|
|
263
|
+
assert state_map.dtype in [torch.int32, torch.int64], (
|
|
264
|
+
"state_map是一个长度为Batch_Size的int64类型的映射数组"
|
|
265
|
+
)
|
|
266
|
+
state_map = state_map.to(torch.int64)
|
|
267
|
+
assert state_map.shape == (B,), "state_map的shape必须为(Batch_Size,)"
|
|
268
|
+
assert state_map.device == r.deivec
|
|
269
|
+
|
|
270
|
+
if with_state:
|
|
271
|
+
if init_state is None:
|
|
272
|
+
assert state_map is None, (
|
|
273
|
+
"您必须在指定了init_state的情况下才能使用state_map"
|
|
274
|
+
)
|
|
275
|
+
init_state = torch.zeros((0,), device=r.device, dtype=s_dtype)
|
|
276
|
+
state_map = torch.zeros((0,), device=r.device, dtype=torch.int64)
|
|
277
|
+
else:
|
|
278
|
+
n_state = init_state.shape[0]
|
|
279
|
+
if state_map is None:
|
|
280
|
+
assert n_state == 1 or n_state == B, (
|
|
281
|
+
"我无法为您推断state_map的形状,请手动指定。"
|
|
282
|
+
)
|
|
283
|
+
if n_state == 1:
|
|
284
|
+
state_map = torch.tensor(
|
|
285
|
+
[0] * B, dtype=torch.int64, device=r.device
|
|
286
|
+
)
|
|
287
|
+
elif n_state == B:
|
|
288
|
+
state_map = torch.tensor(
|
|
289
|
+
[i for i in range(B)], dtype=torch.int64, device=r.device
|
|
290
|
+
)
|
|
291
|
+
else:
|
|
292
|
+
assert False, "未实现"
|
|
293
|
+
else:
|
|
294
|
+
assert state_map.shape == (B,), "state_map的形状必须为(batch_size,)"
|
|
295
|
+
assert (state_map >= 0).all() and (state_map < n_state).all(), (
|
|
296
|
+
f"state_map的取值范围为[0,{n_state})之间的整数,您的输入显然不满足。"
|
|
297
|
+
)
|
|
298
|
+
# print('state map:',state_map)
|
|
299
|
+
o, ys = self.kernel_with_state.apply(
|
|
300
|
+
B, T, C, H, is_custom_init, state_map, r, k, v, w, u, init_state
|
|
301
|
+
)
|
|
302
|
+
return o, ys
|
|
303
|
+
else:
|
|
304
|
+
o = self.normal_kernenl.apply(B, T, C, H, r, k, v, w, u)
|
|
305
|
+
return o, None
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import keras
|
|
2
2
|
from distutils.util import strtobool
|
|
3
|
+
import os
|
|
3
4
|
from keras import ops
|
|
4
5
|
|
|
5
6
|
|
|
@@ -19,6 +20,7 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
19
20
|
from .torch_op import generalized_delta_rule
|
|
20
21
|
|
|
21
22
|
USE_KERNEL = True
|
|
23
|
+
|
|
22
24
|
elif KERNEL_TYPE.lower() == "cuda":
|
|
23
25
|
CHUNK_LEN = 16
|
|
24
26
|
USE_KERNEL = True
|
|
@@ -39,16 +41,11 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
39
41
|
|
|
40
42
|
# 获取当前文件的目录路径
|
|
41
43
|
current_dir_path = os.path.dirname(current_file_path)
|
|
42
|
-
|
|
43
|
-
# 获取上一级目录的路径
|
|
44
|
-
parent_dir_path = os.path.abspath(
|
|
45
|
-
os.path.join(current_dir_path, os.path.pardir)
|
|
46
|
-
)
|
|
47
44
|
load(
|
|
48
45
|
name="wind_backstepping",
|
|
49
46
|
sources=[
|
|
50
|
-
os.path.join(
|
|
51
|
-
os.path.join(
|
|
47
|
+
os.path.join(current_dir_path, "cuda_kernel/wkv7_cuda.cu"),
|
|
48
|
+
os.path.join(current_dir_path, "cuda_kernel/wkv7_op.cpp"),
|
|
52
49
|
],
|
|
53
50
|
is_python_module=False,
|
|
54
51
|
verbose=True,
|
|
@@ -137,11 +134,17 @@ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
|
137
134
|
from jax.lib import xla_bridge
|
|
138
135
|
import jax
|
|
139
136
|
import os
|
|
137
|
+
import logging
|
|
140
138
|
|
|
139
|
+
logging.basicConfig(level=logging.ERROR)
|
|
140
|
+
os.environ["TRITON_LOG_LEVEL"] = "ERROR" # 只显示错误级别的日志
|
|
141
|
+
os.environ["TRITON_DISABLE_AUTOTUNE"] = "1" # 禁用自动调优日志
|
|
142
|
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # 禁用自动调优日志
|
|
141
143
|
if (
|
|
142
144
|
xla_bridge.get_backend().platform == "gpu"
|
|
143
145
|
and KERNEL_TYPE.lower() == "triton"
|
|
144
146
|
):
|
|
147
|
+
os.environ["JAX_LOG_COMPUTATION"] = "0"
|
|
145
148
|
from .jax_op import generalized_delta_rule
|
|
146
149
|
|
|
147
150
|
USE_KERNEL = True
|
|
@@ -5,6 +5,8 @@ import functools
|
|
|
5
5
|
import triton
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
|
+
from enum import Enum
|
|
9
|
+
import contextlib
|
|
8
10
|
|
|
9
11
|
|
|
10
12
|
@lru_cache(maxsize=None)
|
|
@@ -82,9 +84,6 @@ def is_triton_shared_mem_enough(
|
|
|
82
84
|
|
|
83
85
|
device_capacity = is_triton_shared_mem_enough()
|
|
84
86
|
|
|
85
|
-
from enum import Enum
|
|
86
|
-
import contextlib
|
|
87
|
-
|
|
88
87
|
|
|
89
88
|
def _cpu_device_warning():
|
|
90
89
|
import warnings
|
|
@@ -6,6 +6,8 @@ from typing import Literal
|
|
|
6
6
|
import triton
|
|
7
7
|
from packaging import version
|
|
8
8
|
import torch
|
|
9
|
+
from enum import Enum
|
|
10
|
+
import contextlib
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
@lru_cache(maxsize=None)
|
|
@@ -105,8 +107,6 @@ def is_triton_shared_mem_enough(
|
|
|
105
107
|
|
|
106
108
|
|
|
107
109
|
device_capacity = is_triton_shared_mem_enough()
|
|
108
|
-
from enum import Enum
|
|
109
|
-
import contextlib
|
|
110
110
|
|
|
111
111
|
|
|
112
112
|
def _cpu_device_warning():
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# -*- coding: utf-8 -*-
|
|
2
|
-
# Copyright (c) 2023-2025,
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
3
|
|
|
4
4
|
from typing import Tuple
|
|
5
5
|
|
|
@@ -7,7 +7,7 @@ import jax_triton as jt
|
|
|
7
7
|
import jax
|
|
8
8
|
import triton
|
|
9
9
|
|
|
10
|
-
from ..
|
|
10
|
+
from ..get_jax_devices_info import check_shared_mem
|
|
11
11
|
from ..triton_kernel.chunk_o_bwd import *
|
|
12
12
|
|
|
13
13
|
|
|
@@ -104,7 +104,7 @@ def chunk_dplr_bwd_o(
|
|
|
104
104
|
out_shape=out_shapes,
|
|
105
105
|
grid=grid,
|
|
106
106
|
)
|
|
107
|
-
return dq, dk, dw, db, dgk_last
|
|
107
|
+
return (dq, dk, dw, db, dgk_last)
|
|
108
108
|
|
|
109
109
|
|
|
110
110
|
def chunk_dplr_bwd_dAu(
|
rwkv_ops/rwkv7_kernel/jax_op.py
CHANGED
|
@@ -223,7 +223,6 @@ def chunk_dplr_bwd(
|
|
|
223
223
|
|
|
224
224
|
dv = chunk_dplr_bwd_dv(A_qk=A_qk, kg=kg, do=do, dh=dh, chunk_size=BT)
|
|
225
225
|
del A_qk
|
|
226
|
-
|
|
227
226
|
dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
|
|
228
227
|
k=kg,
|
|
229
228
|
b=bg,
|
|
@@ -239,7 +238,6 @@ def chunk_dplr_bwd(
|
|
|
239
238
|
scale=scale,
|
|
240
239
|
)
|
|
241
240
|
del v_new
|
|
242
|
-
|
|
243
241
|
dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
|
|
244
242
|
A_ab_inv=A_ab_inv,
|
|
245
243
|
A_ak=A_ak,
|