fastvideo-kernel 0.2.5__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/PKG-INFO +1 -1
  2. fastvideo_kernel-0.2.5/dist/fastvideo_kernel-0.2.5-cp310-cp310-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl → fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp310-cp310-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
  3. fastvideo_kernel-0.2.5/dist/fastvideo_kernel-0.2.5-cp311-cp311-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl → fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp311-cp311-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
  4. fastvideo_kernel-0.2.5/dist/fastvideo_kernel-0.2.5-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl → fastvideo_kernel-0.2.6/dist/fastvideo_kernel-0.2.6-cp312-cp312-manylinux_2_34_x86_64.manylinux_2_35_x86_64.whl +0 -0
  5. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/pyproject.toml +1 -1
  6. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/block_sparse_attn.py +2 -3
  7. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/ops.py +0 -6
  8. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/triton_kernels/block_sparse_attn_triton.py +314 -47
  9. fastvideo_kernel-0.2.6/python/fastvideo_kernel/version.py +1 -0
  10. fastvideo_kernel-0.2.5/python/fastvideo_kernel/version.py +0 -1
  11. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/CMakeLists.txt +0 -0
  12. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/LICENSE +0 -0
  13. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/MANIFEST.in +0 -0
  14. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/README.md +0 -0
  15. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/benchmarks/bench_vsa.py +0 -0
  16. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/build.sh +0 -0
  17. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/attention/block_sparse_h100.cu +0 -0
  18. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/attention/st_attn_h100.cu +0 -0
  19. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/common_extension.cpp +0 -0
  20. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/common/common.hpp +0 -0
  21. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/common/launch.hpp +0 -0
  22. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/common/load.hpp +0 -0
  23. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/common/store.hpp +0 -0
  24. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/gemm/gemm.cu +0 -0
  25. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/gemm/kernel.hpp +0 -0
  26. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/gemm/launch.hpp +0 -0
  27. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/gemm/utils.hpp +0 -0
  28. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/norm/layernorm.cu +0 -0
  29. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/norm/layernorm.hpp +0 -0
  30. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/norm/rmsnorm.cu +0 -0
  31. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/norm/rmsnorm.hpp +0 -0
  32. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/quant/quant.cu +0 -0
  33. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/csrc/turbodiffusion/quant/quant.hpp +0 -0
  34. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/__init__.py +0 -0
  35. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/triton_kernels/index.py +0 -0
  36. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/triton_kernels/sla_triton.py +0 -0
  37. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/triton_kernels/st_attn_triton.py +0 -0
  38. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/turbodiffusion_ops.py +0 -0
  39. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/python/fastvideo_kernel/vmoba.py +0 -0
  40. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/__init__.py +0 -0
  41. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/support_flex_sta.py +0 -0
  42. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/test_sta.py +0 -0
  43. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/test_turbodiffusion.py +0 -0
  44. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/test_vmoba_correctness.py +0 -0
  45. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/test_vsa.py +0 -0
  46. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/test_vsa_forward.py +0 -0
  47. {fastvideo_kernel-0.2.5 → fastvideo_kernel-0.2.6}/tests/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: fastvideo-kernel
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: Unified CUDA kernels for FastVideo
5
5
  Author-Email: Hao AI Lab <contact@haoailab.com>
6
6
  License: Apache License
@@ -9,7 +9,7 @@ build-backend = "scikit_build_core.build"
9
9
 
10
10
  [project]
11
11
  name = "fastvideo-kernel"
12
- version = "0.2.5"
12
+ version = "0.2.6"
13
13
  description = "Unified CUDA kernels for FastVideo"
14
14
  readme = "README.md"
15
15
  requires-python = ">=3.10"
