rwkv-ops 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_ops/__init__.py +45 -0
- rwkv_ops/mhc_kernel/__init__.py +50 -0
- rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
- rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
- rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
- rwkv_ops/rwkv6_kernel/__init__.py +120 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +113 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.6.1.dist-info/METADATA +495 -0
- rwkv_ops-0.6.1.dist-info/RECORD +89 -0
- rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
- rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import triton
|
|
8
|
+
|
|
9
|
+
from ..get_torch_devices_info import check_shared_mem
|
|
10
|
+
from ..triton_kernel.chunk_h_bwd import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def chunk_dplr_bwd_dhu(
|
|
14
|
+
qg: torch.Tensor,
|
|
15
|
+
bg: torch.Tensor,
|
|
16
|
+
w: torch.Tensor,
|
|
17
|
+
gk: torch.Tensor,
|
|
18
|
+
h0: torch.Tensor,
|
|
19
|
+
dht: Optional[torch.Tensor],
|
|
20
|
+
do: torch.Tensor,
|
|
21
|
+
dv: torch.Tensor,
|
|
22
|
+
chunk_size: int = 64,
|
|
23
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
24
|
+
B, T, H, K, V = *qg.shape, do.shape[-1]
|
|
25
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
26
|
+
BK = triton.next_power_of_2(K)
|
|
27
|
+
assert BK <= 256, (
|
|
28
|
+
"current kernel does not support head dimension being larger than 256."
|
|
29
|
+
)
|
|
30
|
+
# H100
|
|
31
|
+
if check_shared_mem("hopper", qg.device.index):
|
|
32
|
+
BV = 64
|
|
33
|
+
BC = 64 if K <= 128 else 32
|
|
34
|
+
elif check_shared_mem("ampere", qg.device.index): # A100
|
|
35
|
+
BV = 32
|
|
36
|
+
BC = 32
|
|
37
|
+
else: # Etc: 4090
|
|
38
|
+
BV = 16
|
|
39
|
+
BC = 16
|
|
40
|
+
|
|
41
|
+
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
|
42
|
+
|
|
43
|
+
BC = min(BT, BC)
|
|
44
|
+
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
|
45
|
+
assert NK == 1, (
|
|
46
|
+
"NK > 1 is not supported because it involves time-consuming synchronization"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
dh = qg.new_empty(B, NT, H, K, V)
|
|
50
|
+
dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
|
|
51
|
+
dv2 = torch.zeros_like(dv)
|
|
52
|
+
|
|
53
|
+
grid = (NK, NV, N * H)
|
|
54
|
+
chunk_dplr_bwd_kernel_dhu[grid](
|
|
55
|
+
qg=qg,
|
|
56
|
+
bg=bg,
|
|
57
|
+
w=w,
|
|
58
|
+
gk=gk,
|
|
59
|
+
dht=dht,
|
|
60
|
+
dh0=dh0,
|
|
61
|
+
do=do,
|
|
62
|
+
dh=dh,
|
|
63
|
+
dv=dv,
|
|
64
|
+
dv2=dv2,
|
|
65
|
+
T=T,
|
|
66
|
+
H=H,
|
|
67
|
+
K=K,
|
|
68
|
+
V=V,
|
|
69
|
+
BT=BT,
|
|
70
|
+
BC=BC,
|
|
71
|
+
BK=BK,
|
|
72
|
+
BV=BV,
|
|
73
|
+
)
|
|
74
|
+
return dh, dh0, dv2
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import triton
|
|
8
|
+
|
|
9
|
+
from ..get_torch_devices_info import check_shared_mem
|
|
10
|
+
from ..triton_kernel.chunk_h_fwd import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def chunk_dplr_fwd_h(
|
|
14
|
+
kg: torch.Tensor,
|
|
15
|
+
v: torch.Tensor,
|
|
16
|
+
w: torch.Tensor,
|
|
17
|
+
u: torch.Tensor,
|
|
18
|
+
bg: torch.Tensor,
|
|
19
|
+
gk: torch.Tensor,
|
|
20
|
+
initial_state: Optional[torch.Tensor] = None,
|
|
21
|
+
output_final_state: bool = False,
|
|
22
|
+
chunk_size: int = 64,
|
|
23
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
24
|
+
B, T, H, K, V = *kg.shape, u.shape[-1]
|
|
25
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
26
|
+
|
|
27
|
+
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
|
28
|
+
BK = triton.next_power_of_2(K)
|
|
29
|
+
assert BK <= 256, "current kernel does not support head dimension larger than 256."
|
|
30
|
+
# H100 can have larger block size
|
|
31
|
+
|
|
32
|
+
if check_shared_mem("hopper", kg.device.index):
|
|
33
|
+
BV = 64
|
|
34
|
+
BC = 64 if K <= 128 else 32
|
|
35
|
+
elif check_shared_mem("ampere", kg.device.index): # A100
|
|
36
|
+
BV = 32
|
|
37
|
+
BC = 32
|
|
38
|
+
else:
|
|
39
|
+
BV = 16
|
|
40
|
+
BC = 16
|
|
41
|
+
|
|
42
|
+
BC = min(BT, BC)
|
|
43
|
+
NK = triton.cdiv(K, BK)
|
|
44
|
+
NV = triton.cdiv(V, BV)
|
|
45
|
+
assert NK == 1, (
|
|
46
|
+
"NK > 1 is not supported because it involves time-consuming synchronization"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
h = kg.new_empty(B, NT, H, K, V)
|
|
50
|
+
final_state = (
|
|
51
|
+
kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
|
|
52
|
+
)
|
|
53
|
+
v_new = torch.empty_like(u)
|
|
54
|
+
grid = (NK, NV, N * H)
|
|
55
|
+
chunk_dplr_fwd_kernel_h[grid](
|
|
56
|
+
kg=kg,
|
|
57
|
+
v=v,
|
|
58
|
+
w=w,
|
|
59
|
+
bg=bg,
|
|
60
|
+
u=u,
|
|
61
|
+
v_new=v_new,
|
|
62
|
+
h=h,
|
|
63
|
+
gk=gk,
|
|
64
|
+
h0=initial_state,
|
|
65
|
+
ht=final_state,
|
|
66
|
+
T=T,
|
|
67
|
+
H=H,
|
|
68
|
+
K=K,
|
|
69
|
+
V=V,
|
|
70
|
+
BT=BT,
|
|
71
|
+
BC=BC,
|
|
72
|
+
BK=BK,
|
|
73
|
+
BV=BV,
|
|
74
|
+
)
|
|
75
|
+
return h, v_new, final_state
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import triton
|
|
8
|
+
|
|
9
|
+
from ..get_torch_devices_info import check_shared_mem
|
|
10
|
+
from ..triton_kernel.chunk_o_bwd import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def chunk_dplr_bwd_dv(
|
|
14
|
+
A_qk: torch.Tensor,
|
|
15
|
+
kg: torch.Tensor,
|
|
16
|
+
do: torch.Tensor,
|
|
17
|
+
dh: torch.Tensor,
|
|
18
|
+
chunk_size: int = 64,
|
|
19
|
+
) -> torch.Tensor:
|
|
20
|
+
B, T, H, K, V = *kg.shape, do.shape[-1]
|
|
21
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
22
|
+
|
|
23
|
+
NT = triton.cdiv(T, BT)
|
|
24
|
+
|
|
25
|
+
dv = torch.empty_like(do)
|
|
26
|
+
|
|
27
|
+
def grid(meta):
|
|
28
|
+
return (triton.cdiv(V, meta["BV"]), NT, B * H)
|
|
29
|
+
|
|
30
|
+
chunk_dplr_bwd_kernel_dv[grid](
|
|
31
|
+
A_qk=A_qk,
|
|
32
|
+
kg=kg,
|
|
33
|
+
do=do,
|
|
34
|
+
dv=dv,
|
|
35
|
+
dh=dh,
|
|
36
|
+
T=T,
|
|
37
|
+
H=H,
|
|
38
|
+
K=K,
|
|
39
|
+
V=V,
|
|
40
|
+
BT=BT,
|
|
41
|
+
)
|
|
42
|
+
return dv
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def chunk_dplr_bwd_o(
|
|
46
|
+
k: torch.Tensor,
|
|
47
|
+
b: torch.Tensor,
|
|
48
|
+
v: torch.Tensor,
|
|
49
|
+
v_new: torch.Tensor,
|
|
50
|
+
gk: torch.Tensor,
|
|
51
|
+
do: torch.Tensor,
|
|
52
|
+
h: torch.Tensor,
|
|
53
|
+
dh: torch.Tensor,
|
|
54
|
+
dv: torch.Tensor,
|
|
55
|
+
w: torch.Tensor,
|
|
56
|
+
chunk_size: int = 64,
|
|
57
|
+
scale: float = 1.0,
|
|
58
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
59
|
+
B, T, H, K, V = *w.shape, v.shape[-1]
|
|
60
|
+
|
|
61
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
62
|
+
NT = triton.cdiv(T, BT)
|
|
63
|
+
|
|
64
|
+
BK = (
|
|
65
|
+
min(triton.next_power_of_2(K), 64)
|
|
66
|
+
if check_shared_mem()
|
|
67
|
+
else min(triton.next_power_of_2(K), 32)
|
|
68
|
+
)
|
|
69
|
+
BV = (
|
|
70
|
+
min(triton.next_power_of_2(V), 64)
|
|
71
|
+
if check_shared_mem()
|
|
72
|
+
else min(triton.next_power_of_2(K), 32)
|
|
73
|
+
)
|
|
74
|
+
NK = triton.cdiv(K, BK)
|
|
75
|
+
dq = torch.empty_like(k)
|
|
76
|
+
dk = torch.empty_like(k)
|
|
77
|
+
dw = torch.empty_like(w)
|
|
78
|
+
db = torch.empty_like(b)
|
|
79
|
+
grid = (NK, NT, B * H)
|
|
80
|
+
|
|
81
|
+
dgk_last = torch.empty(B, NT, H, K, dtype=torch.float, device=w.device)
|
|
82
|
+
|
|
83
|
+
chunk_dplr_bwd_o_kernel[grid](
|
|
84
|
+
k=k,
|
|
85
|
+
b=b,
|
|
86
|
+
v=v,
|
|
87
|
+
v_new=v_new,
|
|
88
|
+
h=h,
|
|
89
|
+
do=do,
|
|
90
|
+
dh=dh,
|
|
91
|
+
dq=dq,
|
|
92
|
+
dk=dk,
|
|
93
|
+
db=db,
|
|
94
|
+
dgk_last=dgk_last,
|
|
95
|
+
w=w,
|
|
96
|
+
dv=dv,
|
|
97
|
+
dw=dw,
|
|
98
|
+
gk=gk,
|
|
99
|
+
T=T,
|
|
100
|
+
H=H,
|
|
101
|
+
K=K,
|
|
102
|
+
V=V,
|
|
103
|
+
BT=BT,
|
|
104
|
+
BK=BK,
|
|
105
|
+
BV=BV,
|
|
106
|
+
)
|
|
107
|
+
return (dq, dk, dw, db, dgk_last)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def chunk_dplr_bwd_dAu(
|
|
111
|
+
v: torch.Tensor,
|
|
112
|
+
v_new: torch.Tensor,
|
|
113
|
+
do: torch.Tensor,
|
|
114
|
+
A_qb: torch.Tensor,
|
|
115
|
+
scale: float,
|
|
116
|
+
chunk_size: int = 64,
|
|
117
|
+
) -> torch.Tensor:
|
|
118
|
+
B, T, H, V = v.shape
|
|
119
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
120
|
+
NT = triton.cdiv(T, BT)
|
|
121
|
+
|
|
122
|
+
if check_shared_mem("ampere"): # A100
|
|
123
|
+
BV = min(triton.next_power_of_2(V), 128)
|
|
124
|
+
elif check_shared_mem("ada"): # 4090
|
|
125
|
+
BV = min(triton.next_power_of_2(V), 64)
|
|
126
|
+
else:
|
|
127
|
+
BV = min(triton.next_power_of_2(V), 32)
|
|
128
|
+
|
|
129
|
+
grid = (NT, B * H)
|
|
130
|
+
dA_qk = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device)
|
|
131
|
+
dA_qb = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device)
|
|
132
|
+
dv_new = torch.empty_like(v_new)
|
|
133
|
+
chunk_dplr_bwd_kernel_dAu[grid](
|
|
134
|
+
v=v,
|
|
135
|
+
do=do,
|
|
136
|
+
v_new=v_new,
|
|
137
|
+
A_qb=A_qb,
|
|
138
|
+
dA_qk=dA_qk,
|
|
139
|
+
dA_qb=dA_qb,
|
|
140
|
+
dv_new=dv_new,
|
|
141
|
+
scale=scale,
|
|
142
|
+
T=T,
|
|
143
|
+
H=H,
|
|
144
|
+
V=V,
|
|
145
|
+
BT=BT,
|
|
146
|
+
BV=BV,
|
|
147
|
+
)
|
|
148
|
+
return dv_new, dA_qk, dA_qb
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import triton
|
|
7
|
+
|
|
8
|
+
from ..triton_kernel.chunk_o_fwd import *
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def chunk_dplr_fwd_o(
|
|
12
|
+
qg: torch.Tensor,
|
|
13
|
+
v: torch.Tensor,
|
|
14
|
+
v_new: torch.Tensor,
|
|
15
|
+
A_qk: torch.Tensor,
|
|
16
|
+
A_qb: torch.Tensor,
|
|
17
|
+
h: torch.Tensor,
|
|
18
|
+
chunk_size: int = 64,
|
|
19
|
+
) -> torch.Tensor:
|
|
20
|
+
B, T, H, K, V = *qg.shape, v.shape[-1]
|
|
21
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
22
|
+
|
|
23
|
+
NT = triton.cdiv(T, BT)
|
|
24
|
+
|
|
25
|
+
o = torch.empty_like(v)
|
|
26
|
+
|
|
27
|
+
def grid(meta):
|
|
28
|
+
return (triton.cdiv(V, meta["BV"]), NT, B * H)
|
|
29
|
+
|
|
30
|
+
chunk_dplr_fwd_kernel_o[grid](
|
|
31
|
+
qg=qg,
|
|
32
|
+
v=v,
|
|
33
|
+
v_new=v_new,
|
|
34
|
+
A_qk=A_qk,
|
|
35
|
+
A_qb=A_qb,
|
|
36
|
+
h=h,
|
|
37
|
+
o=o,
|
|
38
|
+
T=T,
|
|
39
|
+
H=H,
|
|
40
|
+
K=K,
|
|
41
|
+
V=V,
|
|
42
|
+
BT=BT,
|
|
43
|
+
)
|
|
44
|
+
return o
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from ..triton_kernel.cumsum import *
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def chunk_rwkv6_fwd_cumsum(
|
|
6
|
+
g: torch.Tensor,
|
|
7
|
+
chunk_size: int,
|
|
8
|
+
) -> torch.Tensor:
|
|
9
|
+
B, T, H, S = g.shape
|
|
10
|
+
BT = chunk_size
|
|
11
|
+
NT = triton.cdiv(T, BT)
|
|
12
|
+
|
|
13
|
+
gi, ge = (
|
|
14
|
+
torch.empty_like(g, dtype=torch.float),
|
|
15
|
+
torch.empty_like(g, dtype=torch.float),
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
def grid(meta):
|
|
19
|
+
return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H)
|
|
20
|
+
|
|
21
|
+
# keep cummulative normalizer in fp32
|
|
22
|
+
chunk_rwkv6_fwd_cumsum_kernel[grid](
|
|
23
|
+
g,
|
|
24
|
+
T,
|
|
25
|
+
gi,
|
|
26
|
+
ge,
|
|
27
|
+
H=H,
|
|
28
|
+
S=S,
|
|
29
|
+
BT=BT,
|
|
30
|
+
)
|
|
31
|
+
return gi, ge
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import triton
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from ..get_torch_devices_info import check_shared_mem
|
|
11
|
+
from ..triton_kernel.wy_fast_bwd import *
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def chunk_dplr_bwd_wy(
|
|
15
|
+
A_ab_inv: torch.Tensor,
|
|
16
|
+
A_ak: torch.Tensor,
|
|
17
|
+
v: torch.Tensor,
|
|
18
|
+
ag: torch.Tensor,
|
|
19
|
+
dw: torch.Tensor,
|
|
20
|
+
du: torch.Tensor,
|
|
21
|
+
dv0: torch.Tensor,
|
|
22
|
+
chunk_size: int = 16,
|
|
23
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
24
|
+
A_ab_inv, A_ak, v, ag, dw, du = map(
|
|
25
|
+
lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du]
|
|
26
|
+
)
|
|
27
|
+
B, T, H, K, V = *dw.shape, du.shape[-1]
|
|
28
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
29
|
+
|
|
30
|
+
NT = triton.cdiv(T, BT)
|
|
31
|
+
BK = min(triton.next_power_of_2(K), 64)
|
|
32
|
+
BV = (
|
|
33
|
+
min(triton.next_power_of_2(V), 64)
|
|
34
|
+
if check_shared_mem()
|
|
35
|
+
else min(triton.next_power_of_2(V), 32)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float)
|
|
39
|
+
dA_ak = torch.empty_like(A_ak, dtype=torch.float)
|
|
40
|
+
dv = torch.empty_like(v)
|
|
41
|
+
dag = torch.empty_like(ag)
|
|
42
|
+
|
|
43
|
+
prepare_wy_repr_bwd_kernel[(NT, B * H)](
|
|
44
|
+
A_ab_inv=A_ab_inv,
|
|
45
|
+
A_ak=A_ak,
|
|
46
|
+
ag=ag,
|
|
47
|
+
v=v,
|
|
48
|
+
dw=dw,
|
|
49
|
+
du=du,
|
|
50
|
+
dv=dv,
|
|
51
|
+
dv0=dv0,
|
|
52
|
+
dag=dag,
|
|
53
|
+
dAak=dA_ak,
|
|
54
|
+
dAab=dA_ab,
|
|
55
|
+
T=T,
|
|
56
|
+
H=H,
|
|
57
|
+
K=K,
|
|
58
|
+
V=V,
|
|
59
|
+
BT=BT,
|
|
60
|
+
BK=BK,
|
|
61
|
+
BV=BV,
|
|
62
|
+
)
|
|
63
|
+
return dA_ab, dA_ak, dv, dag
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import triton
|
|
8
|
+
|
|
9
|
+
from ..triton_kernel.wy_fast_fwd import *
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def wu_fwd(
|
|
13
|
+
ag: torch.Tensor,
|
|
14
|
+
v: torch.Tensor,
|
|
15
|
+
A_ak: torch.Tensor,
|
|
16
|
+
A_ab_inv: torch.Tensor,
|
|
17
|
+
chunk_size: int,
|
|
18
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
19
|
+
B, T, H, K, V = *ag.shape, v.shape[-1]
|
|
20
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
21
|
+
|
|
22
|
+
NT = triton.cdiv(T, BT)
|
|
23
|
+
BK = min(triton.next_power_of_2(K), 64)
|
|
24
|
+
BV = min(triton.next_power_of_2(V), 64)
|
|
25
|
+
|
|
26
|
+
w = torch.empty_like(ag)
|
|
27
|
+
u = torch.empty_like(v)
|
|
28
|
+
wu_fwd_kernel[(NT, B * H)](
|
|
29
|
+
ag=ag,
|
|
30
|
+
v=v,
|
|
31
|
+
A_ak=A_ak,
|
|
32
|
+
A_ab_inv=A_ab_inv,
|
|
33
|
+
w=w,
|
|
34
|
+
u=u,
|
|
35
|
+
T=T,
|
|
36
|
+
H=H,
|
|
37
|
+
K=K,
|
|
38
|
+
V=V,
|
|
39
|
+
BT=BT,
|
|
40
|
+
BK=BK,
|
|
41
|
+
BV=BV,
|
|
42
|
+
)
|
|
43
|
+
return w, u
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def prepare_wy_repr_fwd(
|
|
47
|
+
ag: torch.Tensor,
|
|
48
|
+
v: torch.Tensor,
|
|
49
|
+
A_ak: torch.Tensor,
|
|
50
|
+
A_ab: torch.Tensor,
|
|
51
|
+
chunk_size: int = 64,
|
|
52
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
53
|
+
B, T, H, _ = ag.shape
|
|
54
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
55
|
+
|
|
56
|
+
NT = triton.cdiv(T, BT)
|
|
57
|
+
BC = min(BT, 32)
|
|
58
|
+
fwd_fn = (
|
|
59
|
+
prepare_wy_repr_fwd_kernel_chunk64
|
|
60
|
+
if BT == 64
|
|
61
|
+
else prepare_wy_repr_fwd_kernel_chunk32
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
A_ab_inv = torch.empty_like(A_ab)
|
|
65
|
+
fwd_fn[(NT, B * H)](
|
|
66
|
+
A_ab=A_ab,
|
|
67
|
+
A_ab_inv=A_ab_inv,
|
|
68
|
+
T=T,
|
|
69
|
+
H=H,
|
|
70
|
+
BT=BT,
|
|
71
|
+
BC=BC,
|
|
72
|
+
)
|
|
73
|
+
w, u = wu_fwd(ag=ag, v=v, A_ak=A_ak, A_ab_inv=A_ab_inv, chunk_size=BT)
|
|
74
|
+
return w, u, A_ab_inv
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
fwd_prepare_wy_repr = prepare_wy_repr_fwd
|
|
78
|
+
|
|
79
|
+
fwd_wu = wu_fwd
|