tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -319,7 +319,7 @@ def _ragged_paged_attention_kernel(
319
319
  debug_print("[RPA debug] q_len={}", q_len)
320
320
  debug_print("[RPA debug] kv_len={}", kv_len)
321
321
 
322
- def flash_attention(
322
+ def flash_attention_step1_qk_softmax(
323
323
  q, # [actual_bq_sz * num_q_heads_per_kv_head, head_dim]
324
324
  k, # [bkv_sz, head_dim]
325
325
  v, # [bkv_sz, head_dim]
@@ -335,7 +335,6 @@ def _ragged_paged_attention_kernel(
335
335
  assert k.dtype == v.dtype
336
336
  head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
337
337
  head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
338
- head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
339
338
 
340
339
  def load_with_init(ref, init_val):
341
340
  return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
@@ -376,15 +375,32 @@ def _ragged_paged_attention_kernel(
376
375
  head_m_ref[...] = m_curr
377
376
  p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
378
377
 
379
- pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
380
- if v_scale is not None:
381
- pv *= v_scale
382
-
383
378
  p_rowsum = jnp.sum(p, axis=1, keepdims=True)
384
379
  exp_m_diff = jnp.exp(m_prev - m_curr)
385
380
  l_prev = load_with_init(head_l_ref, 0.0)
386
381
  l_curr = exp_m_diff * l_prev + p_rowsum
387
382
  head_l_ref[...] = l_curr
383
+
384
+ return p, exp_m_diff
385
+
386
+ def flash_attention_step2_pv(
387
+ q_shape_0,
388
+ v, # [bkv_sz, head_dim]
389
+ p, # from step1
390
+ exp_m_diff, # from step1
391
+ *,
392
+ bkv_idx,
393
+ kv_head_idx,
394
+ ):
395
+ head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
396
+
397
+ def load_with_init(ref, init_val):
398
+ return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
399
+ ref[...])
400
+
401
+ pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
402
+ if v_scale is not None:
403
+ pv *= v_scale
388
404
  o_prev = load_with_init(head_acc_ref, 0.0)
389
405
  o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
390
406
  head_acc_ref[...] = o_curr
@@ -440,42 +456,51 @@ def _ragged_paged_attention_kernel(
440
456
  debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
441
457
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
442
458
 
443
- # Fetch effective kv from kv cache.
444
- def loop_body(i, offset):
445
- sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
446
- _async_copy(
447
- cache_hbm_ref.at[pl.ds(
448
- page_indices_ref[page_indices_offset + i] * page_size,
449
- sz)],
450
- vmem_ref.at[pl.ds(i * page_size, sz)],
451
- sem,
452
- wait,
459
+ if not wait:
460
+ # Fetch effective kv from kv cache.
461
+ def loop_body(i, offset):
462
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
463
+ _async_copy(
464
+ cache_hbm_ref.at[pl.ds(
465
+ page_indices_ref[page_indices_offset + i] * page_size,
466
+ sz)],
467
+ vmem_ref.at[pl.ds(i * page_size, sz)],
468
+ sem,
469
+ wait=False,
470
+ )
471
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
472
+ return offset + sz
473
+
474
+ offset = lax.fori_loop(
475
+ 0,
476
+ bkv_p_frm_cache,
477
+ loop_body,
478
+ 0, # offset
479
+ unroll=False,
453
480
  )
454
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
455
- return offset + sz
456
-
457
- offset = lax.fori_loop(
458
- 0,
459
- bkv_p_frm_cache,
460
- loop_body,
461
- 0, # offset
462
- unroll=False,
463
- )
464
481
 
465
- # Fetch kv directly from new kv.
466
- @pl.when(bkv_sz_frm_new > 0)
467
- def _fetch_bkv_from_new_kv():
482
+ size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
468
483
  new_kv_len_start = q_end - kv_left_frm_new
469
484
  debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
470
485
  debug_print("[RPA debug] offset_in_bkv={}", offset)
471
486
  _async_copy(
472
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
473
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
487
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
488
+ vmem_ref.at[pl.ds(offset, size)],
474
489
  sem,
475
490
  wait,
476
491
  )
477
492
 
478
- return kv_len_start + offset, bkv_sz_frm_new
493
+ return kv_len_start + offset, bkv_sz_frm_new
494
+ else:
495
+ offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
496
+ dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
497
+ _async_copy(
498
+ src=dst,
499
+ dst=dst,
500
+ sem=sem,
501
+ wait=True,
502
+ )
503
+ return kv_len_start + offset, bkv_sz_frm_new
479
504
 
480
505
  def _update_kv_cache(seq_idx,
481
506
  bkv_sem_idx,
@@ -511,30 +536,41 @@ def _ragged_paged_attention_kernel(
511
536
  debug_print("[RPA debug] p_ignore={}", p_ignore)
512
537
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
513
538
 
514
- def loop_body(i, states):
515
- update_sz, ignore = states
516
- sz = jnp.minimum(page_size - ignore, update_sz)
517
-
539
+ if not wait:
540
+
541
+ def loop_body(i, states):
542
+ update_sz, ignore = states
543
+ sz = jnp.minimum(page_size - ignore, update_sz)
544
+
545
+ _async_copy(
546
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
547
+ sz)],
548
+ cache_hbm_ref.at[pl.ds(
549
+ page_indices_ref[page_indices_offset + i] * page_size +
550
+ ignore,
551
+ sz,
552
+ )],
553
+ sem,
554
+ wait=False,
555
+ )
556
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
557
+ return update_sz - sz, 0
558
+
559
+ lax.fori_loop(
560
+ 0,
561
+ kv_p_end - kv_p_start,
562
+ loop_body,
563
+ (update_sz, ignore), # total transfer size
564
+ unroll=False,
565
+ )
566
+ else:
567
+ dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
518
568
  _async_copy(
519
- vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
520
- cache_hbm_ref.at[pl.ds(
521
- page_indices_ref[page_indices_offset + i] * page_size +
522
- ignore,
523
- sz,
524
- )],
525
- sem,
526
- wait,
569
+ src=dst,
570
+ dst=dst,
571
+ sem=sem,
572
+ wait=True,
527
573
  )
528
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
529
- return update_sz - sz, 0
530
-
531
- lax.fori_loop(
532
- 0,
533
- kv_p_end - kv_p_start,
534
- loop_body,
535
- (update_sz, ignore), # total transfer size
536
- unroll=False,
537
- )
538
574
 
539
575
  def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
540
576
  sem = sems.at[1, bq_sem_idx]
@@ -819,6 +855,11 @@ def _ragged_paged_attention_kernel(
819
855
 
820
856
  # Flash attention with cur bkv and bq
821
857
  # NOTE: kv_packing is divided by 2 because k and v are packed together.
858
+ prev_bq_shape_0 = None
859
+ prev_kv_head_bv = None
860
+ prev_kv_head_idx = None
861
+ prev_kv_head_p = None
862
+ prev_kv_head_exp_m_diff = None
822
863
  heads_per_load = max(1, kv_packing // 2)
823
864
  for kv_head_start in range(0, actual_num_kv_heads,
824
865
  heads_per_load):
@@ -830,21 +871,53 @@ def _ragged_paged_attention_kernel(
830
871
  )
831
872
  assert len(bkv_lst) == heads_per_load
832
873
  for i in range(heads_per_load):
833
- kv_head_idx = kv_head_start + i
834
- if kv_head_idx >= actual_num_kv_heads:
874
+ cur_kv_head_idx = kv_head_start + i
875
+ if cur_kv_head_idx >= actual_num_kv_heads:
835
876
  break
836
- bq = load_bq(bq_sem_idx,
837
- kv_head_idx,
838
- actual_bq_sz=actual_bq_sz)
877
+
878
+ cur_kv_head_bq = load_bq(bq_sem_idx,
879
+ cur_kv_head_idx,
880
+ actual_bq_sz=actual_bq_sz)
839
881
  bk, bv = bkv_lst[i]
840
- flash_attention(
841
- bq,
842
- bk,
843
- bv,
844
- bq_idx=bq_idx,
845
- bkv_idx=bkv_idx,
846
- kv_head_idx=kv_head_idx,
847
- )
882
+ # FlashAttention is divided into `flash_attention_step1_qk_softmax`
883
+ # and `flash_attention_step2_pv` to pipeline the computation.
884
+ # `step2_pv` for the previous KV head, which depends on the softmax
885
+ # output, is overlapped with `step1_qk_softmax` for the current KV
886
+ # head, reducing overall wait times.
887
+ cur_kv_head_p, cur_kv_head_exp_m_diff = (
888
+ flash_attention_step1_qk_softmax(
889
+ cur_kv_head_bq,
890
+ bk,
891
+ bv,
892
+ bq_idx=bq_idx,
893
+ bkv_idx=bkv_idx,
894
+ kv_head_idx=cur_kv_head_idx,
895
+ ))
896
+ if prev_bq_shape_0 is not None:
897
+ flash_attention_step2_pv(
898
+ prev_bq_shape_0,
899
+ prev_kv_head_bv,
900
+ prev_kv_head_p,
901
+ prev_kv_head_exp_m_diff,
902
+ bkv_idx=bkv_idx,
903
+ kv_head_idx=prev_kv_head_idx,
904
+ )
905
+ prev_bq_shape_0 = cur_kv_head_bq.shape[0]
906
+ prev_kv_head_bv = bv
907
+ prev_kv_head_p = cur_kv_head_p
908
+ prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
909
+ prev_kv_head_idx = cur_kv_head_idx
910
+
911
+ # Execute pv of last attention head.
912
+ assert prev_bq_shape_0 is not None
913
+ flash_attention_step2_pv(
914
+ prev_bq_shape_0,
915
+ prev_kv_head_bv,
916
+ prev_kv_head_p,
917
+ prev_kv_head_exp_m_diff,
918
+ bkv_idx=bkv_idx,
919
+ kv_head_idx=prev_kv_head_idx,
920
+ )
848
921
 
849
922
  lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
850
923