rwkv-ops 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- rwkv_ops/__init__.py +26 -0
- rwkv_ops/rwkv7_kernel/__init__.py +153 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +221 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +95 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +523 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.1.0.dist-info/LICENSE.txt +201 -0
- rwkv_ops-0.1.0.dist-info/METADATA +118 -0
- rwkv_ops-0.1.0.dist-info/RECORD +43 -0
- rwkv_ops-0.1.0.dist-info/WHEEL +5 -0
- rwkv_ops-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
import triton
|
|
4
|
+
from .jax_kernel.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
|
|
5
|
+
from .jax_kernel.chunk_A_fwd import chunk_dplr_fwd_intra
|
|
6
|
+
from .jax_kernel.chunk_h_bwd import chunk_dplr_bwd_dhu
|
|
7
|
+
from .jax_kernel.chunk_h_fwd import chunk_dplr_fwd_h
|
|
8
|
+
from .jax_kernel.chunk_o_bwd import (
|
|
9
|
+
chunk_dplr_bwd_dAu,
|
|
10
|
+
chunk_dplr_bwd_dv,
|
|
11
|
+
chunk_dplr_bwd_o,
|
|
12
|
+
)
|
|
13
|
+
from .jax_kernel.chunk_o_fwd import chunk_dplr_fwd_o
|
|
14
|
+
from .jax_kernel.wy_fast_bwd import chunk_dplr_bwd_wy
|
|
15
|
+
from .jax_kernel.wy_fast_fwd import prepare_wy_repr_fwd
|
|
16
|
+
from .jax_kernel.cumsum import chunk_rwkv6_fwd_cumsum
|
|
17
|
+
|
|
18
|
+
CHUNKSIZE = 16
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def chunk_dplr_fwd(
|
|
22
|
+
q: jax.Array,
|
|
23
|
+
k: jax.Array,
|
|
24
|
+
v: jax.Array,
|
|
25
|
+
a: jax.Array,
|
|
26
|
+
b: jax.Array,
|
|
27
|
+
gk: jax.Array,
|
|
28
|
+
scale: float,
|
|
29
|
+
initial_state: jax.Array,
|
|
30
|
+
output_final_state: bool,
|
|
31
|
+
chunk_size: int = 16,
|
|
32
|
+
):
|
|
33
|
+
T = q.shape[1]
|
|
34
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
35
|
+
|
|
36
|
+
gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
|
|
37
|
+
|
|
38
|
+
A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
|
|
39
|
+
q=q,
|
|
40
|
+
k=k,
|
|
41
|
+
a=a,
|
|
42
|
+
b=b,
|
|
43
|
+
gi=gi,
|
|
44
|
+
ge=ge,
|
|
45
|
+
scale=scale,
|
|
46
|
+
chunk_size=BT,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
del ge
|
|
50
|
+
|
|
51
|
+
# A_ab, A_ak, gi, ge torch.float32
|
|
52
|
+
# A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16
|
|
53
|
+
w, u, _ = prepare_wy_repr_fwd(ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT)
|
|
54
|
+
|
|
55
|
+
del A_ab, A_ak
|
|
56
|
+
h, v_new, final_state = chunk_dplr_fwd_h(
|
|
57
|
+
kg=kg,
|
|
58
|
+
bg=bg,
|
|
59
|
+
v=v,
|
|
60
|
+
w=w,
|
|
61
|
+
u=u,
|
|
62
|
+
gk=gi,
|
|
63
|
+
initial_state=initial_state,
|
|
64
|
+
output_final_state=output_final_state,
|
|
65
|
+
chunk_size=BT,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
del u, kg, bg, gi
|
|
69
|
+
|
|
70
|
+
o = chunk_dplr_fwd_o(
|
|
71
|
+
qg=qg, v=v, v_new=v_new, A_qk=A_qk, A_qb=A_qb, h=h, chunk_size=BT
|
|
72
|
+
)
|
|
73
|
+
del v_new, h, A_qk, A_qb
|
|
74
|
+
|
|
75
|
+
return o, final_state
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def chunk_dplr_delta_rule_fwd(
|
|
79
|
+
q: jax.Array,
|
|
80
|
+
k: jax.Array,
|
|
81
|
+
v: jax.Array,
|
|
82
|
+
a: jax.Array,
|
|
83
|
+
b: jax.Array,
|
|
84
|
+
gk: jax.Array,
|
|
85
|
+
scale=None,
|
|
86
|
+
initial_state=None,
|
|
87
|
+
output_final_state: bool = True,
|
|
88
|
+
):
|
|
89
|
+
assert q.dtype == k.dtype == v.dtype
|
|
90
|
+
# assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16."
|
|
91
|
+
# gk = gk.float()
|
|
92
|
+
|
|
93
|
+
scale = k.shape[-1] ** -0.5 if scale is None else scale
|
|
94
|
+
chunk_size = CHUNKSIZE
|
|
95
|
+
|
|
96
|
+
o, final_state = chunk_dplr_fwd(
|
|
97
|
+
q=q,
|
|
98
|
+
k=k,
|
|
99
|
+
v=v,
|
|
100
|
+
a=a,
|
|
101
|
+
b=b,
|
|
102
|
+
gk=gk,
|
|
103
|
+
scale=scale,
|
|
104
|
+
initial_state=initial_state,
|
|
105
|
+
output_final_state=output_final_state,
|
|
106
|
+
chunk_size=chunk_size,
|
|
107
|
+
)
|
|
108
|
+
return o, final_state
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def cal_log_w(w: jax.Array) -> jax.Array:
|
|
112
|
+
return -jnp.exp(w)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@jax.custom_vjp
|
|
116
|
+
def chunk_dplr(
|
|
117
|
+
r: jax.Array,
|
|
118
|
+
k: jax.Array,
|
|
119
|
+
v: jax.Array,
|
|
120
|
+
a: jax.Array,
|
|
121
|
+
b: jax.Array,
|
|
122
|
+
gk: jax.Array,
|
|
123
|
+
initial_state: jax.Array = None,
|
|
124
|
+
):
|
|
125
|
+
return chunk_dplr_delta_rule_fwd(
|
|
126
|
+
q=r,
|
|
127
|
+
k=k,
|
|
128
|
+
v=v,
|
|
129
|
+
a=a,
|
|
130
|
+
b=b,
|
|
131
|
+
gk=gk,
|
|
132
|
+
scale=1,
|
|
133
|
+
initial_state=initial_state,
|
|
134
|
+
output_final_state=True,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def chunk_dplr_fwd_jax(
|
|
139
|
+
r: jax.Array,
|
|
140
|
+
k: jax.Array,
|
|
141
|
+
v: jax.Array,
|
|
142
|
+
a: jax.Array,
|
|
143
|
+
b: jax.Array,
|
|
144
|
+
gk: jax.Array,
|
|
145
|
+
initial_state: jax.Array = None,
|
|
146
|
+
):
|
|
147
|
+
o, state = chunk_dplr_delta_rule_fwd(
|
|
148
|
+
q=r,
|
|
149
|
+
k=k,
|
|
150
|
+
v=v,
|
|
151
|
+
a=a,
|
|
152
|
+
b=b,
|
|
153
|
+
gk=gk,
|
|
154
|
+
scale=1,
|
|
155
|
+
initial_state=initial_state,
|
|
156
|
+
output_final_state=True,
|
|
157
|
+
)
|
|
158
|
+
cache = (r, k, v, a, b, gk, initial_state)
|
|
159
|
+
return [o, state], cache
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def chunk_dplr_bwd(
|
|
163
|
+
q: jax.Array,
|
|
164
|
+
k: jax.Array,
|
|
165
|
+
v: jax.Array,
|
|
166
|
+
a: jax.Array,
|
|
167
|
+
b: jax.Array,
|
|
168
|
+
gk: jax.Array,
|
|
169
|
+
initial_state,
|
|
170
|
+
scale,
|
|
171
|
+
do: jax.Array,
|
|
172
|
+
dht: jax.Array,
|
|
173
|
+
chunk_size: int = CHUNKSIZE,
|
|
174
|
+
):
|
|
175
|
+
# DTYPE = do.dtype
|
|
176
|
+
BT = chunk_size
|
|
177
|
+
scale = scale
|
|
178
|
+
# if do != None:
|
|
179
|
+
# do = do, q.dtype)
|
|
180
|
+
# if dht != None:
|
|
181
|
+
# dht = dht, q.dtype)
|
|
182
|
+
|
|
183
|
+
# ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
|
|
184
|
+
gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
|
|
185
|
+
|
|
186
|
+
A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
|
|
187
|
+
q=q,
|
|
188
|
+
k=k,
|
|
189
|
+
a=a,
|
|
190
|
+
b=b,
|
|
191
|
+
gi=gi,
|
|
192
|
+
ge=ge,
|
|
193
|
+
scale=scale,
|
|
194
|
+
chunk_size=BT,
|
|
195
|
+
)
|
|
196
|
+
w, u, A_ab_inv = prepare_wy_repr_fwd(
|
|
197
|
+
ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT
|
|
198
|
+
)
|
|
199
|
+
del A_ab
|
|
200
|
+
h, v_new, _ = chunk_dplr_fwd_h(
|
|
201
|
+
kg=kg, bg=bg, v=v, w=w, u=u, gk=gi, initial_state=initial_state, chunk_size=BT
|
|
202
|
+
)
|
|
203
|
+
del u
|
|
204
|
+
# ******* end of recomputation *******
|
|
205
|
+
# A_ak, A_ab_inv, gi, ge torch.float32
|
|
206
|
+
# A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16
|
|
207
|
+
|
|
208
|
+
dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
|
|
209
|
+
v=v, v_new=v_new, do=do, A_qb=A_qb, scale=scale, chunk_size=BT
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
dh, dh0, dv_new = chunk_dplr_bwd_dhu(
|
|
213
|
+
qg=qg,
|
|
214
|
+
bg=bg,
|
|
215
|
+
w=w,
|
|
216
|
+
gk=gi,
|
|
217
|
+
h0=initial_state,
|
|
218
|
+
dht=dht,
|
|
219
|
+
do=do,
|
|
220
|
+
dv=dv_new_intra,
|
|
221
|
+
chunk_size=BT,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
dv = chunk_dplr_bwd_dv(A_qk=A_qk, kg=kg, do=do, dh=dh, chunk_size=BT)
|
|
225
|
+
del A_qk
|
|
226
|
+
|
|
227
|
+
dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
|
|
228
|
+
k=kg,
|
|
229
|
+
b=bg,
|
|
230
|
+
v=v,
|
|
231
|
+
v_new=v_new,
|
|
232
|
+
do=do,
|
|
233
|
+
h=h,
|
|
234
|
+
dh=dh,
|
|
235
|
+
dv=dv_new,
|
|
236
|
+
w=w,
|
|
237
|
+
gk=gi,
|
|
238
|
+
chunk_size=BT,
|
|
239
|
+
scale=scale,
|
|
240
|
+
)
|
|
241
|
+
del v_new
|
|
242
|
+
|
|
243
|
+
dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
|
|
244
|
+
A_ab_inv=A_ab_inv,
|
|
245
|
+
A_ak=A_ak,
|
|
246
|
+
v=v,
|
|
247
|
+
ag=ag,
|
|
248
|
+
dw=dw,
|
|
249
|
+
du=dv_new,
|
|
250
|
+
dv0=dv,
|
|
251
|
+
chunk_size=BT,
|
|
252
|
+
)
|
|
253
|
+
del A_ak
|
|
254
|
+
|
|
255
|
+
dq, dk, da, db, dgk = chunk_dplr_bwd_dqk_intra(
|
|
256
|
+
q=q,
|
|
257
|
+
k=k,
|
|
258
|
+
a=a,
|
|
259
|
+
b=b,
|
|
260
|
+
gi=gi,
|
|
261
|
+
ge=ge,
|
|
262
|
+
dAqk=dA_qk,
|
|
263
|
+
dAqb=dA_qb,
|
|
264
|
+
dAak=dA_ak,
|
|
265
|
+
dAab=dA_ab,
|
|
266
|
+
dgk_last=dgk_last,
|
|
267
|
+
dqg=dqg,
|
|
268
|
+
dkg=dkg,
|
|
269
|
+
dag=dag,
|
|
270
|
+
dbg=dbg,
|
|
271
|
+
chunk_size=BT,
|
|
272
|
+
scale=scale,
|
|
273
|
+
)
|
|
274
|
+
return (
|
|
275
|
+
jnp.asarray(dq, q.dtype),
|
|
276
|
+
jnp.asarray(dk, k.dtype),
|
|
277
|
+
jnp.asarray(dv, v.dtype),
|
|
278
|
+
jnp.asarray(da, a.dtype),
|
|
279
|
+
jnp.asarray(db, b.dtype),
|
|
280
|
+
jnp.asarray(dgk, gk.dtype),
|
|
281
|
+
None if initial_state is None else jnp.asarray(dh0, initial_state.dtype),
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def chunk_dplr_bwd_jax(res, g):
|
|
286
|
+
q, k, v, a, b, gk, initial_state = res
|
|
287
|
+
do, dht = g
|
|
288
|
+
return chunk_dplr_bwd(
|
|
289
|
+
q,
|
|
290
|
+
k,
|
|
291
|
+
v,
|
|
292
|
+
a,
|
|
293
|
+
b,
|
|
294
|
+
gk,
|
|
295
|
+
initial_state,
|
|
296
|
+
scale=1,
|
|
297
|
+
do=do,
|
|
298
|
+
dht=dht,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
chunk_dplr.defvjp(chunk_dplr_fwd_jax, chunk_dplr_bwd_jax)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def transpose_head(x, head_first):
|
|
306
|
+
# x = jnp.asarray(x,"bfloat16")
|
|
307
|
+
if head_first:
|
|
308
|
+
return jnp.transpose(x, (0, 2, 1, 3))
|
|
309
|
+
else:
|
|
310
|
+
return x
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
# @partial(jax.jit, static_argnames=['initial_state',"output_final_state","head_first","use_chunk"])
|
|
314
|
+
def generalized_delta_rule(
|
|
315
|
+
r: jax.Array,
|
|
316
|
+
w: jax.Array,
|
|
317
|
+
k: jax.Array,
|
|
318
|
+
v: jax.Array,
|
|
319
|
+
a: jax.Array,
|
|
320
|
+
b: jax.Array,
|
|
321
|
+
initial_state: jax.Array = None,
|
|
322
|
+
output_final_state: bool = True,
|
|
323
|
+
head_first: bool = False,
|
|
324
|
+
):
|
|
325
|
+
r"""
|
|
326
|
+
Main interface function for chunked delta rule attention.
|
|
327
|
+
|
|
328
|
+
分块 Delta Rule 注意力机制的主要接口函数。
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
q (jax.Array):
|
|
332
|
+
queries of shape `[B, T, H, K]`
|
|
333
|
+
k (jax.Array):
|
|
334
|
+
keys of shape `[B, T, H, K]`
|
|
335
|
+
v (jax.Array):
|
|
336
|
+
values of shape `[B, T, H, V]`
|
|
337
|
+
a (jax.Array):
|
|
338
|
+
activations of shape `[B, T, H, K]`
|
|
339
|
+
b (jax.Array):
|
|
340
|
+
betas of shape `[B, T, H, K]`
|
|
341
|
+
gk (jax.Array):
|
|
342
|
+
gk of shape `[B, T, H, K]` decay term in log space!
|
|
343
|
+
initial_state (Optional[jax.Array]):
|
|
344
|
+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
|
345
|
+
For equal-length input sequences, `N` equals the batch size `B`.
|
|
346
|
+
Default: `None`.
|
|
347
|
+
output_final_state (Optional[bool]):
|
|
348
|
+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
|
349
|
+
head_first (Optional[bool]):
|
|
350
|
+
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
|
351
|
+
Default: `False`.
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
o (jax.Array):
|
|
355
|
+
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
|
356
|
+
final_state (jax.Array):
|
|
357
|
+
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
|
358
|
+
"""
|
|
359
|
+
DTYPE = r.dtype
|
|
360
|
+
r = transpose_head(r, head_first)
|
|
361
|
+
k = transpose_head(k, head_first)
|
|
362
|
+
v = transpose_head(v, head_first)
|
|
363
|
+
a = transpose_head(a, head_first)
|
|
364
|
+
b = transpose_head(b, head_first)
|
|
365
|
+
|
|
366
|
+
if w is not None:
|
|
367
|
+
log_w = cal_log_w(w)
|
|
368
|
+
else:
|
|
369
|
+
assert log_w is not None, "Either w or log_w must be provided!"
|
|
370
|
+
log_w = transpose_head(log_w, head_first)
|
|
371
|
+
o, final_state = chunk_dplr(
|
|
372
|
+
r=r,
|
|
373
|
+
k=k,
|
|
374
|
+
v=v,
|
|
375
|
+
a=a,
|
|
376
|
+
b=b,
|
|
377
|
+
gk=log_w,
|
|
378
|
+
initial_state=initial_state,
|
|
379
|
+
)
|
|
380
|
+
if output_final_state:
|
|
381
|
+
return jnp.asarray(o, DTYPE), final_state
|
|
382
|
+
return jnp.asarray(o, DTYPE)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from keras import ops
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def transpose_head(x, head_first):
|
|
6
|
+
"""
|
|
7
|
+
对输入张量进行转置操作。
|
|
8
|
+
|
|
9
|
+
参数:
|
|
10
|
+
x: 输入张量。
|
|
11
|
+
head_first: 布尔值,决定是否进行转置。
|
|
12
|
+
|
|
13
|
+
返回:
|
|
14
|
+
转置后的张量(如果head_first为True),否则返回原张量。
|
|
15
|
+
"""
|
|
16
|
+
x = ops.cast(x, "float32")
|
|
17
|
+
if head_first:
|
|
18
|
+
return ops.transpose(x, (0, 2, 1, 3))
|
|
19
|
+
else:
|
|
20
|
+
return x
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def generalized_delta_rule(
|
|
24
|
+
r,
|
|
25
|
+
w,
|
|
26
|
+
k,
|
|
27
|
+
v,
|
|
28
|
+
a,
|
|
29
|
+
b,
|
|
30
|
+
initial_state=None,
|
|
31
|
+
output_final_state: bool = True,
|
|
32
|
+
head_first: bool = False,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
实现广义delta规则的函数。
|
|
36
|
+
|
|
37
|
+
参数:
|
|
38
|
+
r: 输入张量。
|
|
39
|
+
w: 权重张量。
|
|
40
|
+
k, v, a, b: 其他输入张量。
|
|
41
|
+
initial_state: 初始状态张量。
|
|
42
|
+
output_final_state: 是否输出最终状态。
|
|
43
|
+
head_first: 是否在计算中将head维度放在第一位。
|
|
44
|
+
|
|
45
|
+
返回:
|
|
46
|
+
根据output_final_state参数决定是否返回最终状态。
|
|
47
|
+
"""
|
|
48
|
+
DTYPE = r.dtype
|
|
49
|
+
B, T, H, N = ops.shape(r)
|
|
50
|
+
r = transpose_head(r, head_first)
|
|
51
|
+
|
|
52
|
+
k = transpose_head(k, head_first)
|
|
53
|
+
|
|
54
|
+
v = transpose_head(v, head_first)
|
|
55
|
+
a = transpose_head(a, head_first)
|
|
56
|
+
b = transpose_head(b, head_first)
|
|
57
|
+
w = transpose_head(w, head_first)
|
|
58
|
+
w = ops.exp(-ops.exp(w))
|
|
59
|
+
|
|
60
|
+
if initial_state is not None:
|
|
61
|
+
state = initial_state
|
|
62
|
+
if ops.shape(state)[0] == 1:
|
|
63
|
+
state = ops.broadcast_to(state, (B, H, N, N))
|
|
64
|
+
else:
|
|
65
|
+
state = ops.zeros((B, H, N, N), dtype="float32")
|
|
66
|
+
out = ops.zeros((B, T, H, N), dtype=r.dtype)
|
|
67
|
+
|
|
68
|
+
def step(t, inputs):
|
|
69
|
+
"""
|
|
70
|
+
执行单个时间步的计算。
|
|
71
|
+
|
|
72
|
+
参数:
|
|
73
|
+
t: 当前时间步。
|
|
74
|
+
inputs: 包含当前状态和输出的列表。
|
|
75
|
+
|
|
76
|
+
返回:
|
|
77
|
+
更新后的状态和输出。
|
|
78
|
+
"""
|
|
79
|
+
state, out = inputs
|
|
80
|
+
kk = ops.reshape(k[:, t, :], (B, H, 1, N))
|
|
81
|
+
rr = ops.reshape(r[:, t, :], (B, H, N, 1))
|
|
82
|
+
vv = ops.reshape(v[:, t, :], (B, H, N, 1))
|
|
83
|
+
aa = ops.reshape(a[:, t, :], (B, H, N, 1))
|
|
84
|
+
bb = ops.reshape(b[:, t, :], (B, H, 1, N))
|
|
85
|
+
state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk
|
|
86
|
+
out = ops.slice_update(
|
|
87
|
+
out, [0, t, 0, 0], ops.reshape((state @ rr), (B, 1, H, N))
|
|
88
|
+
)
|
|
89
|
+
return [state, out]
|
|
90
|
+
|
|
91
|
+
state, out = ops.fori_loop(0, T, step, [state, out])
|
|
92
|
+
|
|
93
|
+
if output_final_state:
|
|
94
|
+
return ops.cast(out, DTYPE), state
|
|
95
|
+
return ops.cast(out, DTYPE)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from ..torch_kernel.chunk_A_fwd import *
|
|
2
|
+
from ..torch_kernel.chunk_A_bwd import *
|
|
3
|
+
|
|
4
|
+
# ---------- chunk_h ----------
|
|
5
|
+
from ..torch_kernel.chunk_h_fwd import *
|
|
6
|
+
from ..torch_kernel.chunk_h_bwd import *
|
|
7
|
+
|
|
8
|
+
# ---------- chunk_o ----------
|
|
9
|
+
from ..torch_kernel.chunk_o_fwd import *
|
|
10
|
+
from ..torch_kernel.chunk_o_bwd import *
|
|
11
|
+
from ..torch_kernel.cumsum import *
|
|
12
|
+
from ..torch_kernel.wy_fast_fwd import *
|
|
13
|
+
from ..torch_kernel.wy_fast_bwd import *
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import triton
|
|
7
|
+
from ..triton_kernel.chunk_A_bwd import *
|
|
8
|
+
from ..triton_kernel.utils import is_gather_supported
|
|
9
|
+
from ..get_torch_devices_info import check_shared_mem
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def chunk_dplr_bwd_dqk_intra(
|
|
13
|
+
q: torch.Tensor,
|
|
14
|
+
k: torch.Tensor,
|
|
15
|
+
a: torch.Tensor,
|
|
16
|
+
b: torch.Tensor,
|
|
17
|
+
gi: torch.Tensor,
|
|
18
|
+
ge: torch.Tensor,
|
|
19
|
+
dAqk: torch.Tensor,
|
|
20
|
+
dAqb: torch.Tensor,
|
|
21
|
+
dAak: torch.Tensor,
|
|
22
|
+
dAab: torch.Tensor,
|
|
23
|
+
dqg: torch.Tensor,
|
|
24
|
+
dkg: torch.Tensor,
|
|
25
|
+
dag: torch.Tensor,
|
|
26
|
+
dbg: torch.Tensor,
|
|
27
|
+
dgk_last: torch.Tensor,
|
|
28
|
+
scale: float = 1.0,
|
|
29
|
+
chunk_size: int = 16,
|
|
30
|
+
):
|
|
31
|
+
B, T, H, K = q.shape
|
|
32
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
33
|
+
BK = (
|
|
34
|
+
min(64, triton.next_power_of_2(K))
|
|
35
|
+
if check_shared_mem()
|
|
36
|
+
else min(32, triton.next_power_of_2(K))
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
NT = triton.cdiv(T, BT)
|
|
40
|
+
NK = triton.cdiv(K, BK)
|
|
41
|
+
grid = (NK, NT, B * H)
|
|
42
|
+
|
|
43
|
+
dq = torch.empty_like(q)
|
|
44
|
+
dk = torch.empty_like(k)
|
|
45
|
+
da = torch.empty_like(a)
|
|
46
|
+
db = torch.empty_like(b)
|
|
47
|
+
dgk = torch.empty_like(gi, dtype=torch.float)
|
|
48
|
+
dgk_offset = torch.empty_like(gi, dtype=torch.float)
|
|
49
|
+
|
|
50
|
+
chunk_dplr_bwd_kernel_intra[grid](
|
|
51
|
+
q=q,
|
|
52
|
+
k=k,
|
|
53
|
+
a=a,
|
|
54
|
+
b=b,
|
|
55
|
+
gi=gi,
|
|
56
|
+
ge=ge,
|
|
57
|
+
dAqk=dAqk,
|
|
58
|
+
dAqb=dAqb,
|
|
59
|
+
dAak=dAak,
|
|
60
|
+
dAab=dAab,
|
|
61
|
+
dq=dq,
|
|
62
|
+
dk=dk,
|
|
63
|
+
dgk=dgk,
|
|
64
|
+
dgk_offset=dgk_offset,
|
|
65
|
+
dqg=dqg,
|
|
66
|
+
dkg=dkg,
|
|
67
|
+
dag=dag,
|
|
68
|
+
dbg=dbg,
|
|
69
|
+
da=da,
|
|
70
|
+
db=db,
|
|
71
|
+
scale=scale,
|
|
72
|
+
T=T,
|
|
73
|
+
H=H,
|
|
74
|
+
K=K,
|
|
75
|
+
BT=BT,
|
|
76
|
+
BC=BT,
|
|
77
|
+
BK=BK,
|
|
78
|
+
GATHER_SUPPORTED=is_gather_supported,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
dgk_output = torch.empty_like(dgk)
|
|
82
|
+
|
|
83
|
+
def grid(meta):
|
|
84
|
+
return (NT, triton.cdiv(K, meta["BK"]), B * H)
|
|
85
|
+
|
|
86
|
+
chunk_dplr_bwd_dgk_kernel[grid](
|
|
87
|
+
dgk=dgk,
|
|
88
|
+
dgk_offset=dgk_offset,
|
|
89
|
+
dgk_last=dgk_last,
|
|
90
|
+
dgk_output=dgk_output,
|
|
91
|
+
T=T,
|
|
92
|
+
H=H,
|
|
93
|
+
K=K,
|
|
94
|
+
BT=BT,
|
|
95
|
+
)
|
|
96
|
+
return dq, dk, da, db, dgk_output
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import triton
|
|
7
|
+
|
|
8
|
+
from ..triton_kernel.utils import is_gather_supported
|
|
9
|
+
|
|
10
|
+
from ..triton_kernel.chunk_A_fwd import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def chunk_dplr_fwd_intra(
|
|
14
|
+
q: torch.Tensor,
|
|
15
|
+
k: torch.Tensor,
|
|
16
|
+
a: torch.Tensor,
|
|
17
|
+
b: torch.Tensor,
|
|
18
|
+
gi: torch.Tensor,
|
|
19
|
+
ge: torch.Tensor,
|
|
20
|
+
scale: float,
|
|
21
|
+
chunk_size: int,
|
|
22
|
+
):
|
|
23
|
+
B, T, H, K = k.shape
|
|
24
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
25
|
+
|
|
26
|
+
NT = triton.cdiv(T, BT)
|
|
27
|
+
|
|
28
|
+
Aqk = q.new_empty(B, T, H, BT, dtype=q.dtype)
|
|
29
|
+
Aqb = q.new_empty(B, T, H, BT, dtype=q.dtype)
|
|
30
|
+
# involving matrix inverse and it'd be better to use float here.
|
|
31
|
+
Aab = q.new_empty(B, T, H, BT, dtype=torch.float)
|
|
32
|
+
Aak = q.new_empty(B, T, H, BT, dtype=torch.float)
|
|
33
|
+
|
|
34
|
+
grid = (NT, B, H)
|
|
35
|
+
BK = triton.next_power_of_2(K)
|
|
36
|
+
qg = torch.empty_like(q)
|
|
37
|
+
kg = torch.empty_like(k, dtype=q.dtype)
|
|
38
|
+
ag = torch.empty_like(a, dtype=q.dtype)
|
|
39
|
+
bg = torch.empty_like(b, dtype=q.dtype)
|
|
40
|
+
chunk_dplr_fwd_A_kernel_intra_sub_intra[grid](
|
|
41
|
+
q=q,
|
|
42
|
+
k=k,
|
|
43
|
+
a=a,
|
|
44
|
+
b=b,
|
|
45
|
+
gi=gi,
|
|
46
|
+
ge=ge,
|
|
47
|
+
Aqk=Aqk,
|
|
48
|
+
Aqb=Aqb,
|
|
49
|
+
Aab=Aab,
|
|
50
|
+
Aak=Aak,
|
|
51
|
+
qg=qg,
|
|
52
|
+
kg=kg,
|
|
53
|
+
ag=ag,
|
|
54
|
+
bg=bg,
|
|
55
|
+
scale=scale,
|
|
56
|
+
T=T,
|
|
57
|
+
H=H,
|
|
58
|
+
K=K,
|
|
59
|
+
BT=BT,
|
|
60
|
+
BC=BT,
|
|
61
|
+
BK=BK,
|
|
62
|
+
GATHER_SUPPORTED=is_gather_supported,
|
|
63
|
+
)
|
|
64
|
+
return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import triton
|
|
8
|
+
|
|
9
|
+
from ..get_torch_devices_info import check_shared_mem
|
|
10
|
+
from ..triton_kernel.chunk_h_bwd import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def chunk_dplr_bwd_dhu(
|
|
14
|
+
qg: torch.Tensor,
|
|
15
|
+
bg: torch.Tensor,
|
|
16
|
+
w: torch.Tensor,
|
|
17
|
+
gk: torch.Tensor,
|
|
18
|
+
h0: torch.Tensor,
|
|
19
|
+
dht: Optional[torch.Tensor],
|
|
20
|
+
do: torch.Tensor,
|
|
21
|
+
dv: torch.Tensor,
|
|
22
|
+
chunk_size: int = 64,
|
|
23
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
24
|
+
B, T, H, K, V = *qg.shape, do.shape[-1]
|
|
25
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
26
|
+
BK = triton.next_power_of_2(K)
|
|
27
|
+
assert BK <= 256, (
|
|
28
|
+
"current kernel does not support head dimension being larger than 256."
|
|
29
|
+
)
|
|
30
|
+
# H100
|
|
31
|
+
if check_shared_mem("hopper", qg.device.index):
|
|
32
|
+
BV = 64
|
|
33
|
+
BC = 64 if K <= 128 else 32
|
|
34
|
+
elif check_shared_mem("ampere", qg.device.index): # A100
|
|
35
|
+
BV = 32
|
|
36
|
+
BC = 32
|
|
37
|
+
else: # Etc: 4090
|
|
38
|
+
BV = 16
|
|
39
|
+
BC = 16
|
|
40
|
+
|
|
41
|
+
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
|
42
|
+
|
|
43
|
+
BC = min(BT, BC)
|
|
44
|
+
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
|
45
|
+
assert NK == 1, (
|
|
46
|
+
"NK > 1 is not supported because it involves time-consuming synchronization"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
dh = qg.new_empty(B, NT, H, K, V)
|
|
50
|
+
dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
|
|
51
|
+
dv2 = torch.zeros_like(dv)
|
|
52
|
+
|
|
53
|
+
grid = (NK, NV, N * H)
|
|
54
|
+
chunk_dplr_bwd_kernel_dhu[grid](
|
|
55
|
+
qg=qg,
|
|
56
|
+
bg=bg,
|
|
57
|
+
w=w,
|
|
58
|
+
gk=gk,
|
|
59
|
+
dht=dht,
|
|
60
|
+
dh0=dh0,
|
|
61
|
+
do=do,
|
|
62
|
+
dh=dh,
|
|
63
|
+
dv=dv,
|
|
64
|
+
dv2=dv2,
|
|
65
|
+
T=T,
|
|
66
|
+
H=H,
|
|
67
|
+
K=K,
|
|
68
|
+
V=V,
|
|
69
|
+
BT=BT,
|
|
70
|
+
BC=BC,
|
|
71
|
+
BK=BK,
|
|
72
|
+
BV=BV,
|
|
73
|
+
)
|
|
74
|
+
return dh, dh0, dv2
|