rwkv-ops 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- rwkv_ops/__init__.py +26 -0
- rwkv_ops/rwkv7_kernel/__init__.py +153 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +221 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +95 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +523 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.1.0.dist-info/LICENSE.txt +201 -0
- rwkv_ops-0.1.0.dist-info/METADATA +118 -0
- rwkv_ops-0.1.0.dist-info/RECORD +43 -0
- rwkv_ops-0.1.0.dist-info/WHEEL +5 -0
- rwkv_ops-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import triton
|
|
8
|
+
|
|
9
|
+
from ..get_torch_devices_info import check_shared_mem
|
|
10
|
+
from ..triton_kernel.chunk_h_fwd import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def chunk_dplr_fwd_h(
|
|
14
|
+
kg: torch.Tensor,
|
|
15
|
+
v: torch.Tensor,
|
|
16
|
+
w: torch.Tensor,
|
|
17
|
+
u: torch.Tensor,
|
|
18
|
+
bg: torch.Tensor,
|
|
19
|
+
gk: torch.Tensor,
|
|
20
|
+
initial_state: Optional[torch.Tensor] = None,
|
|
21
|
+
output_final_state: bool = False,
|
|
22
|
+
chunk_size: int = 64,
|
|
23
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
24
|
+
B, T, H, K, V = *kg.shape, u.shape[-1]
|
|
25
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
26
|
+
|
|
27
|
+
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
|
28
|
+
BK = triton.next_power_of_2(K)
|
|
29
|
+
assert BK <= 256, "current kernel does not support head dimension larger than 256."
|
|
30
|
+
# H100 can have larger block size
|
|
31
|
+
|
|
32
|
+
if check_shared_mem("hopper", kg.device.index):
|
|
33
|
+
BV = 64
|
|
34
|
+
BC = 64 if K <= 128 else 32
|
|
35
|
+
elif check_shared_mem("ampere", kg.device.index): # A100
|
|
36
|
+
BV = 32
|
|
37
|
+
BC = 32
|
|
38
|
+
else:
|
|
39
|
+
BV = 16
|
|
40
|
+
BC = 16
|
|
41
|
+
|
|
42
|
+
BC = min(BT, BC)
|
|
43
|
+
NK = triton.cdiv(K, BK)
|
|
44
|
+
NV = triton.cdiv(V, BV)
|
|
45
|
+
assert NK == 1, (
|
|
46
|
+
"NK > 1 is not supported because it involves time-consuming synchronization"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
h = kg.new_empty(B, NT, H, K, V)
|
|
50
|
+
final_state = (
|
|
51
|
+
kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
|
|
52
|
+
)
|
|
53
|
+
v_new = torch.empty_like(u)
|
|
54
|
+
grid = (NK, NV, N * H)
|
|
55
|
+
chunk_dplr_fwd_kernel_h[grid](
|
|
56
|
+
kg=kg,
|
|
57
|
+
v=v,
|
|
58
|
+
w=w,
|
|
59
|
+
bg=bg,
|
|
60
|
+
u=u,
|
|
61
|
+
v_new=v_new,
|
|
62
|
+
h=h,
|
|
63
|
+
gk=gk,
|
|
64
|
+
h0=initial_state,
|
|
65
|
+
ht=final_state,
|
|
66
|
+
T=T,
|
|
67
|
+
H=H,
|
|
68
|
+
K=K,
|
|
69
|
+
V=V,
|
|
70
|
+
BT=BT,
|
|
71
|
+
BC=BC,
|
|
72
|
+
BK=BK,
|
|
73
|
+
BV=BV,
|
|
74
|
+
)
|
|
75
|
+
return h, v_new, final_state
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import triton
|
|
8
|
+
|
|
9
|
+
from ..get_torch_devices_info import check_shared_mem
|
|
10
|
+
from ..triton_kernel.chunk_o_bwd import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def chunk_dplr_bwd_dv(
|
|
14
|
+
A_qk: torch.Tensor,
|
|
15
|
+
kg: torch.Tensor,
|
|
16
|
+
do: torch.Tensor,
|
|
17
|
+
dh: torch.Tensor,
|
|
18
|
+
chunk_size: int = 64,
|
|
19
|
+
) -> torch.Tensor:
|
|
20
|
+
B, T, H, K, V = *kg.shape, do.shape[-1]
|
|
21
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
22
|
+
|
|
23
|
+
NT = triton.cdiv(T, BT)
|
|
24
|
+
|
|
25
|
+
dv = torch.empty_like(do)
|
|
26
|
+
|
|
27
|
+
def grid(meta):
|
|
28
|
+
return (triton.cdiv(V, meta["BV"]), NT, B * H)
|
|
29
|
+
|
|
30
|
+
chunk_dplr_bwd_kernel_dv[grid](
|
|
31
|
+
A_qk=A_qk,
|
|
32
|
+
kg=kg,
|
|
33
|
+
do=do,
|
|
34
|
+
dv=dv,
|
|
35
|
+
dh=dh,
|
|
36
|
+
T=T,
|
|
37
|
+
H=H,
|
|
38
|
+
K=K,
|
|
39
|
+
V=V,
|
|
40
|
+
BT=BT,
|
|
41
|
+
)
|
|
42
|
+
return dv
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def chunk_dplr_bwd_o(
|
|
46
|
+
k: torch.Tensor,
|
|
47
|
+
b: torch.Tensor,
|
|
48
|
+
v: torch.Tensor,
|
|
49
|
+
v_new: torch.Tensor,
|
|
50
|
+
gk: torch.Tensor,
|
|
51
|
+
do: torch.Tensor,
|
|
52
|
+
h: torch.Tensor,
|
|
53
|
+
dh: torch.Tensor,
|
|
54
|
+
dv: torch.Tensor,
|
|
55
|
+
w: torch.Tensor,
|
|
56
|
+
chunk_size: int = 64,
|
|
57
|
+
scale: float = 1.0,
|
|
58
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
59
|
+
B, T, H, K, V = *w.shape, v.shape[-1]
|
|
60
|
+
|
|
61
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
62
|
+
NT = triton.cdiv(T, BT)
|
|
63
|
+
|
|
64
|
+
BK = (
|
|
65
|
+
min(triton.next_power_of_2(K), 64)
|
|
66
|
+
if check_shared_mem()
|
|
67
|
+
else min(triton.next_power_of_2(K), 32)
|
|
68
|
+
)
|
|
69
|
+
BV = (
|
|
70
|
+
min(triton.next_power_of_2(V), 64)
|
|
71
|
+
if check_shared_mem()
|
|
72
|
+
else min(triton.next_power_of_2(K), 32)
|
|
73
|
+
)
|
|
74
|
+
NK = triton.cdiv(K, BK)
|
|
75
|
+
dq = torch.empty_like(k)
|
|
76
|
+
dk = torch.empty_like(k)
|
|
77
|
+
dw = torch.empty_like(w)
|
|
78
|
+
db = torch.empty_like(b)
|
|
79
|
+
grid = (NK, NT, B * H)
|
|
80
|
+
|
|
81
|
+
dgk_last = torch.empty(B, NT, H, K, dtype=torch.float, device=w.device)
|
|
82
|
+
|
|
83
|
+
chunk_dplr_bwd_o_kernel[grid](
|
|
84
|
+
k=k,
|
|
85
|
+
b=b,
|
|
86
|
+
v=v,
|
|
87
|
+
v_new=v_new,
|
|
88
|
+
h=h,
|
|
89
|
+
do=do,
|
|
90
|
+
dh=dh,
|
|
91
|
+
dq=dq,
|
|
92
|
+
dk=dk,
|
|
93
|
+
db=db,
|
|
94
|
+
dgk_last=dgk_last,
|
|
95
|
+
w=w,
|
|
96
|
+
dv=dv,
|
|
97
|
+
dw=dw,
|
|
98
|
+
gk=gk,
|
|
99
|
+
T=T,
|
|
100
|
+
H=H,
|
|
101
|
+
K=K,
|
|
102
|
+
V=V,
|
|
103
|
+
BT=BT,
|
|
104
|
+
BK=BK,
|
|
105
|
+
BV=BV,
|
|
106
|
+
)
|
|
107
|
+
return dq, dk, dw, db, dgk_last
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def chunk_dplr_bwd_dAu(
|
|
111
|
+
v: torch.Tensor,
|
|
112
|
+
v_new: torch.Tensor,
|
|
113
|
+
do: torch.Tensor,
|
|
114
|
+
A_qb: torch.Tensor,
|
|
115
|
+
scale: float,
|
|
116
|
+
chunk_size: int = 64,
|
|
117
|
+
) -> torch.Tensor:
|
|
118
|
+
B, T, H, V = v.shape
|
|
119
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
120
|
+
NT = triton.cdiv(T, BT)
|
|
121
|
+
|
|
122
|
+
if check_shared_mem("ampere"): # A100
|
|
123
|
+
BV = min(triton.next_power_of_2(V), 128)
|
|
124
|
+
elif check_shared_mem("ada"): # 4090
|
|
125
|
+
BV = min(triton.next_power_of_2(V), 64)
|
|
126
|
+
else:
|
|
127
|
+
BV = min(triton.next_power_of_2(V), 32)
|
|
128
|
+
|
|
129
|
+
grid = (NT, B * H)
|
|
130
|
+
dA_qk = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device)
|
|
131
|
+
dA_qb = torch.empty(B, T, H, BT, dtype=torch.float, device=v.device)
|
|
132
|
+
dv_new = torch.empty_like(v_new)
|
|
133
|
+
chunk_dplr_bwd_kernel_dAu[grid](
|
|
134
|
+
v=v,
|
|
135
|
+
do=do,
|
|
136
|
+
v_new=v_new,
|
|
137
|
+
A_qb=A_qb,
|
|
138
|
+
dA_qk=dA_qk,
|
|
139
|
+
dA_qb=dA_qb,
|
|
140
|
+
dv_new=dv_new,
|
|
141
|
+
scale=scale,
|
|
142
|
+
T=T,
|
|
143
|
+
H=H,
|
|
144
|
+
V=V,
|
|
145
|
+
BT=BT,
|
|
146
|
+
BV=BV,
|
|
147
|
+
)
|
|
148
|
+
return dv_new, dA_qk, dA_qb
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import triton
|
|
7
|
+
|
|
8
|
+
from ..triton_kernel.chunk_o_fwd import *
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def chunk_dplr_fwd_o(
|
|
12
|
+
qg: torch.Tensor,
|
|
13
|
+
v: torch.Tensor,
|
|
14
|
+
v_new: torch.Tensor,
|
|
15
|
+
A_qk: torch.Tensor,
|
|
16
|
+
A_qb: torch.Tensor,
|
|
17
|
+
h: torch.Tensor,
|
|
18
|
+
chunk_size: int = 64,
|
|
19
|
+
) -> torch.Tensor:
|
|
20
|
+
B, T, H, K, V = *qg.shape, v.shape[-1]
|
|
21
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
22
|
+
|
|
23
|
+
NT = triton.cdiv(T, BT)
|
|
24
|
+
|
|
25
|
+
o = torch.empty_like(v)
|
|
26
|
+
|
|
27
|
+
def grid(meta):
|
|
28
|
+
return (triton.cdiv(V, meta["BV"]), NT, B * H)
|
|
29
|
+
|
|
30
|
+
chunk_dplr_fwd_kernel_o[grid](
|
|
31
|
+
qg=qg,
|
|
32
|
+
v=v,
|
|
33
|
+
v_new=v_new,
|
|
34
|
+
A_qk=A_qk,
|
|
35
|
+
A_qb=A_qb,
|
|
36
|
+
h=h,
|
|
37
|
+
o=o,
|
|
38
|
+
T=T,
|
|
39
|
+
H=H,
|
|
40
|
+
K=K,
|
|
41
|
+
V=V,
|
|
42
|
+
BT=BT,
|
|
43
|
+
)
|
|
44
|
+
return o
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from ..triton_kernel.cumsum import *
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def chunk_rwkv6_fwd_cumsum(
|
|
6
|
+
g: torch.Tensor,
|
|
7
|
+
chunk_size: int,
|
|
8
|
+
) -> torch.Tensor:
|
|
9
|
+
B, T, H, S = g.shape
|
|
10
|
+
BT = chunk_size
|
|
11
|
+
NT = triton.cdiv(T, BT)
|
|
12
|
+
|
|
13
|
+
gi, ge = (
|
|
14
|
+
torch.empty_like(g, dtype=torch.float),
|
|
15
|
+
torch.empty_like(g, dtype=torch.float),
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
def grid(meta):
|
|
19
|
+
return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H)
|
|
20
|
+
|
|
21
|
+
# keep cummulative normalizer in fp32
|
|
22
|
+
chunk_rwkv6_fwd_cumsum_kernel[grid](
|
|
23
|
+
g,
|
|
24
|
+
T,
|
|
25
|
+
gi,
|
|
26
|
+
ge,
|
|
27
|
+
H=H,
|
|
28
|
+
S=S,
|
|
29
|
+
BT=BT,
|
|
30
|
+
)
|
|
31
|
+
return gi, ge
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import triton
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from ..get_torch_devices_info import check_shared_mem
|
|
11
|
+
from ..triton_kernel.wy_fast_bwd import *
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def chunk_dplr_bwd_wy(
|
|
15
|
+
A_ab_inv: torch.Tensor,
|
|
16
|
+
A_ak: torch.Tensor,
|
|
17
|
+
v: torch.Tensor,
|
|
18
|
+
ag: torch.Tensor,
|
|
19
|
+
dw: torch.Tensor,
|
|
20
|
+
du: torch.Tensor,
|
|
21
|
+
dv0: torch.Tensor,
|
|
22
|
+
chunk_size: int = 16,
|
|
23
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
24
|
+
A_ab_inv, A_ak, v, ag, dw, du = map(
|
|
25
|
+
lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du]
|
|
26
|
+
)
|
|
27
|
+
B, T, H, K, V = *dw.shape, du.shape[-1]
|
|
28
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
29
|
+
|
|
30
|
+
NT = triton.cdiv(T, BT)
|
|
31
|
+
BK = min(triton.next_power_of_2(K), 64)
|
|
32
|
+
BV = (
|
|
33
|
+
min(triton.next_power_of_2(V), 64)
|
|
34
|
+
if check_shared_mem()
|
|
35
|
+
else min(triton.next_power_of_2(V), 32)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float)
|
|
39
|
+
dA_ak = torch.empty_like(A_ak, dtype=torch.float)
|
|
40
|
+
dv = torch.empty_like(v)
|
|
41
|
+
dag = torch.empty_like(ag)
|
|
42
|
+
|
|
43
|
+
prepare_wy_repr_bwd_kernel[(NT, B * H)](
|
|
44
|
+
A_ab_inv=A_ab_inv,
|
|
45
|
+
A_ak=A_ak,
|
|
46
|
+
ag=ag,
|
|
47
|
+
v=v,
|
|
48
|
+
dw=dw,
|
|
49
|
+
du=du,
|
|
50
|
+
dv=dv,
|
|
51
|
+
dv0=dv0,
|
|
52
|
+
dag=dag,
|
|
53
|
+
dAak=dA_ak,
|
|
54
|
+
dAab=dA_ab,
|
|
55
|
+
T=T,
|
|
56
|
+
H=H,
|
|
57
|
+
K=K,
|
|
58
|
+
V=V,
|
|
59
|
+
BT=BT,
|
|
60
|
+
BK=BK,
|
|
61
|
+
BV=BV,
|
|
62
|
+
)
|
|
63
|
+
return dA_ab, dA_ak, dv, dag
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import triton
|
|
8
|
+
|
|
9
|
+
from ..triton_kernel.wy_fast_fwd import *
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def wu_fwd(
|
|
13
|
+
ag: torch.Tensor,
|
|
14
|
+
v: torch.Tensor,
|
|
15
|
+
A_ak: torch.Tensor,
|
|
16
|
+
A_ab_inv: torch.Tensor,
|
|
17
|
+
chunk_size: int,
|
|
18
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
19
|
+
B, T, H, K, V = *ag.shape, v.shape[-1]
|
|
20
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
21
|
+
|
|
22
|
+
NT = triton.cdiv(T, BT)
|
|
23
|
+
BK = min(triton.next_power_of_2(K), 64)
|
|
24
|
+
BV = min(triton.next_power_of_2(V), 64)
|
|
25
|
+
|
|
26
|
+
w = torch.empty_like(ag)
|
|
27
|
+
u = torch.empty_like(v)
|
|
28
|
+
wu_fwd_kernel[(NT, B * H)](
|
|
29
|
+
ag=ag,
|
|
30
|
+
v=v,
|
|
31
|
+
A_ak=A_ak,
|
|
32
|
+
A_ab_inv=A_ab_inv,
|
|
33
|
+
w=w,
|
|
34
|
+
u=u,
|
|
35
|
+
T=T,
|
|
36
|
+
H=H,
|
|
37
|
+
K=K,
|
|
38
|
+
V=V,
|
|
39
|
+
BT=BT,
|
|
40
|
+
BK=BK,
|
|
41
|
+
BV=BV,
|
|
42
|
+
)
|
|
43
|
+
return w, u
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def prepare_wy_repr_fwd(
|
|
47
|
+
ag: torch.Tensor,
|
|
48
|
+
v: torch.Tensor,
|
|
49
|
+
A_ak: torch.Tensor,
|
|
50
|
+
A_ab: torch.Tensor,
|
|
51
|
+
chunk_size: int = 64,
|
|
52
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
53
|
+
B, T, H, _ = ag.shape
|
|
54
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
55
|
+
|
|
56
|
+
NT = triton.cdiv(T, BT)
|
|
57
|
+
BC = min(BT, 32)
|
|
58
|
+
fwd_fn = (
|
|
59
|
+
prepare_wy_repr_fwd_kernel_chunk64
|
|
60
|
+
if BT == 64
|
|
61
|
+
else prepare_wy_repr_fwd_kernel_chunk32
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
A_ab_inv = torch.empty_like(A_ab)
|
|
65
|
+
fwd_fn[(NT, B * H)](
|
|
66
|
+
A_ab=A_ab,
|
|
67
|
+
A_ab_inv=A_ab_inv,
|
|
68
|
+
T=T,
|
|
69
|
+
H=H,
|
|
70
|
+
BT=BT,
|
|
71
|
+
BC=BC,
|
|
72
|
+
)
|
|
73
|
+
w, u = wu_fwd(ag=ag, v=v, A_ak=A_ak, A_ab_inv=A_ab_inv, chunk_size=BT)
|
|
74
|
+
return w, u, A_ab_inv
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
fwd_prepare_wy_repr = prepare_wy_repr_fwd
|
|
78
|
+
|
|
79
|
+
fwd_wu = wu_fwd
|