rwkv-ops 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_ops/__init__.py +45 -0
- rwkv_ops/mhc_kernel/__init__.py +50 -0
- rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
- rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
- rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
- rwkv_ops/rwkv6_kernel/__init__.py +120 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +113 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.6.1.dist-info/METADATA +495 -0
- rwkv_ops-0.6.1.dist-info/RECORD +89 -0
- rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
- rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,328 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import triton
|
|
6
|
+
import triton.language as tl
|
|
7
|
+
|
|
8
|
+
from ..triton_kernel.utils import exp, gather, use_cuda_graph
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@triton.autotune(
|
|
12
|
+
configs=[
|
|
13
|
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
|
14
|
+
for num_warps in [2, 4, 8, 16, 32]
|
|
15
|
+
for num_stages in [2, 3, 4]
|
|
16
|
+
],
|
|
17
|
+
key=["BK", "BT", "K"],
|
|
18
|
+
use_cuda_graph=use_cuda_graph,
|
|
19
|
+
)
|
|
20
|
+
@triton.jit(do_not_specialize=["T"])
|
|
21
|
+
def chunk_dplr_bwd_kernel_intra(
|
|
22
|
+
q,
|
|
23
|
+
k,
|
|
24
|
+
a,
|
|
25
|
+
b,
|
|
26
|
+
gi,
|
|
27
|
+
ge,
|
|
28
|
+
dAqk,
|
|
29
|
+
dAqb,
|
|
30
|
+
dAak,
|
|
31
|
+
dAab,
|
|
32
|
+
dqg,
|
|
33
|
+
dkg,
|
|
34
|
+
dag,
|
|
35
|
+
dbg,
|
|
36
|
+
T,
|
|
37
|
+
dq,
|
|
38
|
+
dk,
|
|
39
|
+
da,
|
|
40
|
+
db,
|
|
41
|
+
dgk,
|
|
42
|
+
dgk_offset,
|
|
43
|
+
scale: tl.constexpr,
|
|
44
|
+
H: tl.constexpr,
|
|
45
|
+
K: tl.constexpr,
|
|
46
|
+
BT: tl.constexpr,
|
|
47
|
+
BC: tl.constexpr,
|
|
48
|
+
BK: tl.constexpr,
|
|
49
|
+
GATHER_SUPPORTED: tl.constexpr,
|
|
50
|
+
):
|
|
51
|
+
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
|
52
|
+
i_b, i_h = i_bh // H, i_bh % H
|
|
53
|
+
if False:
|
|
54
|
+
i_n, i_t = (
|
|
55
|
+
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
|
56
|
+
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
|
57
|
+
)
|
|
58
|
+
bos, eos = (
|
|
59
|
+
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
60
|
+
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
61
|
+
)
|
|
62
|
+
T = eos - bos
|
|
63
|
+
else:
|
|
64
|
+
bos, eos = (i_b * T).to(tl.int32), (i_b * T + T).to(tl.int32)
|
|
65
|
+
|
|
66
|
+
if i_t * BT >= T:
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
# offset calculation
|
|
70
|
+
ge += (bos * H + i_h) * K
|
|
71
|
+
gi += (bos * H + i_h) * K
|
|
72
|
+
q += (bos * H + i_h) * K
|
|
73
|
+
a += (bos * H + i_h) * K
|
|
74
|
+
b += (bos * H + i_h) * K
|
|
75
|
+
k += (bos * H + i_h) * K
|
|
76
|
+
dq += (bos * H + i_h) * K
|
|
77
|
+
dk += (bos * H + i_h) * K
|
|
78
|
+
da += (bos * H + i_h) * K
|
|
79
|
+
db += (bos * H + i_h) * K
|
|
80
|
+
dqg += (bos * H + i_h) * K
|
|
81
|
+
dag += (bos * H + i_h) * K
|
|
82
|
+
dkg += (bos * H + i_h) * K
|
|
83
|
+
dbg += (bos * H + i_h) * K
|
|
84
|
+
dgk += (bos * H + i_h) * K
|
|
85
|
+
dgk_offset += (bos * H + i_h) * K
|
|
86
|
+
dAqk += (bos * H + i_h) * BT
|
|
87
|
+
dAqb += (bos * H + i_h) * BT
|
|
88
|
+
dAak += (bos * H + i_h) * BT
|
|
89
|
+
dAab += (bos * H + i_h) * BT
|
|
90
|
+
|
|
91
|
+
stride_qk = H * K
|
|
92
|
+
stride_A = H * BT
|
|
93
|
+
|
|
94
|
+
p_ge = tl.make_block_ptr(
|
|
95
|
+
ge, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
96
|
+
)
|
|
97
|
+
p_gi = tl.make_block_ptr(
|
|
98
|
+
gi, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
99
|
+
)
|
|
100
|
+
# [BC, BK]
|
|
101
|
+
b_ge = tl.load(p_ge, boundary_check=(0, 1))
|
|
102
|
+
b_gi = tl.load(p_gi, boundary_check=(0, 1))
|
|
103
|
+
b_dq = tl.zeros([BC, BK], dtype=tl.float32)
|
|
104
|
+
b_da = tl.zeros([BC, BK], dtype=tl.float32)
|
|
105
|
+
b_dk = tl.zeros([BC, BK], dtype=tl.float32)
|
|
106
|
+
b_db = tl.zeros([BC, BK], dtype=tl.float32)
|
|
107
|
+
# intra chunk gradient calculation
|
|
108
|
+
p_dAqk = tl.make_block_ptr(
|
|
109
|
+
dAqk, (T, BT), (stride_A, 1), (i_t * BT, 0), (BC, BC), (1, 0)
|
|
110
|
+
)
|
|
111
|
+
p_dAab = tl.make_block_ptr(
|
|
112
|
+
dAab, (T, BT), (stride_A, 1), (i_t * BT, 0), (BC, BC), (1, 0)
|
|
113
|
+
)
|
|
114
|
+
p_dAqb = tl.make_block_ptr(
|
|
115
|
+
dAqb, (T, BT), (stride_A, 1), (i_t * BT, 0), (BC, BC), (1, 0)
|
|
116
|
+
)
|
|
117
|
+
p_dAak = tl.make_block_ptr(
|
|
118
|
+
dAak, (T, BT), (stride_A, 1), (i_t * BT, 0), (BC, BC), (1, 0)
|
|
119
|
+
)
|
|
120
|
+
o_i = tl.arange(0, BC)
|
|
121
|
+
p_k = tl.make_block_ptr(
|
|
122
|
+
k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
123
|
+
)
|
|
124
|
+
p_b = tl.make_block_ptr(
|
|
125
|
+
b, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
126
|
+
)
|
|
127
|
+
p_a = tl.make_block_ptr(
|
|
128
|
+
a, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
129
|
+
)
|
|
130
|
+
p_q = tl.make_block_ptr(
|
|
131
|
+
q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
132
|
+
)
|
|
133
|
+
b_k = tl.load(p_k, boundary_check=(0, 1))
|
|
134
|
+
b_b = tl.load(p_b, boundary_check=(0, 1))
|
|
135
|
+
b_q = tl.load(p_q, boundary_check=(0, 1))
|
|
136
|
+
b_a = tl.load(p_a, boundary_check=(0, 1))
|
|
137
|
+
b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1))
|
|
138
|
+
b_dAab = tl.load(p_dAab, boundary_check=(0, 1))
|
|
139
|
+
b_dAqb = tl.load(p_dAqb, boundary_check=(0, 1))
|
|
140
|
+
b_dAak = tl.load(p_dAak, boundary_check=(0, 1))
|
|
141
|
+
|
|
142
|
+
# inter chunk gradient calculation
|
|
143
|
+
o_k = i_k * BK + tl.arange(0, BK)
|
|
144
|
+
m_k = o_k < K
|
|
145
|
+
# intra chunk gradient calculation
|
|
146
|
+
for j in range(0, min(BC, T - i_t * BT)):
|
|
147
|
+
# trick to index the block
|
|
148
|
+
if GATHER_SUPPORTED:
|
|
149
|
+
row_idx = tl.full([1, BK], j, dtype=tl.int16)
|
|
150
|
+
col_idx = tl.full([BC, 1], j, dtype=tl.int16)
|
|
151
|
+
row_idx_bc = tl.full([1, BC], j, dtype=tl.int16)
|
|
152
|
+
# [1, BK]
|
|
153
|
+
b_kj = gather(b_k, row_idx, axis=0)
|
|
154
|
+
b_bj = gather(b_b, row_idx, axis=0)
|
|
155
|
+
b_gij = gather(b_gi, row_idx, axis=0)
|
|
156
|
+
b_gej = gather(b_ge, row_idx, axis=0)
|
|
157
|
+
b_qj = gather(b_q, row_idx, axis=0)
|
|
158
|
+
b_aj = gather(b_a, row_idx, axis=0)
|
|
159
|
+
# [BC, 1]
|
|
160
|
+
b_dAqk_j = gather(b_dAqk, col_idx, axis=1)
|
|
161
|
+
b_dAab_j = gather(b_dAab, col_idx, axis=1)
|
|
162
|
+
b_dAqb_j = gather(b_dAqb, col_idx, axis=1)
|
|
163
|
+
b_dAak_j = gather(b_dAak, col_idx, axis=1)
|
|
164
|
+
# [1, BC] -> [BC, 1]
|
|
165
|
+
b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
|
|
166
|
+
b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
|
|
167
|
+
b_dA_ab_j = tl.sum(gather(b_dAab, row_idx_bc, axis=0), 0)[:, None]
|
|
168
|
+
b_dA_qb_j = tl.sum(gather(b_dAqb, row_idx_bc, axis=0), 0)[:, None]
|
|
169
|
+
b_dA_ak_j = tl.sum(gather(b_dAak, row_idx_bc, axis=0), 0)[:, None]
|
|
170
|
+
else:
|
|
171
|
+
mask_idx = tl.arange(0, BC) == j
|
|
172
|
+
b_kj = tl.sum(tl.where(mask_idx[:, None], b_k, 0), 0)[None, :]
|
|
173
|
+
b_bj = tl.sum(tl.where(mask_idx[:, None], b_b, 0), 0)[None, :]
|
|
174
|
+
b_gij = tl.sum(tl.where(mask_idx[:, None], b_gi, 0), 0)[None, :]
|
|
175
|
+
b_gej = tl.sum(tl.where(mask_idx[:, None], b_ge, 0), 0)[None, :]
|
|
176
|
+
b_dAqk_j = tl.sum(tl.where(mask_idx[None, :], b_dAqk, 0), 1)[:, None]
|
|
177
|
+
b_dAab_j = tl.sum(tl.where(mask_idx[None, :], b_dAab, 0), 1)[:, None]
|
|
178
|
+
b_dAqb_j = tl.sum(tl.where(mask_idx[None, :], b_dAqb, 0), 1)[:, None]
|
|
179
|
+
b_dAak_j = tl.sum(tl.where(mask_idx[None, :], b_dAak, 0), 1)[:, None]
|
|
180
|
+
b_dA_qk_j = tl.sum(tl.where(mask_idx[:, None], b_dAqk, 0), 0)[:, None]
|
|
181
|
+
b_dA_ab_j = tl.sum(tl.where(mask_idx[:, None], b_dAab, 0), 0)[:, None]
|
|
182
|
+
b_dA_qb_j = tl.sum(tl.where(mask_idx[:, None], b_dAqb, 0), 0)[:, None]
|
|
183
|
+
b_dA_ak_j = tl.sum(tl.where(mask_idx[:, None], b_dAak, 0), 0)[:, None]
|
|
184
|
+
# [1, BK] b_qj, b_aj
|
|
185
|
+
b_qj = tl.sum(tl.where(mask_idx[:, None], b_q, 0), 0)[None, :]
|
|
186
|
+
b_aj = tl.sum(tl.where(mask_idx[:, None], b_a, 0), 0)[None, :]
|
|
187
|
+
|
|
188
|
+
m_e = o_i[:, None] > j
|
|
189
|
+
m_i = o_i[:, None] >= j
|
|
190
|
+
tmp1 = exp(b_gi - b_gij)
|
|
191
|
+
tmp2 = exp(b_ge - b_gij)
|
|
192
|
+
b_dq += tl.where(m_i, b_dAqk_j * b_kj * tmp1, 0.0)
|
|
193
|
+
b_dq += tl.where(m_i, b_dAqb_j * b_bj * tmp1, 0.0)
|
|
194
|
+
b_da += tl.where(m_e, b_dAab_j * b_bj * tmp2, 0.0)
|
|
195
|
+
b_da += tl.where(m_e, b_dAak_j * b_kj * tmp2, 0.0)
|
|
196
|
+
|
|
197
|
+
m_i = o_i[:, None] <= j
|
|
198
|
+
m_e = o_i[:, None] < j
|
|
199
|
+
tmp1 = exp(b_gij - b_gi)
|
|
200
|
+
tmp2 = exp(b_gej - b_gi)
|
|
201
|
+
b_dk += tl.where(m_i, b_dA_qk_j * b_qj * tmp1, 0.0)
|
|
202
|
+
b_dk += tl.where(m_e, b_dA_ak_j * b_aj * tmp2, 0.0)
|
|
203
|
+
b_db += tl.where(m_i, b_dA_qb_j * b_qj * tmp1, 0.0)
|
|
204
|
+
b_db += tl.where(m_e, b_dA_ab_j * b_aj * tmp2, 0.0)
|
|
205
|
+
|
|
206
|
+
# post processing
|
|
207
|
+
p_dq = tl.make_block_ptr(
|
|
208
|
+
dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
209
|
+
)
|
|
210
|
+
p_dk = tl.make_block_ptr(
|
|
211
|
+
dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
212
|
+
)
|
|
213
|
+
p_da = tl.make_block_ptr(
|
|
214
|
+
da, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
215
|
+
)
|
|
216
|
+
p_db = tl.make_block_ptr(
|
|
217
|
+
db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
218
|
+
)
|
|
219
|
+
p_dgk = tl.make_block_ptr(
|
|
220
|
+
dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
221
|
+
)
|
|
222
|
+
p_dgk_offset = tl.make_block_ptr(
|
|
223
|
+
dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
224
|
+
)
|
|
225
|
+
p_dqg = tl.make_block_ptr(
|
|
226
|
+
dqg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
227
|
+
)
|
|
228
|
+
p_dkg = tl.make_block_ptr(
|
|
229
|
+
dkg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
230
|
+
)
|
|
231
|
+
p_dag = tl.make_block_ptr(
|
|
232
|
+
dag, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
233
|
+
)
|
|
234
|
+
p_dbg = tl.make_block_ptr(
|
|
235
|
+
dbg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)
|
|
236
|
+
)
|
|
237
|
+
p_gn = gi + (min(i_t * BT + BT, T) - 1) * stride_qk + o_k
|
|
238
|
+
p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
|
|
239
|
+
b_gn = tl.load(p_gn, mask=m_k, other=0)
|
|
240
|
+
b_da += tl.load(p_dag, boundary_check=(0, 1)) * exp(b_ge)
|
|
241
|
+
b_dq += tl.load(p_dqg, boundary_check=(0, 1)) * exp(b_gi) * scale
|
|
242
|
+
tmp = exp(b_gn[None, :] - b_gi)
|
|
243
|
+
b_dk += tl.load(p_dkg, boundary_check=(0, 1)).to(tl.float32) * tmp
|
|
244
|
+
b_db += tl.load(p_dbg, boundary_check=(0, 1)).to(tl.float32) * tmp
|
|
245
|
+
tl.store(p_dq, (b_dq).to(p_dq.dtype.element_ty), boundary_check=(0, 1))
|
|
246
|
+
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
|
|
247
|
+
tl.store(p_da, b_da.to(p_da.dtype.element_ty), boundary_check=(0, 1))
|
|
248
|
+
tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1))
|
|
249
|
+
b_dgk = (b_dq * b_q + b_da * b_a - b_dk * b_k - b_db * b_b).to(tl.float32)
|
|
250
|
+
b_dgk_offset = b_da * b_a
|
|
251
|
+
tl.store(p_dgk, b_dgk.to(p_dgk.dtype.element_ty), boundary_check=(0, 1))
|
|
252
|
+
tl.store(
|
|
253
|
+
p_dgk_offset,
|
|
254
|
+
b_dgk_offset.to(p_dgk_offset.dtype.element_ty),
|
|
255
|
+
boundary_check=(0, 1),
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
@triton.autotune(
|
|
260
|
+
configs=[
|
|
261
|
+
triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
|
|
262
|
+
for num_warps in [2, 4, 8, 16, 32]
|
|
263
|
+
for num_stages in [2, 3, 4]
|
|
264
|
+
for BK in [32, 64]
|
|
265
|
+
],
|
|
266
|
+
key=["BK", "BT", "K"],
|
|
267
|
+
use_cuda_graph=use_cuda_graph,
|
|
268
|
+
)
|
|
269
|
+
@triton.jit(do_not_specialize=["T"])
|
|
270
|
+
def chunk_dplr_bwd_dgk_kernel(
|
|
271
|
+
dgk,
|
|
272
|
+
dgk_offset,
|
|
273
|
+
dgk_last,
|
|
274
|
+
T,
|
|
275
|
+
dgk_output,
|
|
276
|
+
H: tl.constexpr,
|
|
277
|
+
K: tl.constexpr,
|
|
278
|
+
BT: tl.constexpr,
|
|
279
|
+
BK: tl.constexpr,
|
|
280
|
+
):
|
|
281
|
+
i_t, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
|
282
|
+
i_b, i_h = i_bh // H, i_bh % H
|
|
283
|
+
if False:
|
|
284
|
+
i_tg = i_t
|
|
285
|
+
i_n, i_t = (
|
|
286
|
+
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
|
287
|
+
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
|
288
|
+
)
|
|
289
|
+
bos, eos = (
|
|
290
|
+
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
291
|
+
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
292
|
+
)
|
|
293
|
+
T = eos - bos
|
|
294
|
+
NT = tl.cdiv(T, BT)
|
|
295
|
+
else:
|
|
296
|
+
NT = tl.cdiv(T, BT)
|
|
297
|
+
i_tg = (i_b * NT + i_t).to(tl.int32)
|
|
298
|
+
bos, eos = (i_b * T).to(tl.int32), (i_b * T + T).to(tl.int32)
|
|
299
|
+
|
|
300
|
+
stride_qk = H * K
|
|
301
|
+
dgk += (bos * H + i_h) * K
|
|
302
|
+
dgk_offset += (bos * H + i_h) * K
|
|
303
|
+
dgk_last += (i_tg * H + i_h) * K
|
|
304
|
+
dgk_output += (bos * H + i_h) * K
|
|
305
|
+
p_dgk_last = dgk_last + tl.arange(0, BK) + i_k * BK
|
|
306
|
+
m_k = tl.arange(0, BK) + i_k * BK < K
|
|
307
|
+
b_dgk_last = tl.load(p_dgk_last, mask=m_k, other=0)
|
|
308
|
+
p_dgk_offset = tl.make_block_ptr(
|
|
309
|
+
dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
|
310
|
+
)
|
|
311
|
+
p_dgk = tl.make_block_ptr(
|
|
312
|
+
dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
|
313
|
+
)
|
|
314
|
+
b_dgk = tl.load(p_dgk, boundary_check=(0, 1))
|
|
315
|
+
b_dgk_offset = tl.load(p_dgk_offset, boundary_check=(0, 1))
|
|
316
|
+
# m_inv_cumsum = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]).to(tl.float32)
|
|
317
|
+
# b_dgk_cumsum = tl.dot(m_inv_cumsum, b_dgk, allow_tf32=False)
|
|
318
|
+
b_dgk_cumsum = tl.cumsum(b_dgk, 0, reverse=True)
|
|
319
|
+
b_dgk_cumsum += b_dgk_last[None, :]
|
|
320
|
+
b_dgk_cumsum -= b_dgk_offset
|
|
321
|
+
p_dgk_output = tl.make_block_ptr(
|
|
322
|
+
dgk_output, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
|
323
|
+
)
|
|
324
|
+
tl.store(
|
|
325
|
+
p_dgk_output,
|
|
326
|
+
b_dgk_cumsum.to(p_dgk_output.dtype.element_ty),
|
|
327
|
+
boundary_check=(0, 1),
|
|
328
|
+
)
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import triton
|
|
6
|
+
import triton.language as tl
|
|
7
|
+
|
|
8
|
+
from ..triton_kernel.utils import exp, gather, use_cuda_graph
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@triton.autotune(
|
|
12
|
+
configs=[
|
|
13
|
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
|
14
|
+
for num_warps in [2, 4, 8, 16, 32]
|
|
15
|
+
for num_stages in [2, 3, 4]
|
|
16
|
+
],
|
|
17
|
+
key=["BK", "BT"],
|
|
18
|
+
use_cuda_graph=use_cuda_graph,
|
|
19
|
+
)
|
|
20
|
+
@triton.jit(do_not_specialize=["T"])
|
|
21
|
+
def chunk_dplr_fwd_A_kernel_intra_sub_intra(
|
|
22
|
+
q,
|
|
23
|
+
k,
|
|
24
|
+
a,
|
|
25
|
+
b,
|
|
26
|
+
gi,
|
|
27
|
+
ge,
|
|
28
|
+
T,
|
|
29
|
+
qg,
|
|
30
|
+
kg,
|
|
31
|
+
ag,
|
|
32
|
+
bg,
|
|
33
|
+
Aqk,
|
|
34
|
+
Aqb,
|
|
35
|
+
Aab,
|
|
36
|
+
Aak,
|
|
37
|
+
scale: tl.constexpr,
|
|
38
|
+
H: tl.constexpr,
|
|
39
|
+
K: tl.constexpr,
|
|
40
|
+
BT: tl.constexpr,
|
|
41
|
+
BC: tl.constexpr,
|
|
42
|
+
BK: tl.constexpr,
|
|
43
|
+
GATHER_SUPPORTED: tl.constexpr,
|
|
44
|
+
):
|
|
45
|
+
i_t, i_b, i_h = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
|
46
|
+
|
|
47
|
+
if False:
|
|
48
|
+
i_n, i_t = (
|
|
49
|
+
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
|
50
|
+
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
|
51
|
+
)
|
|
52
|
+
bos, eos = (
|
|
53
|
+
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
54
|
+
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
55
|
+
)
|
|
56
|
+
T = eos - bos
|
|
57
|
+
else:
|
|
58
|
+
bos, eos = i_b * T, i_b * T + T
|
|
59
|
+
|
|
60
|
+
if i_t * BT >= T:
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
o_i = tl.arange(0, BC)
|
|
64
|
+
o_k = tl.arange(0, BK)
|
|
65
|
+
m_k = o_k < K
|
|
66
|
+
m_A = (i_t * BT + tl.arange(0, BC)) < T
|
|
67
|
+
last_idx = min((i_t + 1) * BT, T) - 1
|
|
68
|
+
o_A = (bos + i_t * BT + tl.arange(0, BC)) * H * BT + i_h * BT
|
|
69
|
+
p_q = tl.make_block_ptr(
|
|
70
|
+
q + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
|
|
71
|
+
)
|
|
72
|
+
p_k = tl.make_block_ptr(
|
|
73
|
+
k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
|
|
74
|
+
)
|
|
75
|
+
p_a = tl.make_block_ptr(
|
|
76
|
+
a + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
|
|
77
|
+
)
|
|
78
|
+
p_b = tl.make_block_ptr(
|
|
79
|
+
b + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
|
|
80
|
+
)
|
|
81
|
+
p_gi = tl.make_block_ptr(
|
|
82
|
+
gi + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
|
|
83
|
+
)
|
|
84
|
+
p_ge = tl.make_block_ptr(
|
|
85
|
+
ge + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
|
|
86
|
+
)
|
|
87
|
+
p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK)
|
|
88
|
+
b_g_last = tl.load(p_g_last, mask=m_k, other=0)
|
|
89
|
+
p_qg = tl.make_block_ptr(
|
|
90
|
+
qg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
|
|
91
|
+
)
|
|
92
|
+
p_kg = tl.make_block_ptr(
|
|
93
|
+
kg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
|
|
94
|
+
)
|
|
95
|
+
p_ag = tl.make_block_ptr(
|
|
96
|
+
ag + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
|
|
97
|
+
)
|
|
98
|
+
p_bg = tl.make_block_ptr(
|
|
99
|
+
bg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, 0), (BC, BK), (1, 0)
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
b_q = tl.load(p_q, boundary_check=(0, 1))
|
|
103
|
+
b_q = b_q * scale
|
|
104
|
+
b_k = tl.load(p_k, boundary_check=(0, 1))
|
|
105
|
+
b_a = tl.load(p_a, boundary_check=(0, 1))
|
|
106
|
+
b_b = tl.load(p_b, boundary_check=(0, 1))
|
|
107
|
+
b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32)
|
|
108
|
+
b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32)
|
|
109
|
+
|
|
110
|
+
# deal with decay term.
|
|
111
|
+
g_exp = exp(b_gi)
|
|
112
|
+
g_exp_inv = exp(-b_gi + b_g_last[None, :])
|
|
113
|
+
b_qg = b_q * g_exp
|
|
114
|
+
b_kg = b_k * g_exp_inv
|
|
115
|
+
b_bg = b_b * g_exp_inv
|
|
116
|
+
b_ag = b_a * exp(b_ge)
|
|
117
|
+
tl.store(
|
|
118
|
+
p_qg,
|
|
119
|
+
b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
120
|
+
boundary_check=(0, 1),
|
|
121
|
+
)
|
|
122
|
+
tl.store(
|
|
123
|
+
p_bg,
|
|
124
|
+
b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
125
|
+
boundary_check=(0, 1),
|
|
126
|
+
)
|
|
127
|
+
tl.store(
|
|
128
|
+
p_ag,
|
|
129
|
+
b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
130
|
+
boundary_check=(0, 1),
|
|
131
|
+
)
|
|
132
|
+
tl.store(
|
|
133
|
+
p_kg,
|
|
134
|
+
b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
135
|
+
boundary_check=(0, 1),
|
|
136
|
+
)
|
|
137
|
+
# tl.debug_barrier()
|
|
138
|
+
|
|
139
|
+
b_q = b_q.to(b_k.dtype)
|
|
140
|
+
# inner attn
|
|
141
|
+
for j in range(0, min(BC, T - i_t * BT)):
|
|
142
|
+
# a trick to index the j-th row of b_k, b_g, b_b
|
|
143
|
+
if GATHER_SUPPORTED:
|
|
144
|
+
row_idx = tl.full([1, BK], j, dtype=tl.int16)
|
|
145
|
+
# [1, BK]
|
|
146
|
+
b_k_j = gather(b_k, row_idx, axis=0)
|
|
147
|
+
b_gk_j = gather(b_gi, row_idx, axis=0)
|
|
148
|
+
b_b_j = gather(b_b, row_idx, axis=0)
|
|
149
|
+
else:
|
|
150
|
+
mask = tl.arange(0, BC) == j
|
|
151
|
+
b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :]
|
|
152
|
+
b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :]
|
|
153
|
+
b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :]
|
|
154
|
+
tmp = exp(b_gi - b_gk_j)
|
|
155
|
+
b_A_qk = tl.sum(b_q * b_k_j * tmp, 1)
|
|
156
|
+
m_i = (o_i >= j).to(tl.float32)
|
|
157
|
+
b_A_qk = b_A_qk * m_i
|
|
158
|
+
b_A_qb = tl.sum(b_q * b_b_j * tmp, 1)
|
|
159
|
+
b_A_qb = b_A_qb * m_i
|
|
160
|
+
tmp2 = exp(b_ge - b_gk_j)
|
|
161
|
+
b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1)
|
|
162
|
+
m_i2 = (o_i > j).to(tl.float32)
|
|
163
|
+
b_A_ak = b_A_ak * m_i2
|
|
164
|
+
b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1)
|
|
165
|
+
b_A_ab = b_A_ab * m_i2
|
|
166
|
+
|
|
167
|
+
tl.store(
|
|
168
|
+
Aqk + o_A + j,
|
|
169
|
+
b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
170
|
+
mask=m_A,
|
|
171
|
+
)
|
|
172
|
+
tl.store(
|
|
173
|
+
Aqb + o_A + j,
|
|
174
|
+
b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
175
|
+
mask=m_A,
|
|
176
|
+
)
|
|
177
|
+
tl.store(
|
|
178
|
+
Aab + o_A + j,
|
|
179
|
+
b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
180
|
+
mask=m_A,
|
|
181
|
+
)
|
|
182
|
+
tl.store(
|
|
183
|
+
Aak + o_A + j,
|
|
184
|
+
b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
185
|
+
mask=m_A,
|
|
186
|
+
)
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import triton
|
|
6
|
+
import triton.language as tl
|
|
7
|
+
|
|
8
|
+
from ..triton_kernel.utils import exp, use_cuda_graph
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@triton.heuristics(
|
|
12
|
+
{
|
|
13
|
+
"USE_FINAL_STATE_GRADIENT": lambda args: args["dht"] is not None,
|
|
14
|
+
"USE_INITIAL_STATE": lambda args: args["dh0"] is not None,
|
|
15
|
+
}
|
|
16
|
+
)
|
|
17
|
+
@triton.autotune(
|
|
18
|
+
configs=[
|
|
19
|
+
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
|
20
|
+
for num_warps in [2, 4, 8, 16, 32]
|
|
21
|
+
for num_stages in [2, 3, 4]
|
|
22
|
+
],
|
|
23
|
+
key=["BT", "BK", "BV", "V"],
|
|
24
|
+
use_cuda_graph=use_cuda_graph,
|
|
25
|
+
)
|
|
26
|
+
@triton.jit(do_not_specialize=["T"])
|
|
27
|
+
def chunk_dplr_bwd_kernel_dhu(
|
|
28
|
+
qg,
|
|
29
|
+
bg,
|
|
30
|
+
w,
|
|
31
|
+
gk,
|
|
32
|
+
dht,
|
|
33
|
+
dv,
|
|
34
|
+
do,
|
|
35
|
+
T,
|
|
36
|
+
dh,
|
|
37
|
+
dh0,
|
|
38
|
+
dv2,
|
|
39
|
+
H: tl.constexpr,
|
|
40
|
+
K: tl.constexpr,
|
|
41
|
+
V: tl.constexpr,
|
|
42
|
+
BT: tl.constexpr,
|
|
43
|
+
BC: tl.constexpr,
|
|
44
|
+
BK: tl.constexpr,
|
|
45
|
+
BV: tl.constexpr,
|
|
46
|
+
USE_FINAL_STATE_GRADIENT: tl.constexpr,
|
|
47
|
+
USE_INITIAL_STATE: tl.constexpr,
|
|
48
|
+
):
|
|
49
|
+
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
|
50
|
+
i_n, i_h = i_nh // H, i_nh % H
|
|
51
|
+
if False:
|
|
52
|
+
bos, eos = (
|
|
53
|
+
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
54
|
+
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
55
|
+
)
|
|
56
|
+
T = eos - bos
|
|
57
|
+
NT = tl.cdiv(T, BT)
|
|
58
|
+
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
|
59
|
+
else:
|
|
60
|
+
bos, eos = i_n * T, i_n * T + T
|
|
61
|
+
NT = tl.cdiv(T, BT)
|
|
62
|
+
boh = i_n * NT
|
|
63
|
+
|
|
64
|
+
# [BK, BV]
|
|
65
|
+
b_dh = tl.zeros([BK, BV], dtype=tl.float32)
|
|
66
|
+
if USE_FINAL_STATE_GRADIENT:
|
|
67
|
+
p_dht = tl.make_block_ptr(
|
|
68
|
+
dht + i_nh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
|
|
69
|
+
)
|
|
70
|
+
b_dh += tl.load(p_dht, boundary_check=(0, 1))
|
|
71
|
+
|
|
72
|
+
mask_k = tl.arange(0, BK) < K
|
|
73
|
+
for i_t in range(NT - 1, -1, -1):
|
|
74
|
+
p_dh = tl.make_block_ptr(
|
|
75
|
+
dh + ((boh + i_t) * H + i_h) * K * V,
|
|
76
|
+
(K, V),
|
|
77
|
+
(V, 1),
|
|
78
|
+
(i_k * BK, i_v * BV),
|
|
79
|
+
(BK, BV),
|
|
80
|
+
(1, 0),
|
|
81
|
+
)
|
|
82
|
+
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
|
|
83
|
+
b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
|
|
84
|
+
for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
|
|
85
|
+
p_qg = tl.make_block_ptr(
|
|
86
|
+
qg + (bos * H + i_h) * K,
|
|
87
|
+
(K, T),
|
|
88
|
+
(1, H * K),
|
|
89
|
+
(i_k * BK, i_t * BT + i_c * BC),
|
|
90
|
+
(BK, BC),
|
|
91
|
+
(0, 1),
|
|
92
|
+
)
|
|
93
|
+
p_bg = tl.make_block_ptr(
|
|
94
|
+
bg + (bos * H + i_h) * K,
|
|
95
|
+
(T, K),
|
|
96
|
+
(H * K, 1),
|
|
97
|
+
(i_t * BT + i_c * BC, i_k * BK),
|
|
98
|
+
(BC, BK),
|
|
99
|
+
(1, 0),
|
|
100
|
+
)
|
|
101
|
+
p_w = tl.make_block_ptr(
|
|
102
|
+
w + (bos * H + i_h) * K,
|
|
103
|
+
(K, T),
|
|
104
|
+
(1, H * K),
|
|
105
|
+
(i_k * BK, i_t * BT + i_c * BC),
|
|
106
|
+
(BK, BC),
|
|
107
|
+
(0, 1),
|
|
108
|
+
)
|
|
109
|
+
p_dv = tl.make_block_ptr(
|
|
110
|
+
dv + (bos * H + i_h) * V,
|
|
111
|
+
(T, V),
|
|
112
|
+
(H * V, 1),
|
|
113
|
+
(i_t * BT + i_c * BC, i_v * BV),
|
|
114
|
+
(BC, BV),
|
|
115
|
+
(1, 0),
|
|
116
|
+
)
|
|
117
|
+
p_do = tl.make_block_ptr(
|
|
118
|
+
do + (bos * H + i_h) * V,
|
|
119
|
+
(T, V),
|
|
120
|
+
(H * V, 1),
|
|
121
|
+
(i_t * BT + i_c * BC, i_v * BV),
|
|
122
|
+
(BC, BV),
|
|
123
|
+
(1, 0),
|
|
124
|
+
)
|
|
125
|
+
p_dv2 = tl.make_block_ptr(
|
|
126
|
+
dv2 + (bos * H + i_h) * V,
|
|
127
|
+
(T, V),
|
|
128
|
+
(H * V, 1),
|
|
129
|
+
(i_t * BT + i_c * BC, i_v * BV),
|
|
130
|
+
(BC, BV),
|
|
131
|
+
(1, 0),
|
|
132
|
+
)
|
|
133
|
+
# [BK, BT]
|
|
134
|
+
b_qg = tl.load(p_qg, boundary_check=(0, 1))
|
|
135
|
+
# [BT, BK]
|
|
136
|
+
b_bg = tl.load(p_bg, boundary_check=(0, 1))
|
|
137
|
+
b_w = tl.load(p_w, boundary_check=(0, 1))
|
|
138
|
+
# [BT, V]
|
|
139
|
+
b_do = tl.load(p_do, boundary_check=(0, 1))
|
|
140
|
+
b_dv = tl.load(p_dv, boundary_check=(0, 1))
|
|
141
|
+
b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype))
|
|
142
|
+
tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
|
|
143
|
+
# [BK, BV]
|
|
144
|
+
b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype))
|
|
145
|
+
b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype))
|
|
146
|
+
last_idx = min((i_t + 1) * BT, T) - 1
|
|
147
|
+
bg_last = tl.load(
|
|
148
|
+
gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k
|
|
149
|
+
)
|
|
150
|
+
b_dh *= exp(bg_last)[:, None]
|
|
151
|
+
b_dh += b_dh_tmp
|
|
152
|
+
|
|
153
|
+
if USE_INITIAL_STATE:
|
|
154
|
+
p_dh0 = tl.make_block_ptr(
|
|
155
|
+
dh0 + i_nh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
|
|
156
|
+
)
|
|
157
|
+
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
|