tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -267,6 +267,7 @@ def _ragged_paged_attention_kernel(
267
267
  *,
268
268
  sm_scale: float,
269
269
  sliding_window: int | None = None,
270
+ strict_sliding_window: bool = True,
270
271
  soft_cap: float | None = None,
271
272
  mask_value: float = DEFAULT_MASK_VALUE,
272
273
  q_scale: float | None = None,
@@ -317,19 +318,20 @@ def _ragged_paged_attention_kernel(
317
318
  q_len = q_end - q_start
318
319
  kv_len = kv_lens_ref[seq_idx]
319
320
 
320
- bkv_idx_start = 0 if sliding_window is None else jnp.maximum(
321
- kv_len - sliding_window, 0) // bkv_sz
322
-
323
321
  if sliding_window is None:
324
- next_bkv_idx_start = 0
322
+ bkv_idx_start = next_seq_bkv_idx_start = 0
325
323
  else:
324
+ bkv_idx_start = jnp.maximum(kv_len - q_len - sliding_window,
325
+ 0) // bkv_sz
326
326
 
327
327
  def get_next_bkv_idx_start():
328
328
  next_kv_len = kv_lens_ref[seq_idx + 1]
329
- return jnp.maximum(next_kv_len - sliding_window, 0) // bkv_sz
329
+ next_q_len = cu_q_lens_ref[seq_idx + 2] - q_end
330
+ return jnp.maximum(next_kv_len - next_q_len - sliding_window,
331
+ 0) // bkv_sz
330
332
 
331
- next_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
332
- get_next_bkv_idx_start, lambda: 0)
333
+ next_seq_bkv_idx_start = lax.cond(seq_idx + 1 < num_seqs,
334
+ get_next_bkv_idx_start, lambda: 0)
333
335
 
334
336
  def debug_print(msg, *args):
335
337
  if debug_mode:
@@ -350,7 +352,7 @@ def _ragged_paged_attention_kernel(
350
352
  debug_print("[RPA debug] q_len={}", q_len)
351
353
  debug_print("[RPA debug] kv_len={}", kv_len)
352
354
 
353
- def flash_attention(
355
+ def flash_attention_step1_qk_softmax(
354
356
  q, # [actual_bq_sz * num_q_heads_per_kv_head, actual_head_dim_x2]
355
357
  kv, # [bkv_sz, actual_head_dim_x2]
356
358
  *,
@@ -364,7 +366,6 @@ def _ragged_paged_attention_kernel(
364
366
  assert kv.shape == (bkv_sz, actual_head_dim_x2)
365
367
  head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
366
368
  head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
367
- head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
368
369
 
369
370
  def load_with_init(ref, init_val):
370
371
  return jnp.where(bkv_idx == bkv_idx_start,
@@ -386,16 +387,19 @@ def _ragged_paged_attention_kernel(
386
387
  s *= k_scale
387
388
  if q_scale is not None:
388
389
  s *= q_scale
390
+ if soft_cap is not None:
391
+ s = soft_cap * jnp.tanh(s / soft_cap)
389
392
 
390
393
  q_span = (kv_len - q_len + bq_idx * bq_sz +
391
394
  lax.broadcasted_iota(jnp.int32, s.shape, 0) //
392
395
  num_q_heads_per_kv_head)
393
396
  k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 1)
394
- mask = q_span < k_span
397
+ mask = k_span <= q_span
395
398
 
396
- if soft_cap is not None:
397
- s = soft_cap * jnp.tanh(s / soft_cap)
398
- s += jnp.where(mask, mask_value, 0.0)
399
+ if sliding_window is not None and strict_sliding_window:
400
+ mask = jnp.logical_and(mask, q_span - sliding_window < k_span)
401
+
402
+ s = jnp.where(mask, s, mask_value)
399
403
  s_rowmax = jnp.max(s, axis=1, keepdims=True)
400
404
 
401
405
  if attention_sink_ref is not None:
@@ -411,15 +415,33 @@ def _ragged_paged_attention_kernel(
411
415
  head_m_ref[...] = m_curr
412
416
  p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
413
417
 
414
- pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
415
- if v_scale is not None:
416
- pv *= v_scale
417
-
418
418
  p_rowsum = jnp.sum(p, axis=1, keepdims=True)
419
419
  exp_m_diff = jnp.exp(m_prev - m_curr)
420
420
  l_prev = load_with_init(head_l_ref, 1.0)
421
421
  l_curr = exp_m_diff * l_prev + p_rowsum
422
422
  head_l_ref[...] = l_curr
423
+
424
+ return p, exp_m_diff
425
+
426
+ def flash_attention_step2_pv(
427
+ q_shape_0,
428
+ kv, # [bkv_sz, actual_head_dim_x2]
429
+ p, # from step1
430
+ exp_m_diff, # from step1
431
+ *,
432
+ bkv_idx,
433
+ kv_head_idx,
434
+ ):
435
+ head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
436
+
437
+ def load_with_init(ref, init_val):
438
+ return jnp.where(bkv_idx == bkv_idx_start,
439
+ jnp.full_like(ref, init_val), ref[...])
440
+
441
+ pv = jnp.einsum("nm,md->nd", p, kv, preferred_element_type=jnp.float32)
442
+ if v_scale is not None:
443
+ pv *= v_scale
444
+
423
445
  o_prev = load_with_init(head_acc_ref, 0.0)
424
446
  o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
425
447
  head_acc_ref[...] = o_curr
@@ -475,42 +497,51 @@ def _ragged_paged_attention_kernel(
475
497
  debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
476
498
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
477
499
 
478
- # Fetch effective kv from kv cache.
479
- def loop_body(i, offset):
480
- sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
481
- _async_copy(
482
- cache_hbm_ref.at[pl.ds(
483
- page_indices_ref[page_indices_offset + i] * page_size,
484
- sz)],
485
- vmem_ref.at[pl.ds(i * page_size, sz)],
486
- sem,
487
- wait,
500
+ if not wait:
501
+ # Fetch effective kv from kv cache.
502
+ def loop_body(i, offset):
503
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
504
+ _async_copy(
505
+ cache_hbm_ref.at[pl.ds(
506
+ page_indices_ref[page_indices_offset + i] * page_size,
507
+ sz)],
508
+ vmem_ref.at[pl.ds(i * page_size, sz)],
509
+ sem,
510
+ wait=False,
511
+ )
512
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
513
+ return offset + sz
514
+
515
+ offset = lax.fori_loop(
516
+ 0,
517
+ bkv_p_frm_cache,
518
+ loop_body,
519
+ 0, # offset
520
+ unroll=False,
488
521
  )
489
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
490
- return offset + sz
491
-
492
- offset = lax.fori_loop(
493
- 0,
494
- bkv_p_frm_cache,
495
- loop_body,
496
- 0, # offset
497
- unroll=False,
498
- )
499
522
 
500
- # Fetch kv directly from new kv.
501
- @pl.when(bkv_sz_frm_new > 0)
502
- def _fetch_bkv_from_new_kv():
523
+ size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
503
524
  new_kv_len_start = q_end - kv_left_frm_new
504
525
  debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
505
526
  debug_print("[RPA debug] offset_in_bkv={}", offset)
506
527
  _async_copy(
507
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
508
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
528
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
529
+ vmem_ref.at[pl.ds(offset, size)],
509
530
  sem,
510
531
  wait,
511
532
  )
512
533
 
513
- return kv_len_start + offset, bkv_sz_frm_new
534
+ return kv_len_start + offset, bkv_sz_frm_new
535
+ else:
536
+ offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
537
+ dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
538
+ _async_copy(
539
+ src=dst,
540
+ dst=dst,
541
+ sem=sem,
542
+ wait=True,
543
+ )
544
+ return kv_len_start + offset, bkv_sz_frm_new
514
545
 
515
546
  def _update_kv_cache(seq_idx,
516
547
  bkv_sem_idx,
@@ -546,30 +577,41 @@ def _ragged_paged_attention_kernel(
546
577
  debug_print("[RPA debug] p_ignore={}", p_ignore)
547
578
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
548
579
 
549
- def loop_body(i, states):
550
- update_sz, ignore = states
551
- sz = jnp.minimum(page_size - ignore, update_sz)
552
-
580
+ if not wait:
581
+
582
+ def loop_body(i, states):
583
+ update_sz, ignore = states
584
+ sz = jnp.minimum(page_size - ignore, update_sz)
585
+
586
+ _async_copy(
587
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
588
+ sz)],
589
+ cache_hbm_ref.at[pl.ds(
590
+ page_indices_ref[page_indices_offset + i] * page_size +
591
+ ignore,
592
+ sz,
593
+ )],
594
+ sem,
595
+ wait=False,
596
+ )
597
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
598
+ return update_sz - sz, 0
599
+
600
+ lax.fori_loop(
601
+ 0,
602
+ kv_p_end - kv_p_start,
603
+ loop_body,
604
+ (update_sz, ignore), # total transfer size
605
+ unroll=False,
606
+ )
607
+ else:
608
+ dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
553
609
  _async_copy(
554
- vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
555
- cache_hbm_ref.at[pl.ds(
556
- page_indices_ref[page_indices_offset + i] * page_size +
557
- ignore,
558
- sz,
559
- )],
560
- sem,
561
- wait,
610
+ src=dst,
611
+ dst=dst,
612
+ sem=sem,
613
+ wait=True,
562
614
  )
563
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
564
- return update_sz - sz, 0
565
-
566
- lax.fori_loop(
567
- 0,
568
- kv_p_end - kv_p_start,
569
- loop_body,
570
- (update_sz, ignore), # total transfer size
571
- unroll=False,
572
- )
573
615
 
574
616
  def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
575
617
  sem = sems.at[1, bq_sem_idx]
@@ -737,13 +779,17 @@ def _ragged_paged_attention_kernel(
737
779
  next_seq_idx = lax.select(is_last_bq, seq_idx + 1, seq_idx)
738
780
  next_bkv_sem_idx = lax.select(bkv_sem_idx == 0, 1, 0)
739
781
 
740
- next_bkv_idx = lax.select(
741
- is_last_bkv,
742
- lax.select(
782
+ if sliding_window is None:
783
+ next_bkv_start_idx = 0
784
+ else:
785
+ next_bkv_start_idx = lax.select(
743
786
  is_last_bq,
744
- next_bkv_idx_start,
787
+ next_seq_bkv_idx_start,
745
788
  bkv_idx_start,
746
- ), next_bkv_idx)
789
+ )
790
+ next_bkv_idx = lax.select(is_last_bkv, next_bkv_start_idx,
791
+ next_bkv_idx)
792
+
747
793
  return next_seq_idx, next_bq_idx, next_bkv_idx, next_bkv_sem_idx
748
794
 
749
795
  def compute_with_bq(bq_idx, _):
@@ -803,6 +849,11 @@ def _ragged_paged_attention_kernel(
803
849
  return
804
850
 
805
851
  # Flash attention with cur bkv and bq
852
+ prev_bq_shape_0 = None
853
+ prev_kv_head_bkv = None
854
+ prev_kv_head_idx = None
855
+ prev_kv_head_p = None
856
+ prev_kv_head_exp_m_diff = None
806
857
  for kv_head_start in range(0, actual_num_kv_heads, kv_packing):
807
858
  bkv_lst = strided_load_bkv(
808
859
  bkv_sem_idx,
@@ -812,20 +863,51 @@ def _ragged_paged_attention_kernel(
812
863
  )
813
864
  assert len(bkv_lst) == kv_packing
814
865
  for i in range(kv_packing):
815
- kv_head_idx = kv_head_start + i
816
- if kv_head_idx >= actual_num_kv_heads:
866
+ cur_kv_head_idx = kv_head_start + i
867
+ if cur_kv_head_idx >= actual_num_kv_heads:
817
868
  break
818
- bq = load_bq(bq_sem_idx,
819
- kv_head_idx,
820
- actual_bq_sz=actual_bq_sz)
821
- bkv = bkv_lst[i]
822
- flash_attention(
823
- bq,
824
- bkv,
825
- bq_idx=bq_idx,
826
- bkv_idx=bkv_idx,
827
- kv_head_idx=kv_head_idx,
828
- )
869
+ cur_kv_head_bq = load_bq(bq_sem_idx,
870
+ cur_kv_head_idx,
871
+ actual_bq_sz=actual_bq_sz)
872
+ cur_kv_head__bkv = bkv_lst[i]
873
+ # FlashAttention is divided into `flash_attention_step1_qk_softmax`
874
+ # and `flash_attention_step2_pv` to pipeline the computation.
875
+ # `step2_pv` for the previous KV head, which depends on the softmax
876
+ # output, is overlapped with `step1_qk_softmax` for the current KV
877
+ # head, reducing overall wait times.
878
+ cur_kv_head_p, cur_kv_head_exp_m_diff = (
879
+ flash_attention_step1_qk_softmax(
880
+ cur_kv_head_bq,
881
+ cur_kv_head__bkv,
882
+ bq_idx=bq_idx,
883
+ bkv_idx=bkv_idx,
884
+ kv_head_idx=cur_kv_head_idx,
885
+ ))
886
+ if prev_bq_shape_0 is not None:
887
+ flash_attention_step2_pv(
888
+ prev_bq_shape_0,
889
+ prev_kv_head_bkv,
890
+ prev_kv_head_p,
891
+ prev_kv_head_exp_m_diff,
892
+ bkv_idx=bkv_idx,
893
+ kv_head_idx=prev_kv_head_idx,
894
+ )
895
+ prev_bq_shape_0 = cur_kv_head_bq.shape[0]
896
+ prev_kv_head_bkv = cur_kv_head__bkv
897
+ prev_kv_head_p = cur_kv_head_p
898
+ prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
899
+ prev_kv_head_idx = cur_kv_head_idx
900
+
901
+ # Execute pv of last attention head.
902
+ assert prev_bq_shape_0 is not None
903
+ flash_attention_step2_pv(
904
+ prev_bq_shape_0,
905
+ prev_kv_head_bkv,
906
+ prev_kv_head_p,
907
+ prev_kv_head_exp_m_diff,
908
+ bkv_idx=bkv_idx,
909
+ kv_head_idx=prev_kv_head_idx,
910
+ )
829
911
 
830
912
  lax.fori_loop(bkv_idx_start,
831
913
  num_bkv,
@@ -1226,6 +1308,7 @@ def static_validate_inputs(
1226
1308
  static_argnames=(
1227
1309
  "sm_scale",
1228
1310
  "sliding_window",
1311
+ "strict_sliding_window",
1229
1312
  "soft_cap",
1230
1313
  "mask_value",
1231
1314
  "q_scale",
@@ -1255,6 +1338,7 @@ def ragged_paged_attention_hd64(
1255
1338
  *,
1256
1339
  sm_scale: float = 1.0,
1257
1340
  sliding_window: int | None = None,
1341
+ strict_sliding_window: bool = True,
1258
1342
  soft_cap: float | None = None,
1259
1343
  mask_value: float | None = DEFAULT_MASK_VALUE,
1260
1344
  q_scale: float | None = None,
@@ -1269,42 +1353,41 @@ def ragged_paged_attention_hd64(
1269
1353
  # Debug params.
1270
1354
  debug_mode: bool = False,
1271
1355
  ):
1272
- """A special Ragged paged attention version for head_dim=64 that supports mixed
1273
-
1274
- prefill and decode.
1275
-
1276
- Args:
1277
- queries: concatenated all sequences' queries.
1278
- keys: concatenated all sequences' keys (quantized).
1279
- values: concatenated all sequences' values (quantized).
1280
- kv_cache: paged KV cache with TPU-friendly shape.
1281
- kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1282
- page_indices: flattened page indices look-up table by (seq_id, page_id).
1283
- cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1284
- kv_lens, only the first num_seqs+1 values are valid.
1285
- distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1286
- sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1287
- k is also the total number of sequences.
1288
- attention_sink: optional attention sink for each q head.
1289
- actual_head_dim: the actual head size of the attention. Here we assume k and
1290
- v have the same actual head size.
1291
- sm_scale: the softmax scale which will be applied to the Q@K^T.
1292
- sliding_window: the sliding window size for the attention.
1293
- soft_cap: the logit soft cap for the attention.
1294
- mask_value: mask value for causal mask.
1295
- k_scale: the scale for the key cache.
1296
- v_scale: the scale for the value cache.
1297
- num_kv_pages_per_block: number of kv pages to be processed in one flash
1298
- attention block in the pallas kernel.
1299
- num_queries_per_block: number of kv pages to be processed in one flash
1300
- attention block in the pallas kernel.
1301
- vmem_limit_bytes: the vmem limit for the pallas kernel.
1302
- debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1303
- print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1304
-
1305
- Returns:
1306
- The output of the attention.
1307
- """
1356
+ """A variant of ragged paged attention for head_dim=64.
1357
+
1358
+ Args:
1359
+ queries: concatenated all sequences' queries.
1360
+ keys: concatenated all sequences' keys (quantized).
1361
+ values: concatenated all sequences' values (quantized).
1362
+ kv_cache: paged KV cache with TPU-friendly shape.
1363
+ kv_lens: padded kv lengths. Only the first num_seqs values are valid.
1364
+ page_indices: flattened page indices look-up table by (seq_id, page_id).
1365
+ cu_q_lens: the cumulative sum of the effective query lengths. Similar to
1366
+ kv_lens, only the first num_seqs+1 values are valid.
1367
+ distribution: (i, j, k) represents that sequences[0:i] are decode-only,
1368
+ sequences[i:j] are chunked-prefill-only, and sequences[j:k] are mixed. The
1369
+ k is also the total number of sequences.
1370
+ attention_sink: optional attention sink for each q head.
1371
+ sm_scale: the softmax scale which will be applied to the Q@K^T.
1372
+ sliding_window: the sliding window size for the attention.
1373
+ strict_sliding_window: compute tokens that are strictly within the window.
1374
+ soft_cap: the logit soft cap for the attention.
1375
+ mask_value: mask value for causal mask.
1376
+ q_scale: the scale for the query.
1377
+ k_scale: the scale for the key cache.
1378
+ v_scale: the scale for the value cache.
1379
+ chunk_prefill_size: the chunk prefill size for the attention.
1380
+ num_kv_pages_per_block: number of kv pages to be processed in one flash
1381
+ attention block in the pallas kernel.
1382
+ num_queries_per_block: number of kv pages to be processed in one flash
1383
+ attention block in the pallas kernel.
1384
+ vmem_limit_bytes: the vmem limit for the pallas kernel.
1385
+ debug_mode: if true, RPA does not issue any DMAs or run flash attention but
1386
+ print debug info. Need to compile with `--xla_tpu_enable_log_recorder`.
1387
+
1388
+ Returns:
1389
+ The output of the attention.
1390
+ """
1308
1391
  q, k, v = queries, keys, values
1309
1392
  static_validate_inputs(
1310
1393
  q,
@@ -1374,7 +1457,7 @@ def ragged_paged_attention_hd64(
1374
1457
  pl.BlockSpec(memory_space=pltpu.HBM),
1375
1458
  pl.BlockSpec(memory_space=pltpu.HBM),
1376
1459
  None if attention_sink is None else pl.BlockSpec(
1377
- memory_space=pltpu.VMEM)
1460
+ memory_space=pltpu.VMEM),
1378
1461
  ]
1379
1462
 
1380
1463
  out_specs = [
@@ -1438,6 +1521,7 @@ def ragged_paged_attention_hd64(
1438
1521
  _ragged_paged_attention_kernel,
1439
1522
  sm_scale=sm_scale,
1440
1523
  sliding_window=sliding_window,
1524
+ strict_sliding_window=strict_sliding_window,
1441
1525
  soft_cap=soft_cap,
1442
1526
  mask_value=mask_value,
1443
1527
  q_scale=q_scale,
@@ -295,7 +295,8 @@ def get_tuned_block_sizes(
295
295
  bkv_p, bq = TUNED_BLOCK_SIZES[device][page_size][dtypes][head_dims][
296
296
  max_model_len]
297
297
  except KeyError:
298
- print('Couldn`t find tuned sizes for the RPA v3 kernel with %s', keys)
298
+ logger.warning_once(
299
+ 'Couldn`t find tuned sizes for the RPA v3 kernel with %s', keys)
299
300
 
300
301
  return (min(pages_per_seq, bkv_p), min(max_num_tokens, bq))
301
302
 
@@ -13,7 +13,8 @@ def align_to(x, a):
13
13
 
14
14
 
15
15
  def get_dtype_bitwidth(dtype):
16
- return dtypes.bit_width(dtype)
16
+ return (dtypes.bit_width(dtype)
17
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
17
18
 
18
19
 
19
20
  def get_dtype_packing(dtype):
@@ -308,7 +308,13 @@ def sharded_ragged_paged_attention(
308
308
  args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
309
309
 
310
310
  use_hd64 = q.shape[-1] == 64
311
- func = ragged_paged_attention_hd64 if use_hd64 else ragged_paged_attention
311
+
312
+ func = ragged_paged_attention
313
+ if use_hd64:
314
+ func = functools.partial(ragged_paged_attention_hd64,
315
+ strict_sliding_window=True)
316
+ else:
317
+ func = ragged_paged_attention
312
318
 
313
319
  if attention_sink is not None:
314
320
  if not use_hd64:
@@ -1,6 +1,5 @@
1
1
  import json
2
2
  import math
3
- import os
4
3
  from dataclasses import asdict, dataclass
5
4
  from typing import TYPE_CHECKING, List, Optional
6
5
 
@@ -8,7 +7,7 @@ import jax.numpy as jnp
8
7
  import numpy as np
9
8
  from jax.sharding import Mesh
10
9
 
11
- from tpu_inference import utils
10
+ from tpu_inference import envs, utils
12
11
 
13
12
  if TYPE_CHECKING:
14
13
  from vllm.v1.configs.vllm_config import VllmConfig
@@ -48,7 +47,7 @@ class ShardingAxisName2D:
48
47
 
49
48
 
50
49
  try:
51
- _use_base_sharding = os.getenv("NEW_MODEL_DESIGN", False)
50
+ _use_base_sharding = envs.NEW_MODEL_DESIGN
52
51
  if _use_base_sharding:
53
52
  ShardingAxisName = ShardingAxisNameBase
54
53
  else:
@@ -120,9 +119,13 @@ class ShardingConfigManager:
120
119
  False)
121
120
  if enable_dp_attention:
122
121
  # Replicate attention layer when num_kv_heads < TP
123
- num_kv_heads = vllm_config.model_config.get_total_num_kv_heads()
122
+ num_kv_heads = 1 if vllm_config.model_config.use_mla else vllm_config.model_config.get_total_num_kv_heads(
123
+ )
124
+ cache_dtype = vllm_config.cache_config.cache_dtype
125
+ if cache_dtype == 'auto':
126
+ cache_dtype = vllm_config.model_config.dtype
124
127
  kv_dtype = utils.get_jax_dtype_from_str_dtype(
125
- vllm_config.cache_config.cache_dtype) or jnp.bfloat16
128
+ cache_dtype) or jnp.bfloat16
126
129
  packing = 4 // jnp.dtype(kv_dtype).itemsize
127
130
  # When num_kv_heads * 2 / packing < TP, tensor parallelism would
128
131
  # duplicate KV heads across devices, wasting kv cache memory.
@@ -166,9 +169,10 @@ class ShardingConfigManager:
166
169
  f"LoRA is not supported with data parallelism "
167
170
  f"(DP size: {total_dp_size}). Please disable LoRA or "
168
171
  f"set data parallelism to 1.")
169
- if not os.environ.get("NEW_MODEL_DESIGN", False):
172
+ if sharding_strategy.attention_data_parallelism > 1:
173
+ if not envs.NEW_MODEL_DESIGN:
170
174
  raise ValueError(
171
- "Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
175
+ "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
172
176
  "NEW_MODEL_DESIGN=True.")
173
177
 
174
178
  @property