fastvideo-kernel 0.2.5__tar.gz → 0.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/PKG-INFO +1 -1
- fastvideo_kernel-0.2.5/dist/fastvideo_kernel-0.2.5-cp310-cp310-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl → fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp310-cp310-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
- fastvideo_kernel-0.2.5/dist/fastvideo_kernel-0.2.5-cp311-cp311-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl → fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp311-cp311-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
- fastvideo_kernel-0.2.5/dist/fastvideo_kernel-0.2.5-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl → fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/pyproject.toml +1 -1
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/block_sparse_attn.py +2 -3
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/ops.py +0 -6
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton.py +314 -47
- fastvideo_kernel-0.2.6/python/fastvideo_kernel/version.py +1 -0
- fastvideo_kernel-0.2.5/python/fastvideo_kernel/version.py +0 -1
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/CMakeLists.txt +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/LICENSE +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/MANIFEST.in +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/README.md +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/benchmarks/bench_vsa.py +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/build.sh +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/attention/block_sparse_h100.cu +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/attention/st_attn_h100.cu +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/common_extension.cpp +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/common/common.hpp +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/common/launch.hpp +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/common/load.hpp +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/common/store.hpp +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/gemm/gemm.cu +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/gemm/kernel.hpp +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/gemm/launch.hpp +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/gemm/utils.hpp +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/norm/layernorm.cu +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/norm/layernorm.hpp +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/norm/rmsnorm.cu +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/norm/rmsnorm.hpp +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/quant/quant.cu +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/quant/quant.hpp +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/__init__.py +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/triton_kernels/index.py +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/triton_kernels/sla_triton.py +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/triton_kernels/st_attn_triton.py +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/turbodiffusion_ops.py +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/vmoba.py +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/__init__.py +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/support_flex_sta.py +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/test_sta.py +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/test_turbodiffusion.py +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/test_vmoba_correctness.py +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/test_vsa.py +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/test_vsa_forward.py +0 -0
- {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/utils.py +0 -0
{fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/block_sparse_attn.py
RENAMED
|
@@ -287,9 +287,8 @@ def block_sparse_attn(
|
|
|
287
287
|
block_sparse_fwd, block_sparse_bwd = _get_sm90_ops()
|
|
288
288
|
if (not _force_triton()) and _is_sm90() and (block_sparse_fwd is not None) and (block_sparse_bwd is not None):
|
|
289
289
|
return block_sparse_attn_sm90(q, k, v, block_map, variable_block_sizes)
|
|
290
|
-
# Triton path:
|
|
291
|
-
|
|
292
|
-
raise RuntimeError("Triton fallback requires q/k/v to have the same padded length.")
|
|
290
|
+
# Triton path: supports q_seq_len != kv_seq_len as long as both are padded
|
|
291
|
+
# to a multiple of the block size (64 tokens).
|
|
293
292
|
return block_sparse_attn_triton(q, k, v, block_map, variable_block_sizes)
|
|
294
293
|
|
|
295
294
|
|
|
@@ -141,12 +141,6 @@ def video_sparse_attn(
|
|
|
141
141
|
# Use autograd-enabled wrapper so backward works (and still uses SM90 kernel when available)
|
|
142
142
|
out_s = block_sparse_attn(q, k, v, mask, variable_block_sizes)[0]
|
|
143
143
|
else:
|
|
144
|
-
if q_seq_len != kv_seq_len:
|
|
145
|
-
raise RuntimeError(
|
|
146
|
-
"q/k have different lengths, but the compiled CUDA kernel (block_sparse_fwd) "
|
|
147
|
-
"is not available. The Triton fallback currently requires q and k/v to have "
|
|
148
|
-
"the same padded length."
|
|
149
|
-
)
|
|
150
144
|
# Triton-only forward (kept for environments without the wrapper deps)
|
|
151
145
|
out_s, _ = triton_block_sparse_attn_forward(q, k, v, idx, num, variable_block_sizes)
|
|
152
146
|
|
|
@@ -29,7 +29,7 @@ configs = [
|
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
# ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
|
|
32
|
-
@triton.autotune(configs, key=["
|
|
32
|
+
@triton.autotune(configs, key=["N_CTX_Q", "HEAD_DIM"])
|
|
33
33
|
@triton.jit
|
|
34
34
|
def _attn_fwd_sparse(
|
|
35
35
|
Q,
|
|
@@ -60,7 +60,8 @@ def _attn_fwd_sparse(
|
|
|
60
60
|
stride_on,
|
|
61
61
|
Z,
|
|
62
62
|
H,
|
|
63
|
-
|
|
63
|
+
N_CTX_Q, #
|
|
64
|
+
N_CTX_KV, #
|
|
64
65
|
HEAD_DIM: tl.constexpr, #
|
|
65
66
|
BLOCK_M: tl.constexpr,
|
|
66
67
|
BLOCK_N: tl.constexpr,
|
|
@@ -75,24 +76,29 @@ def _attn_fwd_sparse(
|
|
|
75
76
|
off_hz = tl.program_id(1) # fused (batch, head)
|
|
76
77
|
b = off_hz // H
|
|
77
78
|
h = off_hz % H
|
|
78
|
-
q_tiles =
|
|
79
|
+
q_tiles = N_CTX_Q // BLOCK_M
|
|
79
80
|
meta_base = ((b * H + h) * q_tiles + q_blk)
|
|
80
81
|
|
|
81
82
|
kv_blocks = tl.load(q2k_num + meta_base) # int32
|
|
82
83
|
kv_ptr = q2k_index + meta_base * max_kv_blks # ptr to list
|
|
83
84
|
|
|
84
85
|
# ----- base pointers -----
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
86
|
+
# Note: when q and kv have different sequence lengths, their per-(batch,head)
|
|
87
|
+
# strides differ, so we must compute separate base offsets.
|
|
88
|
+
q_off = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
|
|
89
|
+
k_off = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
|
|
90
|
+
v_off = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
|
|
91
|
+
o_off = (b.to(tl.int64) * stride_oz + h.to(tl.int64) * stride_oh)
|
|
92
|
+
|
|
93
|
+
Q_ptr = tl.make_block_ptr(base=Q + q_off,
|
|
94
|
+
shape=(N_CTX_Q, HEAD_DIM),
|
|
89
95
|
strides=(stride_qm, stride_qk),
|
|
90
96
|
offsets=(q_blk * BLOCK_M, 0),
|
|
91
97
|
block_shape=(BLOCK_M, HEAD_DIM),
|
|
92
98
|
order=(1, 0))
|
|
93
99
|
|
|
94
|
-
K_base = tl.make_block_ptr(base=K +
|
|
95
|
-
shape=(HEAD_DIM,
|
|
100
|
+
K_base = tl.make_block_ptr(base=K + k_off,
|
|
101
|
+
shape=(HEAD_DIM, N_CTX_KV),
|
|
96
102
|
strides=(stride_kk, stride_kn),
|
|
97
103
|
offsets=(0, 0),
|
|
98
104
|
block_shape=(HEAD_DIM, BLOCK_N),
|
|
@@ -100,15 +106,15 @@ def _attn_fwd_sparse(
|
|
|
100
106
|
|
|
101
107
|
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1,
|
|
102
108
|
0)
|
|
103
|
-
V_base = tl.make_block_ptr(base=V +
|
|
104
|
-
shape=(
|
|
109
|
+
V_base = tl.make_block_ptr(base=V + v_off,
|
|
110
|
+
shape=(N_CTX_KV, HEAD_DIM),
|
|
105
111
|
strides=(stride_vk, stride_vn),
|
|
106
112
|
offsets=(0, 0),
|
|
107
113
|
block_shape=(BLOCK_N, HEAD_DIM),
|
|
108
114
|
order=v_order)
|
|
109
115
|
|
|
110
|
-
O_ptr = tl.make_block_ptr(base=Out +
|
|
111
|
-
shape=(
|
|
116
|
+
O_ptr = tl.make_block_ptr(base=Out + o_off,
|
|
117
|
+
shape=(N_CTX_Q, HEAD_DIM),
|
|
112
118
|
strides=(stride_om, stride_on),
|
|
113
119
|
offsets=(q_blk * BLOCK_M, 0),
|
|
114
120
|
block_shape=(BLOCK_M, HEAD_DIM),
|
|
@@ -150,7 +156,7 @@ def _attn_fwd_sparse(
|
|
|
150
156
|
# ----- epilogue -----
|
|
151
157
|
m_i += tl.math.log2(l_i)
|
|
152
158
|
acc = acc / l_i[:, None]
|
|
153
|
-
tl.store(M + off_hz *
|
|
159
|
+
tl.store(M + off_hz * N_CTX_Q + offs_m, m_i)
|
|
154
160
|
tl.store(O_ptr, acc.to(Out.type.element_ty))
|
|
155
161
|
|
|
156
162
|
|
|
@@ -201,7 +207,7 @@ def _attn_bwd_dkdv(
|
|
|
201
207
|
stride_tok,
|
|
202
208
|
stride_d, #
|
|
203
209
|
H,
|
|
204
|
-
|
|
210
|
+
N_CTX_KV,
|
|
205
211
|
BLOCK_M1: tl.constexpr, #
|
|
206
212
|
BLOCK_N1: tl.constexpr, #
|
|
207
213
|
HEAD_DIM: tl.constexpr, #
|
|
@@ -221,8 +227,8 @@ def _attn_bwd_dkdv(
|
|
|
221
227
|
off_hz = tl.program_id(2) # fused (batch, head)
|
|
222
228
|
b = off_hz // H
|
|
223
229
|
h = off_hz % H
|
|
224
|
-
|
|
225
|
-
meta_base = ((b * H + h) *
|
|
230
|
+
kv_tiles = N_CTX_KV // BLOCK_N1
|
|
231
|
+
meta_base = ((b * H + h) * kv_tiles + kv_blk)
|
|
226
232
|
|
|
227
233
|
q_blocks = tl.load(k2q_num + meta_base) # int32
|
|
228
234
|
q_ptr = k2q_index + meta_base * max_q_blks # ptr to list
|
|
@@ -302,16 +308,21 @@ def _attn_bwd_dq(
|
|
|
302
308
|
|
|
303
309
|
kv_blocks = tl.load(q2k_num + meta_base) # int32
|
|
304
310
|
kv_ptr = q2k_index + meta_base * max_kv_blks # ptr to list
|
|
305
|
-
block_size = tl.load(variable_block_sizes + q_blk)
|
|
306
311
|
|
|
307
312
|
for blk_idx in range(kv_blocks * 2):
|
|
308
|
-
|
|
309
|
-
|
|
313
|
+
kv_idx = tl.load(kv_ptr + blk_idx // 2).to(tl.int32)
|
|
314
|
+
# variable_block_sizes is defined per KV block (tile). Mask must therefore
|
|
315
|
+
# use kv_idx (not q_blk). Also, because we split each 64-token block into
|
|
316
|
+
# two 32-token halves, the mask must account for the half-block offset.
|
|
317
|
+
block_size = tl.load(variable_block_sizes + kv_idx).to(tl.int32)
|
|
318
|
+
half = (blk_idx % 2).to(tl.int32)
|
|
319
|
+
block_sparse_offset = (kv_idx * 2 + half) * step_n * stride_tok
|
|
310
320
|
kT = tl.load(kT_ptrs + block_sparse_offset)
|
|
311
321
|
vT = tl.load(vT_ptrs + block_sparse_offset)
|
|
312
322
|
qk = tl.dot(q, kT)
|
|
313
323
|
p = tl.math.exp2(qk - m)
|
|
314
|
-
|
|
324
|
+
offs_in_block = half * step_n + tl.arange(0, BLOCK_N2)
|
|
325
|
+
mask = offs_in_block < block_size
|
|
315
326
|
p = tl.where(mask[None, :], p, 0.0)
|
|
316
327
|
# Compute dP and dS.
|
|
317
328
|
dp = tl.dot(do, vT).to(tl.float32)
|
|
@@ -467,19 +478,235 @@ def _attn_bwd(
|
|
|
467
478
|
tl.store(dq_ptrs, dq)
|
|
468
479
|
|
|
469
480
|
|
|
481
|
+
@triton.jit
|
|
482
|
+
def _attn_bwd_dkdv_kernel(
|
|
483
|
+
Q,
|
|
484
|
+
K,
|
|
485
|
+
V,
|
|
486
|
+
sm_scale, #
|
|
487
|
+
DO, #
|
|
488
|
+
DK,
|
|
489
|
+
DV, #
|
|
490
|
+
M,
|
|
491
|
+
D,
|
|
492
|
+
k2q_index,
|
|
493
|
+
k2q_num,
|
|
494
|
+
max_q_blks,
|
|
495
|
+
variable_block_sizes,
|
|
496
|
+
# shared token/dim strides (assumed contiguous along token and dim)
|
|
497
|
+
stride_tok,
|
|
498
|
+
stride_d, #
|
|
499
|
+
# batch/head strides (may differ between Q and KV)
|
|
500
|
+
stride_qz,
|
|
501
|
+
stride_qh,
|
|
502
|
+
stride_kz,
|
|
503
|
+
stride_kh,
|
|
504
|
+
stride_vz,
|
|
505
|
+
stride_vh,
|
|
506
|
+
stride_doz,
|
|
507
|
+
stride_doh,
|
|
508
|
+
stride_dkz,
|
|
509
|
+
stride_dkh,
|
|
510
|
+
stride_dvz,
|
|
511
|
+
stride_dvh,
|
|
512
|
+
H,
|
|
513
|
+
N_CTX_Q,
|
|
514
|
+
N_CTX_KV,
|
|
515
|
+
BLOCK_M1: tl.constexpr, #
|
|
516
|
+
BLOCK_N1: tl.constexpr, #
|
|
517
|
+
HEAD_DIM: tl.constexpr):
|
|
518
|
+
"""
|
|
519
|
+
Backward kernel that computes dK and dV for each KV block (64 tokens).
|
|
520
|
+
Grid:
|
|
521
|
+
pid0: kv_blk in [0, N_CTX_KV/BLOCK_N1)
|
|
522
|
+
pid2: fused (batch, head) in [0, B*H)
|
|
523
|
+
"""
|
|
524
|
+
bhid = tl.program_id(2)
|
|
525
|
+
b = bhid // H
|
|
526
|
+
h = bhid % H
|
|
527
|
+
kv_blk = tl.program_id(0)
|
|
528
|
+
|
|
529
|
+
q_adj = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
|
|
530
|
+
kv_adj_k = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
|
|
531
|
+
kv_adj_v = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
|
|
532
|
+
do_adj = (b.to(tl.int64) * stride_doz + h.to(tl.int64) * stride_doh)
|
|
533
|
+
dk_adj = (b.to(tl.int64) * stride_dkz + h.to(tl.int64) * stride_dkh)
|
|
534
|
+
dv_adj = (b.to(tl.int64) * stride_dvz + h.to(tl.int64) * stride_dvh)
|
|
535
|
+
|
|
536
|
+
Q = Q + q_adj
|
|
537
|
+
K = K + kv_adj_k
|
|
538
|
+
V = V + kv_adj_v
|
|
539
|
+
DO = DO + do_adj
|
|
540
|
+
DK = DK + dk_adj
|
|
541
|
+
DV = DV + dv_adj
|
|
542
|
+
|
|
543
|
+
# M and D (delta) are always sized by Q length.
|
|
544
|
+
M = M + (bhid * N_CTX_Q).to(tl.int64)
|
|
545
|
+
D = D + (bhid * N_CTX_Q).to(tl.int64)
|
|
546
|
+
|
|
547
|
+
offs_k = tl.arange(0, HEAD_DIM)
|
|
548
|
+
start_n = kv_blk * BLOCK_N1
|
|
549
|
+
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
|
550
|
+
|
|
551
|
+
# load K and V: they stay in SRAM throughout the inner loop.
|
|
552
|
+
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
|
553
|
+
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
|
554
|
+
|
|
555
|
+
dv_acc = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
|
556
|
+
dk_acc = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
|
|
557
|
+
|
|
558
|
+
num_steps = N_CTX_Q // BLOCK_M1
|
|
559
|
+
dk_acc, dv_acc = _attn_bwd_dkdv(
|
|
560
|
+
dk_acc,
|
|
561
|
+
dv_acc,
|
|
562
|
+
Q,
|
|
563
|
+
k,
|
|
564
|
+
v,
|
|
565
|
+
sm_scale,
|
|
566
|
+
DO,
|
|
567
|
+
M,
|
|
568
|
+
D,
|
|
569
|
+
k2q_index,
|
|
570
|
+
k2q_num,
|
|
571
|
+
max_q_blks,
|
|
572
|
+
variable_block_sizes,
|
|
573
|
+
stride_tok,
|
|
574
|
+
stride_d,
|
|
575
|
+
H,
|
|
576
|
+
N_CTX_KV,
|
|
577
|
+
BLOCK_M1=BLOCK_M1,
|
|
578
|
+
BLOCK_N1=BLOCK_N1,
|
|
579
|
+
HEAD_DIM=HEAD_DIM,
|
|
580
|
+
start_n=start_n,
|
|
581
|
+
start_m=0,
|
|
582
|
+
num_steps=num_steps,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
|
586
|
+
tl.store(dv_ptrs, dv_acc)
|
|
587
|
+
|
|
588
|
+
dk_acc *= sm_scale
|
|
589
|
+
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
|
590
|
+
tl.store(dk_ptrs, dk_acc)
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
@triton.jit
|
|
594
|
+
def _attn_bwd_dq_kernel(
|
|
595
|
+
Q,
|
|
596
|
+
K,
|
|
597
|
+
V,
|
|
598
|
+
DO, #
|
|
599
|
+
DQ,
|
|
600
|
+
M,
|
|
601
|
+
D,
|
|
602
|
+
q2k_index,
|
|
603
|
+
q2k_num,
|
|
604
|
+
max_kv_blks,
|
|
605
|
+
variable_block_sizes,
|
|
606
|
+
# shared token/dim strides (assumed contiguous along token and dim)
|
|
607
|
+
stride_tok,
|
|
608
|
+
stride_d, #
|
|
609
|
+
# batch/head strides (may differ between Q and KV)
|
|
610
|
+
stride_qz,
|
|
611
|
+
stride_qh,
|
|
612
|
+
stride_kz,
|
|
613
|
+
stride_kh,
|
|
614
|
+
stride_vz,
|
|
615
|
+
stride_vh,
|
|
616
|
+
stride_doz,
|
|
617
|
+
stride_doh,
|
|
618
|
+
stride_dqz,
|
|
619
|
+
stride_dqh,
|
|
620
|
+
H,
|
|
621
|
+
N_CTX_Q,
|
|
622
|
+
BLOCK_M2: tl.constexpr, #
|
|
623
|
+
BLOCK_N2: tl.constexpr, #
|
|
624
|
+
HEAD_DIM: tl.constexpr):
|
|
625
|
+
"""
|
|
626
|
+
Backward kernel that computes dQ for each Q block (64 tokens).
|
|
627
|
+
Grid:
|
|
628
|
+
pid0: q_blk in [0, N_CTX_Q/BLOCK_M2)
|
|
629
|
+
pid2: fused (batch, head) in [0, B*H)
|
|
630
|
+
"""
|
|
631
|
+
LN2 = 0.6931471824645996 # = ln(2)
|
|
632
|
+
bhid = tl.program_id(2)
|
|
633
|
+
b = bhid // H
|
|
634
|
+
h = bhid % H
|
|
635
|
+
q_blk = tl.program_id(0)
|
|
636
|
+
|
|
637
|
+
q_adj = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
|
|
638
|
+
kv_adj_k = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
|
|
639
|
+
kv_adj_v = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
|
|
640
|
+
do_adj = (b.to(tl.int64) * stride_doz + h.to(tl.int64) * stride_doh)
|
|
641
|
+
dq_adj = (b.to(tl.int64) * stride_dqz + h.to(tl.int64) * stride_dqh)
|
|
642
|
+
|
|
643
|
+
Q = Q + q_adj
|
|
644
|
+
K = K + kv_adj_k
|
|
645
|
+
V = V + kv_adj_v
|
|
646
|
+
DO = DO + do_adj
|
|
647
|
+
DQ = DQ + dq_adj
|
|
648
|
+
|
|
649
|
+
M = M + (bhid * N_CTX_Q).to(tl.int64)
|
|
650
|
+
D = D + (bhid * N_CTX_Q).to(tl.int64)
|
|
651
|
+
|
|
652
|
+
offs_k = tl.arange(0, HEAD_DIM)
|
|
653
|
+
start_m = q_blk * BLOCK_M2
|
|
654
|
+
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
|
655
|
+
|
|
656
|
+
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
|
657
|
+
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
|
658
|
+
m = tl.load(M + offs_m)[:, None]
|
|
659
|
+
|
|
660
|
+
dq_acc = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
|
|
661
|
+
num_steps = 0 # unused in _attn_bwd_dq
|
|
662
|
+
dq_acc = _attn_bwd_dq(
|
|
663
|
+
dq_acc,
|
|
664
|
+
q,
|
|
665
|
+
K,
|
|
666
|
+
V,
|
|
667
|
+
do,
|
|
668
|
+
m,
|
|
669
|
+
D,
|
|
670
|
+
q2k_index,
|
|
671
|
+
q2k_num,
|
|
672
|
+
max_kv_blks,
|
|
673
|
+
variable_block_sizes,
|
|
674
|
+
stride_tok,
|
|
675
|
+
stride_d,
|
|
676
|
+
H,
|
|
677
|
+
N_CTX_Q,
|
|
678
|
+
BLOCK_M2=BLOCK_M2,
|
|
679
|
+
BLOCK_N2=BLOCK_N2,
|
|
680
|
+
HEAD_DIM=HEAD_DIM,
|
|
681
|
+
start_m=start_m,
|
|
682
|
+
start_n=0,
|
|
683
|
+
num_steps=num_steps,
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
|
687
|
+
dq_acc *= LN2
|
|
688
|
+
tl.store(dq_ptrs, dq_acc)
|
|
689
|
+
|
|
690
|
+
|
|
470
691
|
# ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
|
|
471
692
|
def triton_block_sparse_attn_forward(q, k, v, q2k_index, q2k_num,
|
|
472
693
|
variable_block_sizes):
|
|
473
|
-
B, H,
|
|
694
|
+
B, H, Tq, D = q.shape
|
|
695
|
+
Tkv = k.shape[2]
|
|
474
696
|
sm_scale = 1.0 / math.sqrt(D)
|
|
475
697
|
max_kv_blks = q2k_index.shape[-1]
|
|
476
|
-
assert
|
|
698
|
+
assert Tq % 64 == 0, f"q length must be a multiple of 64, but got {Tq}"
|
|
699
|
+
assert Tkv % 64 == 0, f"kv length must be a multiple of 64, but got {Tkv}"
|
|
477
700
|
assert q2k_num.shape[
|
|
478
|
-
-1] ==
|
|
701
|
+
-1] == Tq // 64, f"shape mismatch, Tq // 64 = {Tq // 64}, q2k_num.shape[-2] = {q2k_num.shape[-2]}"
|
|
702
|
+
assert variable_block_sizes.numel() == Tkv // 64, (
|
|
703
|
+
f"shape mismatch, variable_block_sizes must have length {Tkv // 64}, "
|
|
704
|
+
f"got {variable_block_sizes.numel()}"
|
|
705
|
+
)
|
|
479
706
|
o = torch.empty_like(q)
|
|
480
|
-
M = torch.empty((B, H,
|
|
707
|
+
M = torch.empty((B, H, Tq), dtype=torch.float32, device=q.device)
|
|
481
708
|
|
|
482
|
-
grid = lambda _: (triton.cdiv(
|
|
709
|
+
grid = lambda _: (triton.cdiv(Tq, 64), B * H, 1)
|
|
483
710
|
_attn_fwd_sparse[grid](q,
|
|
484
711
|
k,
|
|
485
712
|
v,
|
|
@@ -508,7 +735,8 @@ def triton_block_sparse_attn_forward(q, k, v, q2k_index, q2k_num,
|
|
|
508
735
|
o.stride(3),
|
|
509
736
|
B,
|
|
510
737
|
H,
|
|
511
|
-
|
|
738
|
+
Tq,
|
|
739
|
+
Tkv,
|
|
512
740
|
HEAD_DIM=D,
|
|
513
741
|
STAGE=3)
|
|
514
742
|
|
|
@@ -518,21 +746,21 @@ def triton_block_sparse_attn_forward(q, k, v, q2k_index, q2k_num,
|
|
|
518
746
|
def triton_block_sparse_attn_backward(do, q, k, v, o, M, q2k_index, q2k_num,
|
|
519
747
|
k2q_index, k2q_num, variable_block_sizes):
|
|
520
748
|
assert do.is_contiguous()
|
|
521
|
-
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
|
|
522
749
|
|
|
523
|
-
B, H,
|
|
750
|
+
B, H, Tq, D = q.shape
|
|
751
|
+
Tkv = k.shape[2]
|
|
524
752
|
sm_scale = 1.0 / math.sqrt(D)
|
|
525
753
|
dq = torch.empty_like(q)
|
|
526
754
|
dk = torch.empty_like(k)
|
|
527
755
|
dv = torch.empty_like(v)
|
|
528
|
-
BATCH, N_HEAD
|
|
756
|
+
BATCH, N_HEAD = q.shape[:2]
|
|
529
757
|
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
|
|
530
758
|
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
|
|
531
759
|
arg_k = k
|
|
532
760
|
arg_k = arg_k * (sm_scale * RCP_LN2)
|
|
533
761
|
PRE_BLOCK = 64
|
|
534
|
-
assert
|
|
535
|
-
pre_grid = (
|
|
762
|
+
assert Tq % PRE_BLOCK == 0
|
|
763
|
+
pre_grid = (Tq // PRE_BLOCK, BATCH * N_HEAD)
|
|
536
764
|
delta = torch.empty_like(M)
|
|
537
765
|
_attn_bwd_preprocess[pre_grid](
|
|
538
766
|
o,
|
|
@@ -540,7 +768,7 @@ def triton_block_sparse_attn_backward(do, q, k, v, o, M, q2k_index, q2k_num,
|
|
|
540
768
|
delta, #
|
|
541
769
|
BATCH,
|
|
542
770
|
N_HEAD,
|
|
543
|
-
|
|
771
|
+
Tq, #
|
|
544
772
|
BLOCK_M=PRE_BLOCK,
|
|
545
773
|
HEAD_DIM=D #
|
|
546
774
|
)
|
|
@@ -548,36 +776,75 @@ def triton_block_sparse_attn_backward(do, q, k, v, o, M, q2k_index, q2k_num,
|
|
|
548
776
|
max_q_blks = k2q_index.shape[-1]
|
|
549
777
|
max_kv_blks = q2k_index.shape[-1]
|
|
550
778
|
|
|
551
|
-
|
|
552
|
-
|
|
779
|
+
# dK/dV kernel: grid over KV blocks
|
|
780
|
+
grid_kv = (Tkv // BLOCK_N1, 1, BATCH * N_HEAD)
|
|
781
|
+
_attn_bwd_dkdv_kernel[grid_kv](
|
|
553
782
|
q,
|
|
554
783
|
arg_k,
|
|
555
784
|
v,
|
|
556
785
|
sm_scale,
|
|
557
786
|
do,
|
|
558
|
-
dq,
|
|
559
787
|
dk,
|
|
560
|
-
dv,
|
|
788
|
+
dv,
|
|
561
789
|
M,
|
|
562
|
-
delta,
|
|
563
|
-
q2k_index,
|
|
564
|
-
q2k_num,
|
|
565
|
-
max_kv_blks,
|
|
790
|
+
delta,
|
|
566
791
|
k2q_index,
|
|
567
792
|
k2q_num,
|
|
568
793
|
max_q_blks,
|
|
569
794
|
variable_block_sizes,
|
|
795
|
+
q.stride(2),
|
|
796
|
+
q.stride(3),
|
|
570
797
|
q.stride(0),
|
|
571
798
|
q.stride(1),
|
|
572
|
-
|
|
573
|
-
|
|
799
|
+
arg_k.stride(0),
|
|
800
|
+
arg_k.stride(1),
|
|
801
|
+
v.stride(0),
|
|
802
|
+
v.stride(1),
|
|
803
|
+
do.stride(0),
|
|
804
|
+
do.stride(1),
|
|
805
|
+
dk.stride(0),
|
|
806
|
+
dk.stride(1),
|
|
807
|
+
dv.stride(0),
|
|
808
|
+
dv.stride(1),
|
|
574
809
|
N_HEAD,
|
|
575
|
-
|
|
810
|
+
Tq,
|
|
811
|
+
Tkv,
|
|
576
812
|
BLOCK_M1=BLOCK_M1,
|
|
577
|
-
BLOCK_N1=BLOCK_N1,
|
|
813
|
+
BLOCK_N1=BLOCK_N1,
|
|
814
|
+
HEAD_DIM=D,
|
|
815
|
+
)
|
|
816
|
+
|
|
817
|
+
# dQ kernel: grid over Q blocks
|
|
818
|
+
grid_q = (Tq // BLOCK_M2, 1, BATCH * N_HEAD)
|
|
819
|
+
_attn_bwd_dq_kernel[grid_q](
|
|
820
|
+
q,
|
|
821
|
+
arg_k,
|
|
822
|
+
v,
|
|
823
|
+
do,
|
|
824
|
+
dq,
|
|
825
|
+
M,
|
|
826
|
+
delta,
|
|
827
|
+
q2k_index,
|
|
828
|
+
q2k_num,
|
|
829
|
+
max_kv_blks,
|
|
830
|
+
variable_block_sizes,
|
|
831
|
+
q.stride(2),
|
|
832
|
+
q.stride(3),
|
|
833
|
+
q.stride(0),
|
|
834
|
+
q.stride(1),
|
|
835
|
+
arg_k.stride(0),
|
|
836
|
+
arg_k.stride(1),
|
|
837
|
+
v.stride(0),
|
|
838
|
+
v.stride(1),
|
|
839
|
+
do.stride(0),
|
|
840
|
+
do.stride(1),
|
|
841
|
+
dq.stride(0),
|
|
842
|
+
dq.stride(1),
|
|
843
|
+
N_HEAD,
|
|
844
|
+
Tq,
|
|
578
845
|
BLOCK_M2=BLOCK_M2,
|
|
579
|
-
BLOCK_N2=BLOCK_N2,
|
|
580
|
-
HEAD_DIM=D
|
|
846
|
+
BLOCK_N2=BLOCK_N2,
|
|
847
|
+
HEAD_DIM=D,
|
|
581
848
|
)
|
|
582
849
|
|
|
583
850
|
return dq, dk, dv
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.2.6"
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
__version__ = "0.2.5"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/triton_kernels/index.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/turbodiffusion_ops.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|