rwkv-ops 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_ops/__init__.py +45 -0
- rwkv_ops/mhc_kernel/__init__.py +50 -0
- rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
- rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
- rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
- rwkv_ops/rwkv6_kernel/__init__.py +120 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +113 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.6.1.dist-info/METADATA +495 -0
- rwkv_ops-0.6.1.dist-info/RECORD +89 -0
- rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
- rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,504 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
This file implements the forward and backward pass of a chunked delta rule attention mechanism,
|
|
4
|
+
optimized with Triton kernels for GPU acceleration. It includes functions for forward propagation,
|
|
5
|
+
backward gradient computation, and integration with PyTorch's autograd system.
|
|
6
|
+
|
|
7
|
+
该文件实现了分块 Delta Rule 注意力机制的前向与反向传播,
|
|
8
|
+
使用 Triton 内核进行 GPU 加速优化。包括前向传播、梯度反向传播函数,
|
|
9
|
+
并集成了 PyTorch 的自动求导系统。
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import warnings
|
|
14
|
+
from typing import Optional
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import triton
|
|
18
|
+
|
|
19
|
+
from .torch_kernel.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
|
|
20
|
+
from .torch_kernel.chunk_A_fwd import chunk_dplr_fwd_intra
|
|
21
|
+
from .torch_kernel.chunk_h_bwd import chunk_dplr_bwd_dhu
|
|
22
|
+
from .torch_kernel.chunk_h_fwd import chunk_dplr_fwd_h
|
|
23
|
+
|
|
24
|
+
from .torch_kernel.chunk_o_bwd import (
|
|
25
|
+
chunk_dplr_bwd_dAu,
|
|
26
|
+
chunk_dplr_bwd_dv,
|
|
27
|
+
chunk_dplr_bwd_o,
|
|
28
|
+
)
|
|
29
|
+
from .torch_kernel.chunk_o_fwd import chunk_dplr_fwd_o
|
|
30
|
+
from .torch_kernel.wy_fast_bwd import chunk_dplr_bwd_wy
|
|
31
|
+
from .torch_kernel.wy_fast_fwd import prepare_wy_repr_fwd
|
|
32
|
+
from .torch_kernel.cumsum import chunk_rwkv6_fwd_cumsum
|
|
33
|
+
from .get_torch_devices_info import (
|
|
34
|
+
autocast_custom_bwd,
|
|
35
|
+
autocast_custom_fwd,
|
|
36
|
+
input_guard,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def cast(x, dtype):
|
|
41
|
+
if x is None or x.dtype == dtype:
|
|
42
|
+
return x
|
|
43
|
+
return x.to(dtype)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def chunk_dplr_fwd(
|
|
47
|
+
q: torch.Tensor,
|
|
48
|
+
k: torch.Tensor,
|
|
49
|
+
v: torch.Tensor,
|
|
50
|
+
a: torch.Tensor,
|
|
51
|
+
b: torch.Tensor,
|
|
52
|
+
gk: torch.Tensor,
|
|
53
|
+
scale: float = 1,
|
|
54
|
+
initial_state: torch.Tensor = None,
|
|
55
|
+
output_final_state: bool = True,
|
|
56
|
+
chunk_size: int = 16,
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
Forward pass of chunked delta rule attention.
|
|
60
|
+
|
|
61
|
+
分块 Delta Rule 注意力机制的前向传播。
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
q (torch.Tensor): Queries tensor [B, T, H, K]
|
|
65
|
+
k (torch.Tensor): Keys tensor [B, T, H, K]
|
|
66
|
+
v (torch.Tensor): Values tensor [B, T, H, V]
|
|
67
|
+
a (torch.Tensor): Activations tensor [B, T, H, K]
|
|
68
|
+
b (torch.Tensor): Betas tensor [B, T, H, K]
|
|
69
|
+
gk (torch.Tensor): Log decay tensor [B, T, H, K]
|
|
70
|
+
scale (float): Scale factor for attention scores
|
|
71
|
+
initial_state (Optional[torch.Tensor]): Initial state for recurrent processing
|
|
72
|
+
output_final_state (bool): Whether to return final state
|
|
73
|
+
chunk_size (int): Chunk size for processing
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
o (torch.Tensor): Output tensor [B, T, H, V]
|
|
77
|
+
final_state (Optional[torch.Tensor]): Final state if requested
|
|
78
|
+
"""
|
|
79
|
+
T = q.shape[1]
|
|
80
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
81
|
+
gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
|
|
82
|
+
|
|
83
|
+
A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
|
|
84
|
+
q=q,
|
|
85
|
+
k=k,
|
|
86
|
+
a=a,
|
|
87
|
+
b=b,
|
|
88
|
+
gi=gi,
|
|
89
|
+
ge=ge,
|
|
90
|
+
scale=scale,
|
|
91
|
+
chunk_size=BT,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
del ge
|
|
95
|
+
|
|
96
|
+
# A_ab, A_ak, gi, ge torch.float32
|
|
97
|
+
# A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16
|
|
98
|
+
w, u, _ = prepare_wy_repr_fwd(ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT)
|
|
99
|
+
|
|
100
|
+
del A_ab, A_ak
|
|
101
|
+
h, v_new, final_state = chunk_dplr_fwd_h(
|
|
102
|
+
kg=kg,
|
|
103
|
+
bg=bg,
|
|
104
|
+
v=v,
|
|
105
|
+
w=w,
|
|
106
|
+
u=u,
|
|
107
|
+
gk=gi,
|
|
108
|
+
initial_state=initial_state,
|
|
109
|
+
output_final_state=output_final_state,
|
|
110
|
+
chunk_size=BT,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
del u, kg, bg, gi
|
|
114
|
+
|
|
115
|
+
o = chunk_dplr_fwd_o(
|
|
116
|
+
qg=qg, v=v, v_new=v_new, A_qk=A_qk, A_qb=A_qb, h=h, chunk_size=BT
|
|
117
|
+
)
|
|
118
|
+
del v_new, h, A_qk, A_qb
|
|
119
|
+
|
|
120
|
+
return o, final_state
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def chunk_dplr_bwd(
|
|
124
|
+
q: torch.Tensor,
|
|
125
|
+
k: torch.Tensor,
|
|
126
|
+
v: torch.Tensor,
|
|
127
|
+
a: torch.Tensor,
|
|
128
|
+
b: torch.Tensor,
|
|
129
|
+
gk: torch.Tensor,
|
|
130
|
+
initial_state: torch.Tensor,
|
|
131
|
+
scale,
|
|
132
|
+
do,
|
|
133
|
+
dht,
|
|
134
|
+
BT: int = 16,
|
|
135
|
+
):
|
|
136
|
+
"""
|
|
137
|
+
Backward pass of chunked delta rule attention.
|
|
138
|
+
|
|
139
|
+
分块 Delta Rule 注意力机制的反向传播。
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
q (torch.Tensor): Queries tensor [B, T, H, K]
|
|
143
|
+
k (torch.Tensor): Keys tensor [B, T, H, K]
|
|
144
|
+
v (torch.Tensor): Values tensor [B, T, H, V]
|
|
145
|
+
a (torch.Tensor): Activations tensor [B, T, H, K]
|
|
146
|
+
b (torch.Tensor): Betas tensor [B, T, H, K]
|
|
147
|
+
gk (torch.Tensor): Log decay tensor [B, T, H, K]
|
|
148
|
+
initial_state (torch.Tensor): Initial state for recurrent processing
|
|
149
|
+
scale (float): Scale factor for attention scores
|
|
150
|
+
do (torch.Tensor): Gradient of outputs
|
|
151
|
+
dht (torch.Tensor): Gradient of final hidden state
|
|
152
|
+
BT (int): Chunk size for processing
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
dq (torch.Tensor): Gradient of queries
|
|
156
|
+
dk (torch.Tensor): Gradient of keys
|
|
157
|
+
dv (torch.Tensor): Gradient of values
|
|
158
|
+
da (torch.Tensor): Gradient of activations
|
|
159
|
+
db (torch.Tensor): Gradient of betas
|
|
160
|
+
dgk (torch.Tensor): Gradient of log decays
|
|
161
|
+
dh0 (torch.Tensor): Gradient of initial state
|
|
162
|
+
"""
|
|
163
|
+
# ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
|
|
164
|
+
gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
|
|
165
|
+
A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
|
|
166
|
+
q=q,
|
|
167
|
+
k=k,
|
|
168
|
+
a=a,
|
|
169
|
+
b=b,
|
|
170
|
+
gi=gi,
|
|
171
|
+
ge=ge,
|
|
172
|
+
scale=scale,
|
|
173
|
+
chunk_size=BT,
|
|
174
|
+
)
|
|
175
|
+
w, u, A_ab_inv = prepare_wy_repr_fwd(
|
|
176
|
+
ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT
|
|
177
|
+
)
|
|
178
|
+
del A_ab
|
|
179
|
+
h, v_new, _ = chunk_dplr_fwd_h(
|
|
180
|
+
kg=kg, bg=bg, v=v, w=w, u=u, gk=gi, initial_state=initial_state, chunk_size=BT
|
|
181
|
+
)
|
|
182
|
+
del u
|
|
183
|
+
# ******* end of recomputation *******
|
|
184
|
+
# A_ak, A_ab_inv, gi, ge torch.float32
|
|
185
|
+
# A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16
|
|
186
|
+
|
|
187
|
+
dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
|
|
188
|
+
v=v, v_new=v_new, do=do, A_qb=A_qb, scale=scale, chunk_size=BT
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
dh, dh0, dv_new = chunk_dplr_bwd_dhu(
|
|
192
|
+
qg=qg,
|
|
193
|
+
bg=bg,
|
|
194
|
+
w=w,
|
|
195
|
+
gk=gi,
|
|
196
|
+
h0=initial_state,
|
|
197
|
+
dht=dht,
|
|
198
|
+
do=do,
|
|
199
|
+
dv=dv_new_intra,
|
|
200
|
+
chunk_size=BT,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
dv = chunk_dplr_bwd_dv(A_qk=A_qk, kg=kg, do=do, dh=dh, chunk_size=BT)
|
|
204
|
+
del A_qk
|
|
205
|
+
|
|
206
|
+
dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
|
|
207
|
+
k=kg,
|
|
208
|
+
b=bg,
|
|
209
|
+
v=v,
|
|
210
|
+
v_new=v_new,
|
|
211
|
+
do=do,
|
|
212
|
+
h=h,
|
|
213
|
+
dh=dh,
|
|
214
|
+
dv=dv_new,
|
|
215
|
+
w=w,
|
|
216
|
+
gk=gi,
|
|
217
|
+
chunk_size=BT,
|
|
218
|
+
scale=scale,
|
|
219
|
+
)
|
|
220
|
+
del v_new
|
|
221
|
+
|
|
222
|
+
dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
|
|
223
|
+
A_ab_inv=A_ab_inv,
|
|
224
|
+
A_ak=A_ak,
|
|
225
|
+
v=v,
|
|
226
|
+
ag=ag,
|
|
227
|
+
dw=dw,
|
|
228
|
+
du=dv_new,
|
|
229
|
+
dv0=dv,
|
|
230
|
+
chunk_size=BT,
|
|
231
|
+
)
|
|
232
|
+
del A_ak
|
|
233
|
+
|
|
234
|
+
dq, dk, da, db, dgk = chunk_dplr_bwd_dqk_intra(
|
|
235
|
+
q=q,
|
|
236
|
+
k=k,
|
|
237
|
+
a=a,
|
|
238
|
+
b=b,
|
|
239
|
+
gi=gi,
|
|
240
|
+
ge=ge,
|
|
241
|
+
dAqk=dA_qk,
|
|
242
|
+
dAqb=dA_qb,
|
|
243
|
+
dAak=dA_ak,
|
|
244
|
+
dAab=dA_ab,
|
|
245
|
+
dgk_last=dgk_last,
|
|
246
|
+
dqg=dqg,
|
|
247
|
+
dkg=dkg,
|
|
248
|
+
dag=dag,
|
|
249
|
+
dbg=dbg,
|
|
250
|
+
chunk_size=BT,
|
|
251
|
+
scale=scale,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
return (
|
|
255
|
+
dq.to(q),
|
|
256
|
+
dk.to(k),
|
|
257
|
+
dv.to(v),
|
|
258
|
+
da.to(a),
|
|
259
|
+
db.to(b),
|
|
260
|
+
dgk.to(gk),
|
|
261
|
+
None,
|
|
262
|
+
dh0,
|
|
263
|
+
None,
|
|
264
|
+
None,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class ChunkDPLRDeltaRuleFunction(torch.autograd.Function):
|
|
269
|
+
@staticmethod
|
|
270
|
+
@input_guard
|
|
271
|
+
@autocast_custom_fwd
|
|
272
|
+
def forward(
|
|
273
|
+
ctx,
|
|
274
|
+
q: torch.Tensor,
|
|
275
|
+
k: torch.Tensor,
|
|
276
|
+
v: torch.Tensor,
|
|
277
|
+
a: torch.Tensor,
|
|
278
|
+
b: torch.Tensor,
|
|
279
|
+
gk: torch.Tensor,
|
|
280
|
+
scale: float = 1,
|
|
281
|
+
initial_state: torch.Tensor = None,
|
|
282
|
+
output_final_state: bool = True,
|
|
283
|
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
|
284
|
+
):
|
|
285
|
+
chunk_size = 16
|
|
286
|
+
o, final_state = chunk_dplr_fwd(
|
|
287
|
+
q=q,
|
|
288
|
+
k=k,
|
|
289
|
+
v=v,
|
|
290
|
+
a=a,
|
|
291
|
+
b=b,
|
|
292
|
+
gk=gk,
|
|
293
|
+
scale=scale,
|
|
294
|
+
initial_state=initial_state,
|
|
295
|
+
output_final_state=output_final_state,
|
|
296
|
+
chunk_size=chunk_size,
|
|
297
|
+
)
|
|
298
|
+
ctx.save_for_backward(q, k, v, a, b, gk, initial_state)
|
|
299
|
+
ctx.cu_seqlens = cu_seqlens
|
|
300
|
+
ctx.scale = scale
|
|
301
|
+
ctx.chunk_size = chunk_size
|
|
302
|
+
return o.to(q.dtype), final_state
|
|
303
|
+
|
|
304
|
+
@staticmethod
|
|
305
|
+
@input_guard
|
|
306
|
+
@autocast_custom_bwd
|
|
307
|
+
def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
|
|
308
|
+
q, k, v, a, b, gk, initial_state = ctx.saved_tensors
|
|
309
|
+
BT = ctx.chunk_size
|
|
310
|
+
cu_seqlens = ctx.cu_seqlens
|
|
311
|
+
scale = ctx.scale
|
|
312
|
+
|
|
313
|
+
return chunk_dplr_bwd(
|
|
314
|
+
q=q,
|
|
315
|
+
k=k,
|
|
316
|
+
v=v,
|
|
317
|
+
a=a,
|
|
318
|
+
b=b,
|
|
319
|
+
gk=gk,
|
|
320
|
+
scale=scale,
|
|
321
|
+
initial_state=initial_state,
|
|
322
|
+
do=do,
|
|
323
|
+
dht=dht,
|
|
324
|
+
BT=BT,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
@torch.compiler.disable
|
|
329
|
+
def chunk_dplr_delta_rule(
|
|
330
|
+
q: torch.Tensor,
|
|
331
|
+
k: torch.Tensor,
|
|
332
|
+
v: torch.Tensor,
|
|
333
|
+
a: torch.Tensor,
|
|
334
|
+
b: torch.Tensor,
|
|
335
|
+
gk: torch.Tensor,
|
|
336
|
+
scale: Optional[float] = None,
|
|
337
|
+
initial_state: Optional[torch.Tensor] = None,
|
|
338
|
+
output_final_state: bool = False,
|
|
339
|
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
|
340
|
+
):
|
|
341
|
+
r"""
|
|
342
|
+
Main interface function for chunked delta rule attention.
|
|
343
|
+
|
|
344
|
+
分块 Delta Rule 注意力机制的主要接口函数。
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
q (torch.Tensor):
|
|
348
|
+
queries of shape `[B, T, H, K]`
|
|
349
|
+
k (torch.Tensor):
|
|
350
|
+
keys of shape `[B, T, H, K]`
|
|
351
|
+
v (torch.Tensor):
|
|
352
|
+
values of shape `[B, T, H, V]`
|
|
353
|
+
a (torch.Tensor):
|
|
354
|
+
activations of shape `[B, T, H, K]`
|
|
355
|
+
b (torch.Tensor):
|
|
356
|
+
betas of shape `[B, T, H, K]`
|
|
357
|
+
gk (torch.Tensor):
|
|
358
|
+
gk of shape `[B, T, H, K]` decay term in log space!
|
|
359
|
+
scale (Optional[float]):
|
|
360
|
+
Scale factor for the RetNet attention scores.
|
|
361
|
+
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
|
362
|
+
initial_state (Optional[torch.Tensor]):
|
|
363
|
+
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
|
364
|
+
For equal-length input sequences, `N` equals the batch size `B`.
|
|
365
|
+
Default: `None`.
|
|
366
|
+
output_final_state (Optional[bool]):
|
|
367
|
+
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
|
368
|
+
cu_seqlens (torch.LongTensor):
|
|
369
|
+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
|
370
|
+
consistent with the FlashAttention API.
|
|
371
|
+
head_first (Optional[bool]):
|
|
372
|
+
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
|
373
|
+
Default: `False`.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
o (torch.Tensor):
|
|
377
|
+
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
|
378
|
+
final_state (torch.Tensor):
|
|
379
|
+
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
|
380
|
+
"""
|
|
381
|
+
if q.dtype == torch.float32:
|
|
382
|
+
warnings.warn(
|
|
383
|
+
"""ChunkDeltaRuleFunction does not support float32 on some platforms. Please use bfloat16/float16.
|
|
384
|
+
If you want to use float32, please solve the issue by yourself.""",
|
|
385
|
+
category=RuntimeWarning,
|
|
386
|
+
stacklevel=2,
|
|
387
|
+
)
|
|
388
|
+
if cu_seqlens is not None:
|
|
389
|
+
if q.shape[0] != 1:
|
|
390
|
+
raise ValueError(
|
|
391
|
+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
|
392
|
+
f"Please flatten variable-length inputs before processing."
|
|
393
|
+
)
|
|
394
|
+
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
|
395
|
+
raise ValueError(
|
|
396
|
+
f"The number of initial states is expected to be equal to the number of input sequences, "
|
|
397
|
+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
|
398
|
+
)
|
|
399
|
+
scale = k.shape[-1] ** -0.5 if scale is None else scale
|
|
400
|
+
o, final_state = ChunkDPLRDeltaRuleFunction.apply(
|
|
401
|
+
q,
|
|
402
|
+
k,
|
|
403
|
+
v,
|
|
404
|
+
a,
|
|
405
|
+
b,
|
|
406
|
+
gk,
|
|
407
|
+
scale,
|
|
408
|
+
initial_state,
|
|
409
|
+
output_final_state,
|
|
410
|
+
cu_seqlens,
|
|
411
|
+
)
|
|
412
|
+
return o, final_state
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def chunk_rwkv7(
|
|
416
|
+
r: torch.Tensor,
|
|
417
|
+
k: torch.Tensor,
|
|
418
|
+
v: torch.Tensor,
|
|
419
|
+
a: torch.Tensor,
|
|
420
|
+
b: torch.Tensor,
|
|
421
|
+
w: torch.Tensor = None,
|
|
422
|
+
log_w: torch.Tensor = None,
|
|
423
|
+
scale: float = 1.0,
|
|
424
|
+
initial_state: torch.Tensor = None,
|
|
425
|
+
output_final_state: bool = True,
|
|
426
|
+
):
|
|
427
|
+
"""
|
|
428
|
+
Interface function for RWKV-7 attention.
|
|
429
|
+
|
|
430
|
+
RWKV-7 注意力机制的接口函数。
|
|
431
|
+
"""
|
|
432
|
+
|
|
433
|
+
if w is not None:
|
|
434
|
+
log_w = -torch.exp(w)
|
|
435
|
+
else:
|
|
436
|
+
assert log_w is not None, "Either w or log_w must be provided!"
|
|
437
|
+
|
|
438
|
+
return chunk_dplr_delta_rule(
|
|
439
|
+
q=r,
|
|
440
|
+
k=k,
|
|
441
|
+
v=v,
|
|
442
|
+
a=a,
|
|
443
|
+
b=b,
|
|
444
|
+
gk=log_w,
|
|
445
|
+
scale=scale,
|
|
446
|
+
initial_state=initial_state,
|
|
447
|
+
output_final_state=output_final_state,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def transpose_head(x, head_first):
|
|
452
|
+
if head_first:
|
|
453
|
+
x = torch.permute(x, dims=(0, 2, 1, 3))
|
|
454
|
+
out = cast(x, torch.bfloat16).contiguous()
|
|
455
|
+
return out
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def generalized_delta_rule(
|
|
459
|
+
r: torch.Tensor,
|
|
460
|
+
w: torch.Tensor,
|
|
461
|
+
k: torch.Tensor,
|
|
462
|
+
v: torch.Tensor,
|
|
463
|
+
a: torch.Tensor,
|
|
464
|
+
b: torch.Tensor,
|
|
465
|
+
initial_state: torch.Tensor = None,
|
|
466
|
+
output_final_state: bool = True,
|
|
467
|
+
head_first: bool = False,
|
|
468
|
+
):
|
|
469
|
+
dtype = r.dtype
|
|
470
|
+
r = transpose_head(r, head_first)
|
|
471
|
+
k = transpose_head(k, head_first)
|
|
472
|
+
v = transpose_head(v, head_first)
|
|
473
|
+
a = transpose_head(a, head_first)
|
|
474
|
+
b = transpose_head(b, head_first)
|
|
475
|
+
w = transpose_head(w, head_first)
|
|
476
|
+
if w.device.type == "cuda":
|
|
477
|
+
out, state = chunk_rwkv7(
|
|
478
|
+
r=r,
|
|
479
|
+
k=k,
|
|
480
|
+
v=v,
|
|
481
|
+
a=a,
|
|
482
|
+
b=b,
|
|
483
|
+
w=w,
|
|
484
|
+
initial_state=initial_state,
|
|
485
|
+
output_final_state=output_final_state,
|
|
486
|
+
)
|
|
487
|
+
else:
|
|
488
|
+
from .native_keras_op import generalized_delta_rule
|
|
489
|
+
|
|
490
|
+
out, state = generalized_delta_rule(
|
|
491
|
+
r=r,
|
|
492
|
+
k=k,
|
|
493
|
+
v=v,
|
|
494
|
+
a=a,
|
|
495
|
+
b=b,
|
|
496
|
+
w=w,
|
|
497
|
+
initial_state=initial_state,
|
|
498
|
+
output_final_state=output_final_state,
|
|
499
|
+
)
|
|
500
|
+
out = transpose_head(out, head_first)
|
|
501
|
+
if output_final_state:
|
|
502
|
+
return out, cast(state, dtype)
|
|
503
|
+
else:
|
|
504
|
+
return out
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# ---------- chunk_A ----------
|
|
2
|
+
from .chunk_A_bwd import (
|
|
3
|
+
chunk_dplr_bwd_kernel_intra,
|
|
4
|
+
chunk_dplr_bwd_dgk_kernel,
|
|
5
|
+
)
|
|
6
|
+
from .chunk_A_fwd import chunk_dplr_fwd_A_kernel_intra_sub_intra
|
|
7
|
+
|
|
8
|
+
# ---------- chunk_h ----------
|
|
9
|
+
from .chunk_h_bwd import chunk_dplr_bwd_kernel_dhu
|
|
10
|
+
from .chunk_h_fwd import chunk_dplr_fwd_kernel_h
|
|
11
|
+
|
|
12
|
+
# ---------- chunk_o ----------
|
|
13
|
+
from .chunk_o_bwd import (
|
|
14
|
+
chunk_dplr_bwd_kernel_dAu,
|
|
15
|
+
chunk_dplr_bwd_o_kernel,
|
|
16
|
+
chunk_dplr_bwd_kernel_dv,
|
|
17
|
+
)
|
|
18
|
+
from .chunk_o_fwd import chunk_dplr_fwd_kernel_o
|
|
19
|
+
|
|
20
|
+
# ---------- cumsum ----------
|
|
21
|
+
from .cumsum import chunk_rwkv6_fwd_cumsum_kernel
|
|
22
|
+
|
|
23
|
+
# ---------- wy_fast ----------
|
|
24
|
+
from .wy_fast_bwd import (
|
|
25
|
+
prepare_wy_repr_bwd_kernel,
|
|
26
|
+
)
|
|
27
|
+
from .wy_fast_fwd import (
|
|
28
|
+
prepare_wy_repr_fwd_kernel_chunk32,
|
|
29
|
+
prepare_wy_repr_fwd_kernel_chunk64,
|
|
30
|
+
wu_fwd_kernel,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# ---------- utils ----------
|
|
34
|
+
from .utils import *
|