@@ -287,9 +287,8 @@ def block_sparse_attn(
287
287
  block_sparse_fwd, block_sparse_bwd = _get_sm90_ops()
288
288
  if (not _force_triton()) and _is_sm90() and (block_sparse_fwd is not None) and (block_sparse_bwd is not None):
289
289
  return block_sparse_attn_sm90(q, k, v, block_map, variable_block_sizes)
290
- # Triton path: generally assumes q/k/v share the same padded length
291
- if q.shape[2] != k.shape[2] or q.shape[2] != v.shape[2]:
292
- raise RuntimeError("Triton fallback requires q/k/v to have the same padded length.")
290
+ # Triton path: supports q_seq_len != kv_seq_len as long as both are padded
291
+ # to a multiple of the block size (64 tokens).
293
292
  return block_sparse_attn_triton(q, k, v, block_map, variable_block_sizes)
294
293
 
295
294
 
@@ -141,12 +141,6 @@ def video_sparse_attn(
141
141
  # Use autograd-enabled wrapper so backward works (and still uses SM90 kernel when available)
142
142
  out_s = block_sparse_attn(q, k, v, mask, variable_block_sizes)[0]
143
143
  else:
144
- if q_seq_len != kv_seq_len:
145
- raise RuntimeError(
146
- "q/k have different lengths, but the compiled CUDA kernel (block_sparse_fwd) "
147
- "is not available. The Triton fallback currently requires q and k/v to have "
148
- "the same padded length."
149
- )
150
144
  # Triton-only forward (kept for environments without the wrapper deps)
151
145
  out_s, _ = triton_block_sparse_attn_forward(q, k, v, idx, num, variable_block_sizes)
152
146
 
@@ -29,7 +29,7 @@ configs = [
29
29
 
30
30
 
31
31
  # ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
32
- @triton.autotune(configs, key=["N_CTX", "HEAD_DIM"])
32
+ @triton.autotune(configs, key=["N_CTX_Q", "HEAD_DIM"])
33
33
  @triton.jit
34
34
  def _attn_fwd_sparse(
35
35
  Q,
@@ -60,7 +60,8 @@ def _attn_fwd_sparse(
60
60
  stride_on,
61
61
  Z,
62
62
  H,
63
- N_CTX, #
63
+ N_CTX_Q, #
64
+ N_CTX_KV, #
64
65
  HEAD_DIM: tl.constexpr, #
65
66
  BLOCK_M: tl.constexpr,
66
67
  BLOCK_N: tl.constexpr,
@@ -75,24 +76,29 @@ def _attn_fwd_sparse(
75
76
  off_hz = tl.program_id(1) # fused (batch, head)
76
77
  b = off_hz // H
77
78
  h = off_hz % H
78
- q_tiles = N_CTX // BLOCK_M
79
+ q_tiles = N_CTX_Q // BLOCK_M
79
80
  meta_base = ((b * H + h) * q_tiles + q_blk)
80
81
 
81
82
  kv_blocks = tl.load(q2k_num + meta_base) # int32
82
83
  kv_ptr = q2k_index + meta_base * max_kv_blks # ptr to list
83
84
 
84
85
  # ----- base pointers -----
85
- qvk_off = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
86
-
87
- Q_ptr = tl.make_block_ptr(base=Q + qvk_off,
88
- shape=(N_CTX, HEAD_DIM),
86
+ # Note: when q and kv have different sequence lengths, their per-(batch,head)
87
+ # strides differ, so we must compute separate base offsets.
88
+ q_off = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
89
+ k_off = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
90
+ v_off = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
91
+ o_off = (b.to(tl.int64) * stride_oz + h.to(tl.int64) * stride_oh)
92
+
93
+ Q_ptr = tl.make_block_ptr(base=Q + q_off,
94
+ shape=(N_CTX_Q, HEAD_DIM),
89
95
  strides=(stride_qm, stride_qk),
90
96
  offsets=(q_blk * BLOCK_M, 0),
91
97
  block_shape=(BLOCK_M, HEAD_DIM),
92
98
  order=(1, 0))
93
99
 
94
- K_base = tl.make_block_ptr(base=K + qvk_off,
95
- shape=(HEAD_DIM, N_CTX),
100
+ K_base = tl.make_block_ptr(base=K + k_off,
101
+ shape=(HEAD_DIM, N_CTX_KV),
96
102
  strides=(stride_kk, stride_kn),
97
103
  offsets=(0, 0),
98
104
  block_shape=(HEAD_DIM, BLOCK_N),
@@ -100,15 +106,15 @@ def _attn_fwd_sparse(
100
106
 
101
107
  v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1,
102
108
  0)
103
- V_base = tl.make_block_ptr(base=V + qvk_off,
104
- shape=(N_CTX, HEAD_DIM),
109
+ V_base = tl.make_block_ptr(base=V + v_off,
110
+ shape=(N_CTX_KV, HEAD_DIM),
105
111
  strides=(stride_vk, stride_vn),
106
112
  offsets=(0, 0),
107
113
  block_shape=(BLOCK_N, HEAD_DIM),
108
114
  order=v_order)
109
115
 
110
- O_ptr = tl.make_block_ptr(base=Out + qvk_off,
111
- shape=(N_CTX, HEAD_DIM),
116
+ O_ptr = tl.make_block_ptr(base=Out + o_off,
117
+ shape=(N_CTX_Q, HEAD_DIM),
112
118
  strides=(stride_om, stride_on),
113
119
  offsets=(q_blk * BLOCK_M, 0),
114
120
  block_shape=(BLOCK_M, HEAD_DIM),
@@ -150,7 +156,7 @@ def _attn_fwd_sparse(
150
156
  # ----- epilogue -----
151
157
  m_i += tl.math.log2(l_i)
152
158
  acc = acc / l_i[:, None]
153
- tl.store(M + off_hz * N_CTX + offs_m, m_i)
159
+ tl.store(M + off_hz * N_CTX_Q + offs_m, m_i)
154
160
  tl.store(O_ptr, acc.to(Out.type.element_ty))
155
161
 
156
162
 
@@ -201,7 +207,7 @@ def _attn_bwd_dkdv(
201
207
  stride_tok,
202
208
  stride_d, #
203
209
  H,
204
- N_CTX,
210
+ N_CTX_KV,
205
211
  BLOCK_M1: tl.constexpr, #
206
212
  BLOCK_N1: tl.constexpr, #
207
213
  HEAD_DIM: tl.constexpr, #
@@ -221,8 +227,8 @@ def _attn_bwd_dkdv(
221
227
  off_hz = tl.program_id(2) # fused (batch, head)
222
228
  b = off_hz // H
223
229
  h = off_hz % H
224
- q_tiles = N_CTX // BLOCK_N1
225
- meta_base = ((b * H + h) * q_tiles + kv_blk)
230
+ kv_tiles = N_CTX_KV // BLOCK_N1
231
+ meta_base = ((b * H + h) * kv_tiles + kv_blk)
226
232
 
227
233
  q_blocks = tl.load(k2q_num + meta_base) # int32
228
234
  q_ptr = k2q_index + meta_base * max_q_blks # ptr to list
@@ -302,16 +308,21 @@ def _attn_bwd_dq(
302
308
 
303
309
  kv_blocks = tl.load(q2k_num + meta_base) # int32
304
310
  kv_ptr = q2k_index + meta_base * max_kv_blks # ptr to list
305
- block_size = tl.load(variable_block_sizes + q_blk)
306
311
 
307
312
  for blk_idx in range(kv_blocks * 2):
308
- block_sparse_offset = (tl.load(kv_ptr + blk_idx // 2).to(tl.int32) * 2 +
309
- blk_idx % 2) * step_n * stride_tok
313
+ kv_idx = tl.load(kv_ptr + blk_idx // 2).to(tl.int32)
314
+ # variable_block_sizes is defined per KV block (tile). Mask must therefore
315
+ # use kv_idx (not q_blk). Also, because we split each 64-token block into
316
+ # two 32-token halves, the mask must account for the half-block offset.
317
+ block_size = tl.load(variable_block_sizes + kv_idx).to(tl.int32)
318
+ half = (blk_idx % 2).to(tl.int32)
319
+ block_sparse_offset = (kv_idx * 2 + half) * step_n * stride_tok
310
320
  kT = tl.load(kT_ptrs + block_sparse_offset)
311
321
  vT = tl.load(vT_ptrs + block_sparse_offset)
312
322
  qk = tl.dot(q, kT)
313
323
  p = tl.math.exp2(qk - m)
314
- mask = tl.arange(0, BLOCK_N2) < block_size.to(tl.int32)
324
+ offs_in_block = half * step_n + tl.arange(0, BLOCK_N2)
325
+ mask = offs_in_block < block_size
315
326
  p = tl.where(mask[None, :], p, 0.0)
316
327
  # Compute dP and dS.
317
328
  dp = tl.dot(do, vT).to(tl.float32)
@@ -467,19 +478,235 @@ def _attn_bwd(
467
478
  tl.store(dq_ptrs, dq)
468
479
 
469
480
 
481
+ @triton.jit
482
+ def _attn_bwd_dkdv_kernel(
483
+ Q,
484
+ K,
485
+ V,
486
+ sm_scale, #
487
+ DO, #
488
+ DK,
489
+ DV, #
490
+ M,
491
+ D,
492
+ k2q_index,
493
+ k2q_num,
494
+ max_q_blks,
495
+ variable_block_sizes,
496
+ # shared token/dim strides (assumed contiguous along token and dim)
497
+ stride_tok,
498
+ stride_d, #
499
+ # batch/head strides (may differ between Q and KV)
500
+ stride_qz,
501
+ stride_qh,
502
+ stride_kz,
503
+ stride_kh,
504
+ stride_vz,
505
+ stride_vh,
506
+ stride_doz,
507
+ stride_doh,
508
+ stride_dkz,
509
+ stride_dkh,
510
+ stride_dvz,
511
+ stride_dvh,
512
+ H,
513
+ N_CTX_Q,
514
+ N_CTX_KV,
515
+ BLOCK_M1: tl.constexpr, #
516
+ BLOCK_N1: tl.constexpr, #
517
+ HEAD_DIM: tl.constexpr):
518
+ """
519
+ Backward kernel that computes dK and dV for each KV block (64 tokens).
520
+ Grid:
521
+ pid0: kv_blk in [0, N_CTX_KV/BLOCK_N1)
522
+ pid2: fused (batch, head) in [0, B*H)
523
+ """
524
+ bhid = tl.program_id(2)
525
+ b = bhid // H
526
+ h = bhid % H
527
+ kv_blk = tl.program_id(0)
528
+
529
+ q_adj = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
530
+ kv_adj_k = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
531
+ kv_adj_v = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
532
+ do_adj = (b.to(tl.int64) * stride_doz + h.to(tl.int64) * stride_doh)
533
+ dk_adj = (b.to(tl.int64) * stride_dkz + h.to(tl.int64) * stride_dkh)
534
+ dv_adj = (b.to(tl.int64) * stride_dvz + h.to(tl.int64) * stride_dvh)
535
+
536
+ Q = Q + q_adj
537
+ K = K + kv_adj_k
538
+ V = V + kv_adj_v
539
+ DO = DO + do_adj
540
+ DK = DK + dk_adj
541
+ DV = DV + dv_adj
542
+
543
+ # M and D (delta) are always sized by Q length.
544
+ M = M + (bhid * N_CTX_Q).to(tl.int64)
545
+ D = D + (bhid * N_CTX_Q).to(tl.int64)
546
+
547
+ offs_k = tl.arange(0, HEAD_DIM)
548
+ start_n = kv_blk * BLOCK_N1
549
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
550
+
551
+ # load K and V: they stay in SRAM throughout the inner loop.
552
+ k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
553
+ v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
554
+
555
+ dv_acc = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
556
+ dk_acc = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
557
+
558
+ num_steps = N_CTX_Q // BLOCK_M1
559
+ dk_acc, dv_acc = _attn_bwd_dkdv(
560
+ dk_acc,
561
+ dv_acc,
562
+ Q,
563
+ k,
564
+ v,
565
+ sm_scale,
566
+ DO,
567
+ M,
568
+ D,
569
+ k2q_index,
570
+ k2q_num,
571
+ max_q_blks,
572
+ variable_block_sizes,
573
+ stride_tok,
574
+ stride_d,
575
+ H,
576
+ N_CTX_KV,
577
+ BLOCK_M1=BLOCK_M1,
578
+ BLOCK_N1=BLOCK_N1,
579
+ HEAD_DIM=HEAD_DIM,
580
+ start_n=start_n,
581
+ start_m=0,
582
+ num_steps=num_steps,
583
+ )
584
+
585
+ dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
586
+ tl.store(dv_ptrs, dv_acc)
587
+
588
+ dk_acc *= sm_scale
589
+ dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
590
+ tl.store(dk_ptrs, dk_acc)
591
+
592
+
593
+ @triton.jit
594
+ def _attn_bwd_dq_kernel(
595
+ Q,
596
+ K,
597
+ V,
598
+ DO, #
599
+ DQ,
600
+ M,
601
+ D,
602
+ q2k_index,
603
+ q2k_num,
604
+ max_kv_blks,
605
+ variable_block_sizes,
606
+ # shared token/dim strides (assumed contiguous along token and dim)
607
+ stride_tok,
608
+ stride_d, #
609
+ # batch/head strides (may differ between Q and KV)
610
+ stride_qz,
611
+ stride_qh,
612
+ stride_kz,
613
+ stride_kh,
614
+ stride_vz,
615
+ stride_vh,
616
+ stride_doz,
617
+ stride_doh,
618
+ stride_dqz,
619
+ stride_dqh,
620
+ H,
621
+ N_CTX_Q,
622
+ BLOCK_M2: tl.constexpr, #
623
+ BLOCK_N2: tl.constexpr, #
624
+ HEAD_DIM: tl.constexpr):
625
+ """
626
+ Backward kernel that computes dQ for each Q block (64 tokens).
627
+ Grid:
628
+ pid0: q_blk in [0, N_CTX_Q/BLOCK_M2)
629
+ pid2: fused (batch, head) in [0, B*H)
630
+ """
631
+ LN2 = 0.6931471824645996 # = ln(2)
632
+ bhid = tl.program_id(2)
633
+ b = bhid // H
634
+ h = bhid % H
635
+ q_blk = tl.program_id(0)
636
+
637
+ q_adj = (b.to(tl.int64) * stride_qz + h.to(tl.int64) * stride_qh)
638
+ kv_adj_k = (b.to(tl.int64) * stride_kz + h.to(tl.int64) * stride_kh)
639
+ kv_adj_v = (b.to(tl.int64) * stride_vz + h.to(tl.int64) * stride_vh)
640
+ do_adj = (b.to(tl.int64) * stride_doz + h.to(tl.int64) * stride_doh)
641
+ dq_adj = (b.to(tl.int64) * stride_dqz + h.to(tl.int64) * stride_dqh)
642
+
643
+ Q = Q + q_adj
644
+ K = K + kv_adj_k
645
+ V = V + kv_adj_v
646
+ DO = DO + do_adj
647
+ DQ = DQ + dq_adj
648
+
649
+ M = M + (bhid * N_CTX_Q).to(tl.int64)
650
+ D = D + (bhid * N_CTX_Q).to(tl.int64)
651
+
652
+ offs_k = tl.arange(0, HEAD_DIM)
653
+ start_m = q_blk * BLOCK_M2
654
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
655
+
656
+ q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
657
+ do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
658
+ m = tl.load(M + offs_m)[:, None]
659
+
660
+ dq_acc = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
661
+ num_steps = 0 # unused in _attn_bwd_dq
662
+ dq_acc = _attn_bwd_dq(
663
+ dq_acc,
664
+ q,
665
+ K,
666
+ V,
667
+ do,
668
+ m,
669
+ D,
670
+ q2k_index,
671
+ q2k_num,
672
+ max_kv_blks,
673
+ variable_block_sizes,
674
+ stride_tok,
675
+ stride_d,
676
+ H,
677
+ N_CTX_Q,
678
+ BLOCK_M2=BLOCK_M2,
679
+ BLOCK_N2=BLOCK_N2,
680
+ HEAD_DIM=HEAD_DIM,
681
+ start_m=start_m,
682
+ start_n=0,
683
+ num_steps=num_steps,
684
+ )
685
+
686
+ dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
687
+ dq_acc *= LN2
688
+ tl.store(dq_ptrs, dq_acc)
689
+
690
+
470
691
  # ──────────────────────────── SPARSE ADDITION BEGIN ───────────────────────────
471
692
  def triton_block_sparse_attn_forward(q, k, v, q2k_index, q2k_num,
472
693
  variable_block_sizes):
473
- B, H, T, D = q.shape
694
+ B, H, Tq, D = q.shape
695
+ Tkv = k.shape[2]
474
696
  sm_scale = 1.0 / math.sqrt(D)
475
697
  max_kv_blks = q2k_index.shape[-1]
476
- assert T % 64 == 0, f"T must be a multiple of 64, but got {T}"
698
+ assert Tq % 64 == 0, f"q length must be a multiple of 64, but got {Tq}"
699
+ assert Tkv % 64 == 0, f"kv length must be a multiple of 64, but got {Tkv}"
477
700
  assert q2k_num.shape[
478
- -1] == T // 64, f"shape mismatch, T // 64 = {T // 64}, q2k_num.shape[-2] = {q2k_num.shape[-2]}"
701
+ -1] == Tq // 64, f"shape mismatch, Tq // 64 = {Tq // 64}, q2k_num.shape[-2] = {q2k_num.shape[-2]}"
702
+ assert variable_block_sizes.numel() == Tkv // 64, (
703
+ f"shape mismatch, variable_block_sizes must have length {Tkv // 64}, "
704
+ f"got {variable_block_sizes.numel()}"
705
+ )
479
706
  o = torch.empty_like(q)
480
- M = torch.empty((B, H, T), dtype=torch.float32, device=q.device)
707
+ M = torch.empty((B, H, Tq), dtype=torch.float32, device=q.device)
481
708
 
482
- grid = lambda _: (triton.cdiv(T, 64), B * H, 1)
709
+ grid = lambda _: (triton.cdiv(Tq, 64), B * H, 1)
483
710
  _attn_fwd_sparse[grid](q,
484
711
  k,
485
712
  v,
@@ -508,7 +735,8 @@ def triton_block_sparse_attn_forward(q, k, v, q2k_index, q2k_num,
508
735
  o.stride(3),
509
736
  B,
510
737
  H,
511
- T,
738
+ Tq,
739
+ Tkv,
512
740
  HEAD_DIM=D,
513
741
  STAGE=3)
514
742
 
@@ -518,21 +746,21 @@ def triton_block_sparse_attn_forward(q, k, v, q2k_index, q2k_num,
518
746
  def triton_block_sparse_attn_backward(do, q, k, v, o, M, q2k_index, q2k_num,
519
747
  k2q_index, k2q_num, variable_block_sizes):
520
748
  assert do.is_contiguous()
521
- assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
522
749
 
523
- B, H, T, D = q.shape
750
+ B, H, Tq, D = q.shape
751
+ Tkv = k.shape[2]
524
752
  sm_scale = 1.0 / math.sqrt(D)
525
753
  dq = torch.empty_like(q)
526
754
  dk = torch.empty_like(k)
527
755
  dv = torch.empty_like(v)
528
- BATCH, N_HEAD, N_CTX = q.shape[:3]
756
+ BATCH, N_HEAD = q.shape[:2]
529
757
  BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
530
758
  RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
531
759
  arg_k = k
532
760
  arg_k = arg_k * (sm_scale * RCP_LN2)
533
761
  PRE_BLOCK = 64
534
- assert N_CTX % PRE_BLOCK == 0
535
- pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
762
+ assert Tq % PRE_BLOCK == 0
763
+ pre_grid = (Tq // PRE_BLOCK, BATCH * N_HEAD)
536
764
  delta = torch.empty_like(M)
537
765
  _attn_bwd_preprocess[pre_grid](
538
766
  o,
@@ -540,7 +768,7 @@ def triton_block_sparse_attn_backward(do, q, k, v, o, M, q2k_index, q2k_num,
540
768
  delta, #
541
769
  BATCH,
542
770
  N_HEAD,
543
- N_CTX, #
771
+ Tq, #
544
772
  BLOCK_M=PRE_BLOCK,
545
773
  HEAD_DIM=D #
546
774
  )
@@ -548,36 +776,75 @@ def triton_block_sparse_attn_backward(do, q, k, v, o, M, q2k_index, q2k_num,
548
776
  max_q_blks = k2q_index.shape[-1]
549
777
  max_kv_blks = q2k_index.shape[-1]
550
778
 
551
- grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
552
- _attn_bwd[grid](
779
+ # dK/dV kernel: grid over KV blocks
780
+ grid_kv = (Tkv // BLOCK_N1, 1, BATCH * N_HEAD)
781
+ _attn_bwd_dkdv_kernel[grid_kv](
553
782
  q,
554
783
  arg_k,
555
784
  v,
556
785
  sm_scale,
557
786
  do,
558
- dq,
559
787
  dk,
560
- dv, #
788
+ dv,
561
789
  M,
562
- delta, #
563
- q2k_index,
564
- q2k_num,
565
- max_kv_blks,
790
+ delta,
566
791
  k2q_index,
567
792
  k2q_num,
568
793
  max_q_blks,
569
794
  variable_block_sizes,
795
+ q.stride(2),
796
+ q.stride(3),
570
797
  q.stride(0),
571
798
  q.stride(1),
572
- q.stride(2),
573
- q.stride(3), #
799
+ arg_k.stride(0),
800
+ arg_k.stride(1),
801
+ v.stride(0),
802
+ v.stride(1),
803
+ do.stride(0),
804
+ do.stride(1),
805
+ dk.stride(0),
806
+ dk.stride(1),
807
+ dv.stride(0),
808
+ dv.stride(1),
574
809
  N_HEAD,
575
- N_CTX, #
810
+ Tq,
811
+ Tkv,
576
812
  BLOCK_M1=BLOCK_M1,
577
- BLOCK_N1=BLOCK_N1, #
813
+ BLOCK_N1=BLOCK_N1,
814
+ HEAD_DIM=D,
815
+ )
816
+
817
+ # dQ kernel: grid over Q blocks
818
+ grid_q = (Tq // BLOCK_M2, 1, BATCH * N_HEAD)
819
+ _attn_bwd_dq_kernel[grid_q](
820
+ q,
821
+ arg_k,
822
+ v,
823
+ do,
824
+ dq,
825
+ M,
826
+ delta,
827
+ q2k_index,
828
+ q2k_num,
829
+ max_kv_blks,
830
+ variable_block_sizes,
831
+ q.stride(2),
832
+ q.stride(3),
833
+ q.stride(0),
834
+ q.stride(1),
835
+ arg_k.stride(0),
836
+ arg_k.stride(1),
837
+ v.stride(0),
838
+ v.stride(1),
839
+ do.stride(0),
840
+ do.stride(1),
841
+ dq.stride(0),
842
+ dq.stride(1),
843
+ N_HEAD,
844
+ Tq,
578
845
  BLOCK_M2=BLOCK_M2,
579
- BLOCK_N2=BLOCK_N2, #
580
- HEAD_DIM=D #
846
+ BLOCK_N2=BLOCK_N2,
847
+ HEAD_DIM=D,
581
848
  )
582
849
 
583
850
  return dq, dk, dv
@@ -0,0 +1 @@
1
+ __version__ = "0.2.6"
@@ -1 +0,0 @@
1
- __version__ = "0.2.5"