tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (59) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -1
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/test_envs.py +110 -12
  9. tests/test_quantization.py +3 -0
  10. tests/test_utils.py +1 -2
  11. tpu_inference/distributed/tpu_connector.py +1 -1
  12. tpu_inference/envs.py +92 -8
  13. tpu_inference/executors/ray_distributed_executor.py +5 -1
  14. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  15. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  16. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  17. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  18. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  19. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  20. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  21. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
  22. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
  23. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  25. tpu_inference/layers/common/attention_interface.py +7 -1
  26. tpu_inference/layers/common/sharding.py +11 -7
  27. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  28. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  29. tpu_inference/layers/vllm/fused_moe.py +170 -208
  30. tpu_inference/layers/vllm/linear_common.py +43 -21
  31. tpu_inference/layers/vllm/quantization/common.py +11 -6
  32. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  33. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  34. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  35. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  36. tpu_inference/models/common/model_loader.py +78 -22
  37. tpu_inference/models/jax/deepseek_v3.py +185 -64
  38. tpu_inference/models/jax/gpt_oss.py +3 -3
  39. tpu_inference/models/jax/llama_eagle3.py +4 -5
  40. tpu_inference/models/jax/qwen2_5_vl.py +161 -47
  41. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  42. tpu_inference/models/jax/utils/weight_utils.py +203 -155
  43. tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
  44. tpu_inference/platforms/tpu_platform.py +29 -48
  45. tpu_inference/runner/compilation_manager.py +112 -46
  46. tpu_inference/runner/kv_cache.py +40 -20
  47. tpu_inference/runner/kv_cache_manager.py +40 -31
  48. tpu_inference/runner/persistent_batch_manager.py +40 -2
  49. tpu_inference/runner/structured_decoding_manager.py +2 -3
  50. tpu_inference/runner/tpu_runner.py +94 -51
  51. tpu_inference/runner/utils.py +2 -2
  52. tpu_inference/spec_decode/jax/eagle3.py +71 -22
  53. tpu_inference/utils.py +41 -14
  54. tpu_inference/worker/tpu_worker.py +43 -45
  55. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
  56. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
  57. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  58. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  59. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -319,7 +319,7 @@ def _ragged_paged_attention_kernel(
319
319
  debug_print("[RPA debug] q_len={}", q_len)
320
320
  debug_print("[RPA debug] kv_len={}", kv_len)
321
321
 
322
- def flash_attention(
322
+ def flash_attention_step1_qk_softmax(
323
323
  q, # [actual_bq_sz * num_q_heads_per_kv_head, head_dim]
324
324
  k, # [bkv_sz, head_dim]
325
325
  v, # [bkv_sz, head_dim]
@@ -335,7 +335,6 @@ def _ragged_paged_attention_kernel(
335
335
  assert k.dtype == v.dtype
336
336
  head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
337
337
  head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
338
- head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
339
338
 
340
339
  def load_with_init(ref, init_val):
341
340
  return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
@@ -376,15 +375,32 @@ def _ragged_paged_attention_kernel(
376
375
  head_m_ref[...] = m_curr
377
376
  p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
378
377
 
379
- pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
380
- if v_scale is not None:
381
- pv *= v_scale
382
-
383
378
  p_rowsum = jnp.sum(p, axis=1, keepdims=True)
384
379
  exp_m_diff = jnp.exp(m_prev - m_curr)
385
380
  l_prev = load_with_init(head_l_ref, 0.0)
386
381
  l_curr = exp_m_diff * l_prev + p_rowsum
387
382
  head_l_ref[...] = l_curr
383
+
384
+ return p, exp_m_diff
385
+
386
+ def flash_attention_step2_pv(
387
+ q_shape_0,
388
+ v, # [bkv_sz, head_dim]
389
+ p, # from step1
390
+ exp_m_diff, # from step1
391
+ *,
392
+ bkv_idx,
393
+ kv_head_idx,
394
+ ):
395
+ head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
396
+
397
+ def load_with_init(ref, init_val):
398
+ return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
399
+ ref[...])
400
+
401
+ pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
402
+ if v_scale is not None:
403
+ pv *= v_scale
388
404
  o_prev = load_with_init(head_acc_ref, 0.0)
389
405
  o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
390
406
  head_acc_ref[...] = o_curr
@@ -463,19 +479,16 @@ def _ragged_paged_attention_kernel(
463
479
  unroll=False,
464
480
  )
465
481
 
466
- # Fetch kv directly from new kv.
467
- @pl.when(bkv_sz_frm_new > 0)
468
- def _fetch_bkv_from_new_kv():
469
- new_kv_len_start = q_end - kv_left_frm_new
470
- debug_print("[RPA debug] new_kv_len_start={}",
471
- new_kv_len_start)
472
- debug_print("[RPA debug] offset_in_bkv={}", offset)
473
- _async_copy(
474
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
475
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
476
- sem,
477
- wait,
478
- )
482
+ size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
483
+ new_kv_len_start = q_end - kv_left_frm_new
484
+ debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
485
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
486
+ _async_copy(
487
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
488
+ vmem_ref.at[pl.ds(offset, size)],
489
+ sem,
490
+ wait,
491
+ )
479
492
 
480
493
  return kv_len_start + offset, bkv_sz_frm_new
481
494
  else:
@@ -842,6 +855,11 @@ def _ragged_paged_attention_kernel(
842
855
 
843
856
  # Flash attention with cur bkv and bq
844
857
  # NOTE: kv_packing is divided by 2 because k and v are packed together.
858
+ prev_bq_shape_0 = None
859
+ prev_kv_head_bv = None
860
+ prev_kv_head_idx = None
861
+ prev_kv_head_p = None
862
+ prev_kv_head_exp_m_diff = None
845
863
  heads_per_load = max(1, kv_packing // 2)
846
864
  for kv_head_start in range(0, actual_num_kv_heads,
847
865
  heads_per_load):
@@ -853,21 +871,53 @@ def _ragged_paged_attention_kernel(
853
871
  )
854
872
  assert len(bkv_lst) == heads_per_load
855
873
  for i in range(heads_per_load):
856
- kv_head_idx = kv_head_start + i
857
- if kv_head_idx >= actual_num_kv_heads:
874
+ cur_kv_head_idx = kv_head_start + i
875
+ if cur_kv_head_idx >= actual_num_kv_heads:
858
876
  break
859
- bq = load_bq(bq_sem_idx,
860
- kv_head_idx,
861
- actual_bq_sz=actual_bq_sz)
877
+
878
+ cur_kv_head_bq = load_bq(bq_sem_idx,
879
+ cur_kv_head_idx,
880
+ actual_bq_sz=actual_bq_sz)
862
881
  bk, bv = bkv_lst[i]
863
- flash_attention(
864
- bq,
865
- bk,
866
- bv,
867
- bq_idx=bq_idx,
868
- bkv_idx=bkv_idx,
869
- kv_head_idx=kv_head_idx,
870
- )
882
+ # FlashAttention is divided into `flash_attention_step1_qk_softmax`
883
+ # and `flash_attention_step2_pv` to pipeline the computation.
884
+ # `step2_pv` for the previous KV head, which depends on the softmax
885
+ # output, is overlapped with `step1_qk_softmax` for the current KV
886
+ # head, reducing overall wait times.
887
+ cur_kv_head_p, cur_kv_head_exp_m_diff = (
888
+ flash_attention_step1_qk_softmax(
889
+ cur_kv_head_bq,
890
+ bk,
891
+ bv,
892
+ bq_idx=bq_idx,
893
+ bkv_idx=bkv_idx,
894
+ kv_head_idx=cur_kv_head_idx,
895
+ ))
896
+ if prev_bq_shape_0 is not None:
897
+ flash_attention_step2_pv(
898
+ prev_bq_shape_0,
899
+ prev_kv_head_bv,
900
+ prev_kv_head_p,
901
+ prev_kv_head_exp_m_diff,
902
+ bkv_idx=bkv_idx,
903
+ kv_head_idx=prev_kv_head_idx,
904
+ )
905
+ prev_bq_shape_0 = cur_kv_head_bq.shape[0]
906
+ prev_kv_head_bv = bv
907
+ prev_kv_head_p = cur_kv_head_p
908
+ prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
909
+ prev_kv_head_idx = cur_kv_head_idx
910
+
911
+ # Execute pv of last attention head.
912
+ assert prev_bq_shape_0 is not None
913
+ flash_attention_step2_pv(
914
+ prev_bq_shape_0,
915
+ prev_kv_head_bv,
916
+ prev_kv_head_p,
917
+ prev_kv_head_exp_m_diff,
918
+ bkv_idx=bkv_idx,
919
+ kv_head_idx=prev_kv_head_idx,
920
+ )
871
921
 
872
922
  lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
873
923
 
@@ -267,6 +267,7 @@ def _ragged_paged_attention_kernel(
267
267
  *,
268
268
  sm_scale: float,
269
269
  sliding_window: int | None = None,
270
+ strict_sliding_window: bool = True,
270
271
  soft_cap: float | None = None,
271
272
  mask_value: float = DEFAULT_MASK_VALUE,
272
273
  q_scale: float | None = None,
@@ -317,19 +318,20 @@ def _ragged_paged_attention_kernel(
317
318
  q_len = q_end - q_start
318
319
  kv_len = kv_lens_ref[seq_idx]
319
320
 
320
- bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
321
- kv_len - sliding_window, 0) // bkv_sz
322
-
323
321
  if sliding_window is None:
324
- next_bkv_idx_start = 0
322
+ bkv_idx_start = next_seq_bkv_idx_start = 0
325
323
  else:
324
+ bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
325
+ 0) // bkv_sz
326
326
 
327
327
  def get_next_bkv_idx_start():
328
328
  next_kv_len = kv_lens_ref[seq_idx + 1]
329
- return jnp.maximum(next_kv_len - sliding_window, 0) // bkv_sz
329
+ next_q_len = cu_q_lens_ref[seq_idx + 2] - q_end
330
+ return jnp.maximum(next_kv_len - next_q_len - sliding_window,
331
+ 0) // bkv_sz
330
332
 
331
- next_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
332
- get_next_bkv_idx_start, lambda: 0)
333
+ next_seq_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
334
+ get_next_bkv_idx_start, lambda: 0)
333
335
 
334
336
  def debug_print(msg, *args):
335
337
  if debug_mode:
@@ -350,7 +352,7 @@ def _ragged_paged_attention_kernel(
350
352
  debug_print("[RPA debug] q_len={}", q_len)
351
353
  debug_print("[RPA debug] kv_len={}", kv_len)
352
354
 
353
- def flash_attention(
355
+ def flash_attention_step1_qk_softmax(
354
356
  q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
355
357
  kv, # [bkv_sz, actual_head_dim_x2]
356
358
  *,
@@ -364,7 +366,6 @@ def _ragged_paged_attention_kernel(
364
366
  assert kv.shape == (bkv_sz, actual_head_dim_x2)
365
367
  head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
366
368
  head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
367
- head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
368
369
 
369
370
  def load_with_init(ref, init_val):
370
371
  return jnp.where(bkv_idx == bkv_idx_start,
@@ -386,16 +387,19 @@ def _ragged_paged_attention_kernel(
386
387
  s *= k_scale
387
388
  if q_scale is not None:
388
389
  s *= q_scale
390
+ if soft_cap is not None:
391
+ s = soft_cap * jnp.tanh(s / soft_cap)
389
392
 
390
393
  q_span = (kv_len - q_len + bq_idx * bq_sz +
391
394
  lax.broadcasted_iota(jnp.int32, s.shape, 0) //
392
395
  num_q_heads_per_kv_head)
393
396
  k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
394
- mask = q_span < k_span
397
+ mask = k_span <= q_span
395
398
 
396
- if soft_cap is not None:
397
- s = soft_cap * jnp.tanh(s / soft_cap)
398
- s += jnp.where(mask, mask_value, 0.0)
399
+ if sliding_window is not None and strict_sliding_window:
400
+ mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
401
+
402
+ s = jnp.where(mask, s, mask_value)
399
403
  s_rowmax = jnp.max(s, axis=1, keepdims=True)
400
404
 
401
405
  if attention_sink_ref is not None:
@@ -411,15 +415,33 @@ def _ragged_paged_attention_kernel(
411
415
  head_m_ref[...] = m_curr
412
416
  p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
413
417
 
414
- pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
415
- if v_scale is not None:
416
- pv *= v_scale
417
-
418
418
  p_rowsum = jnp.sum(p, axis=1, keepdims=True)
419
419
  exp_m_diff = jnp.exp(m_prev - m_curr)
420
420
  l_prev = load_with_init(head_l_ref, 1.0)
421
421
  l_curr = exp_m_diff * l_prev + p_rowsum
422
422
  head_l_ref[...] = l_curr
423
+
424
+ return p, exp_m_diff
425
+
426
+ def flash_attention_step2_pv(
427
+ q_shape_0,
428
+ kv, # [bkv_sz, actual_head_dim_x2]
429
+ p, # from step1
430
+ exp_m_diff, # from step1
431
+ *,
432
+ bkv_idx,
433
+ kv_head_idx,
434
+ ):
435
+ head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
436
+
437
+ def load_with_init(ref, init_val):
438
+ return jnp.where(bkv_idx == bkv_idx_start,
439
+ jnp.full_like(ref, init_val), ref[...])
440
+
441
+ pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
442
+ if v_scale is not None:
443
+ pv *= v_scale
444
+
423
445
  o_prev = load_with_init(head_acc_ref, 0.0)
424
446
  o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
425
447
  head_acc_ref[...] = o_curr
@@ -498,19 +520,16 @@ def _ragged_paged_attention_kernel(
498
520
  unroll=False,
499
521
  )
500
522
 
501
- # Fetch kv directly from new kv.
502
- @pl.when(bkv_sz_frm_new > 0)
503
- def _fetch_bkv_from_new_kv():
504
- new_kv_len_start = q_end - kv_left_frm_new
505
- debug_print("[RPA debug] new_kv_len_start={}",
506
- new_kv_len_start)
507
- debug_print("[RPA debug] offset_in_bkv={}", offset)
508
- _async_copy(
509
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
510
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
511
- sem,
512
- wait,
513
- )
523
+ size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
524
+ new_kv_len_start = q_end - kv_left_frm_new
525
+ debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
526
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
527
+ _async_copy(
528
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
529
+ vmem_ref.at[pl.ds(offset, size)],
530
+ sem,
531
+ wait,
532
+ )
514
533
 
515
534
  return kv_len_start + offset, bkv_sz_frm_new
516
535
  else:
@@ -760,13 +779,17 @@ def _ragged_paged_attention_kernel(
760
779
  next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
761
780
  next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
762
781
 
763
- next_bkv_idx = lax.select(
764
- is_last_bkv,
765
- lax.select(
782
+ if sliding_window is None:
783
+ next_bkv_start_idx = 0
784
+ else:
785
+ next_bkv_start_idx = lax.select(
766
786
  is_last_bq,
767
- next_bkv_idx_start,
787
+ next_seq_bkv_idx_start,
768
788
  bkv_idx_start,
769
- ), next_bkv_idx)
789
+ )
790
+ next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
791
+ next_bkv_idx)
792
+
770
793
  return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
771
794
 
772
795
  def compute_with_bq(bq_idx, _):
@@ -826,6 +849,11 @@ def _ragged_paged_attention_kernel(
826
849
  return
827
850
 
828
851
  # Flash attention with cur bkv and bq
852
+ prev_bq_shape_0 = None
853
+ prev_kv_head_bkv = None
854
+ prev_kv_head_idx = None
855
+ prev_kv_head_p = None
856
+ prev_kv_head_exp_m_diff = None
829
857
  for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
830
858
  bkv_lst = strided_load_bkv(
831
859
  bkv_sem_idx,
@@ -835,20 +863,51 @@ def _ragged_paged_attention_kernel(
835
863
  )
836
864
  assert len(bkv_lst) == kv_packing
837
865
  for i in range(kv_packing):
838
- kv_head_idx = kv_head_start + i
839
- if kv_head_idx >= actual_num_kv_heads:
866
+ cur_kv_head_idx = kv_head_start + i
867
+ if cur_kv_head_idx >= actual_num_kv_heads:
840
868
  break
841
- bq = load_bq(bq_sem_idx,
842
- kv_head_idx,
843
- actual_bq_sz=actual_bq_sz)
844
- bkv = bkv_lst[i]
845
- flash_attention(
846
- bq,
847
- bkv,
848
- bq_idx=bq_idx,
849
- bkv_idx=bkv_idx,
850
- kv_head_idx=kv_head_idx,
851
- )
869
+ cur_kv_head_bq = load_bq(bq_sem_idx,
870
+ cur_kv_head_idx,
871
+ actual_bq_sz=actual_bq_sz)
872
+ cur_kv_head__bkv = bkv_lst[i]
873
+ # FlashAttention is divided into `flash_attention_step1_qk_softmax`
874
+ # and `flash_attention_step2_pv` to pipeline the computation.
875
+ # `step2_pv` for the previous KV head, which depends on the softmax
876
+ # output, is overlapped with `step1_qk_softmax` for the current KV
877
+ # head, reducing overall wait times.
878
+ cur_kv_head_p, cur_kv_head_exp_m_diff = (
879
+ flash_attention_step1_qk_softmax(
880
+ cur_kv_head_bq,
881
+ cur_kv_head__bkv,
882
+ bq_idx=bq_idx,
883
+ bkv_idx=bkv_idx,
884
+ kv_head_idx=cur_kv_head_idx,
885
+ ))
886
+ if prev_bq_shape_0 is not None:
887
+ flash_attention_step2_pv(
888
+ prev_bq_shape_0,
889
+ prev_kv_head_bkv,
890
+ prev_kv_head_p,
891
+ prev_kv_head_exp_m_diff,
892
+ bkv_idx=bkv_idx,
893
+ kv_head_idx=prev_kv_head_idx,
894
+ )
895
+ prev_bq_shape_0 = cur_kv_head_bq.shape[0]
896
+ prev_kv_head_bkv = cur_kv_head__bkv
897
+ prev_kv_head_p = cur_kv_head_p
898
+ prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
899
+ prev_kv_head_idx = cur_kv_head_idx
900
+
901
+ # Execute pv of last attention head.
902
+ assert prev_bq_shape_0 is not None
903
+ flash_attention_step2_pv(
904
+ prev_bq_shape_0,
905
+ prev_kv_head_bkv,
906
+ prev_kv_head_p,
907
+ prev_kv_head_exp_m_diff,
908
+ bkv_idx=bkv_idx,
909
+ kv_head_idx=prev_kv_head_idx,
910
+ )
852
911
 
853
912
  lax.fori_loop(bkv_idx_start,
854
913
  num_bkv,
@@ -1249,6 +1308,7 @@ def static_validate_inputs(
1249
1308
  static_argnames=(
1250
1309
  "sm_scale",
1251
1310
  "sliding_window",
1311
+ "strict_sliding_window",
1252
1312
  "soft_cap",
1253
1313
  "mask_value",
1254
1314
  "q_scale",
@@ -1278,6 +1338,7 @@ def ragged_paged_attention_hd64(
1278
1338
  *,
1279
1339
  sm_scale: float = 1.0,
1280
1340
  sliding_window: int | None = None,
1341
+ strict_sliding_window: bool = True,
1281
1342
  soft_cap: float | None = None,
1282
1343
  mask_value: float | None = DEFAULT_MASK_VALUE,
1283
1344
  q_scale: float | None = None,
@@ -1292,42 +1353,41 @@ def ragged_paged_attention_hd64(
1292
1353
  # Debug params.
1293
1354
  debug_mode: bool = False,
1294
1355
  ):
1295
- """A special Ragged paged attention version for head_dim=64 that supports mixed
1296
-
1297
- prefill and decode.
1298
-
1299
- Args:
1300
- queries: concatenated all sequences' queries.
1301
- keys: concatenated all sequences' keys (quantized).
1302
- values: concatenated all sequences' values (quantized).
1303
- kv_cache: paged KV cache with TPU-friendly shape.
1304
- kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1305
- page_indices: flattened page indices look-up table by (seq_id, page_id).
1306
- cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1307
- kv_lens, only the first num_seqs+1 values are valid.
1308
- distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1309
- sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1310
- k is also the total number of sequences.
1311
- attention_sink: optional attention sink for each q head.
1312
- actual_head_dim: the actual head size of the attention. Here we assume k and
1313
- v have the same actual head size.
1314
- sm_scale: the softmax scale which will be applied to the Q@K^T.
1315
- sliding_window: the sliding window size for the attention.
1316
- soft_cap: the logit soft cap for the attention.
1317
- mask_value: mask value for causal mask.
1318
- k_scale: the scale for the key cache.
1319
- v_scale: the scale for the value cache.
1320
- num_kv_pages_per_block: number of kv pages to be processed in one flash
1321
- attention block in the pallas kernel.
1322
- num_queries_per_block: number of kv pages to be processed in one flash
1323
- attention block in the pallas kernel.
1324
- vmem_limit_bytes: the vmem limit for the pallas kernel.
1325
- debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1326
- print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1327
-
1328
- Returns:
1329
- The output of the attention.
1330
- """
1356
+ """A variant of ragged paged attention for head_dim=64.
1357
+
1358
+ Args:
1359
+ queries: concatenated all sequences' queries.
1360
+ keys: concatenated all sequences' keys (quantized).
1361
+ values: concatenated all sequences' values (quantized).
1362
+ kv_cache: paged KV cache with TPU-friendly shape.
1363
+ kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1364
+ page_indices: flattened page indices look-up table by (seq_id, page_id).
1365
+ cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1366
+ kv_lens, only the first num_seqs+1 values are valid.
1367
+ distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1368
+ sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1369
+ k is also the total number of sequences.
1370
+ attention_sink: optional attention sink for each q head.
1371
+ sm_scale: the softmax scale which will be applied to the Q@K^T.
1372
+ sliding_window: the sliding window size for the attention.
1373
+ strict_sliding_window: compute tokens that are strictly within the window.
1374
+ soft_cap: the logit soft cap for the attention.
1375
+ mask_value: mask value for causal mask.
1376
+ q_scale: the scale for the query.
1377
+ k_scale: the scale for the key cache.
1378
+ v_scale: the scale for the value cache.
1379
+ chunk_prefill_size: the chunk prefill size for the attention.
1380
+ num_kv_pages_per_block: number of kv pages to be processed in one flash
1381
+ attention block in the pallas kernel.
1382
+ num_queries_per_block: number of kv pages to be processed in one flash
1383
+ attention block in the pallas kernel.
1384
+ vmem_limit_bytes: the vmem limit for the pallas kernel.
1385
+ debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1386
+ print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1387
+
1388
+ Returns:
1389
+ The output of the attention.
1390
+ """
1331
1391
  q, k, v = queries, keys, values
1332
1392
  static_validate_inputs(
1333
1393
  q,
@@ -1397,7 +1457,7 @@ def ragged_paged_attention_hd64(
1397
1457
  pl.BlockSpec(memory_space=pltpu.HBM),
1398
1458
  pl.BlockSpec(memory_space=pltpu.HBM),
1399
1459
  None if attention_sink is None else pl.BlockSpec(
1400
- memory_space=pltpu.VMEM)
1460
+ memory_space=pltpu.VMEM),
1401
1461
  ]
1402
1462
 
1403
1463
  out_specs = [
@@ -1461,6 +1521,7 @@ def ragged_paged_attention_hd64(
1461
1521
  _ragged_paged_attention_kernel,
1462
1522
  sm_scale=sm_scale,
1463
1523
  sliding_window=sliding_window,
1524
+ strict_sliding_window=strict_sliding_window,
1464
1525
  soft_cap=soft_cap,
1465
1526
  mask_value=mask_value,
1466
1527
  q_scale=q_scale,
@@ -295,7 +295,8 @@ def get_tuned_block_sizes(
295
295
  bkv_p, bq = TUNED_BLOCK_SIZES[device][page_size][dtypes][head_dims][
296
296
  max_model_len]
297
297
  except KeyError:
298
- print('Couldn`t find tuned sizes for the RPA v3 kernel with %s', keys)
298
+ logger.warning_once(
299
+ 'Couldn`t find tuned sizes for the RPA v3 kernel with %s', keys)
299
300
 
300
301
  return (min(pages_per_seq, bkv_p), min(max_num_tokens, bq))
301
302
 
@@ -13,7 +13,8 @@ def align_to(x, a):
13
13
 
14
14
 
15
15
  def get_dtype_bitwidth(dtype):
16
- return dtypes.bit_width(dtype)
16
+ return (dtypes.bit_width(dtype)
17
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
17
18
 
18
19
 
19
20
  def get_dtype_packing(dtype):
@@ -308,7 +308,13 @@ def sharded_ragged_paged_attention(
308
308
  args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
309
309
 
310
310
  use_hd64 = q.shape[-1] == 64
311
- func = ragged_paged_attention_hd64 if use_hd64 else ragged_paged_attention
311
+
312
+ func = ragged_paged_attention
313
+ if use_hd64:
314
+ func = functools.partial(ragged_paged_attention_hd64,
315
+ strict_sliding_window=True)
316
+ else:
317
+ func = ragged_paged_attention
312
318
 
313
319
  if attention_sink is not None:
314
320
  if not use_hd64:
@@ -1,6 +1,5 @@
1
1
  import json
2
2
  import math
3
- import os
4
3
  from dataclasses import asdict, dataclass
5
4
  from typing import TYPE_CHECKING, List, Optional
6
5
 
@@ -8,7 +7,7 @@ import jax.numpy as jnp
8
7
  import numpy as np
9
8
  from jax.sharding import Mesh
10
9
 
11
- from tpu_inference import utils
10
+ from tpu_inference import envs, utils
12
11
 
13
12
  if TYPE_CHECKING:
14
13
  from vllm.v1.configs.vllm_config import VllmConfig
@@ -48,7 +47,7 @@ class ShardingAxisName2D:
48
47
 
49
48
 
50
49
  try:
51
- _use_base_sharding = os.getenv("NEW_MODEL_DESIGN", False)
50
+ _use_base_sharding = envs.NEW_MODEL_DESIGN
52
51
  if _use_base_sharding:
53
52
  ShardingAxisName = ShardingAxisNameBase
54
53
  else:
@@ -120,9 +119,13 @@ class ShardingConfigManager:
120
119
  False)
121
120
  if enable_dp_attention:
122
121
  # Replicate attention layer when num_kv_heads < TP
123
- num_kv_heads = vllm_config.model_config.get_total_num_kv_heads()
122
+ num_kv_heads = 1 if vllm_config.model_config.use_mla else vllm_config.model_config.get_total_num_kv_heads(
123
+ )
124
+ cache_dtype = vllm_config.cache_config.cache_dtype
125
+ if cache_dtype == 'auto':
126
+ cache_dtype = vllm_config.model_config.dtype
124
127
  kv_dtype = utils.get_jax_dtype_from_str_dtype(
125
- vllm_config.cache_config.cache_dtype) or jnp.bfloat16
128
+ cache_dtype) or jnp.bfloat16
126
129
  packing = 4 // jnp.dtype(kv_dtype).itemsize
127
130
  # When num_kv_heads * 2 / packing < TP, tensor parallelism would
128
131
  # duplicate KV heads across devices, wasting kv cache memory.
@@ -166,9 +169,10 @@ class ShardingConfigManager:
166
169
  f"LoRA is not supported with data parallelism "
167
170
  f"(DP size: {total_dp_size}). Please disable LoRA or "
168
171
  f"set data parallelism to 1.")
169
- if not os.environ.get("NEW_MODEL_DESIGN", False):
172
+ if sharding_strategy.attention_data_parallelism > 1:
173
+ if not envs.NEW_MODEL_DESIGN:
170
174
  raise ValueError(
171
- "Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
175
+ "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
172
176
  "NEW_MODEL_DESIGN=True.")
173
177
 
174
178
  @property