tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -7
- tests/lora/test_lora_perf.py +53 -0
- tests/lora/utils.py +0 -8
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +93 -9
- tpu_inference/executors/ray_distributed_executor.py +9 -2
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +84 -28
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +205 -144
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
- tpu_inference/platforms/tpu_platform.py +34 -50
- tpu_inference/runner/compilation_manager.py +144 -60
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +48 -33
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +280 -149
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +46 -18
- tpu_inference/worker/tpu_worker.py +197 -63
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -319,7 +319,7 @@ def _ragged_paged_attention_kernel(
|
|
|
319
319
|
debug_print("[RPA debug] q_len={}", q_len)
|
|
320
320
|
debug_print("[RPA debug] kv_len={}", kv_len)
|
|
321
321
|
|
|
322
|
-
def
|
|
322
|
+
def flash_attention_step1_qk_softmax(
|
|
323
323
|
q, # [actual_bq_sz * num_q_heads_per_kv_head, head_dim]
|
|
324
324
|
k, # [bkv_sz, head_dim]
|
|
325
325
|
v, # [bkv_sz, head_dim]
|
|
@@ -335,7 +335,6 @@ def _ragged_paged_attention_kernel(
|
|
|
335
335
|
assert k.dtype == v.dtype
|
|
336
336
|
head_l_ref = l_ref.at[kv_head_idx, :q.shape[0]]
|
|
337
337
|
head_m_ref = m_ref.at[kv_head_idx, :q.shape[0]]
|
|
338
|
-
head_acc_ref = acc_ref.at[kv_head_idx, :q.shape[0]]
|
|
339
338
|
|
|
340
339
|
def load_with_init(ref, init_val):
|
|
341
340
|
return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
|
|
@@ -376,15 +375,32 @@ def _ragged_paged_attention_kernel(
|
|
|
376
375
|
head_m_ref[...] = m_curr
|
|
377
376
|
p = jnp.exp(s - broadcast_minor(m_curr, s.shape))
|
|
378
377
|
|
|
379
|
-
pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
|
|
380
|
-
if v_scale is not None:
|
|
381
|
-
pv *= v_scale
|
|
382
|
-
|
|
383
378
|
p_rowsum = jnp.sum(p, axis=1, keepdims=True)
|
|
384
379
|
exp_m_diff = jnp.exp(m_prev - m_curr)
|
|
385
380
|
l_prev = load_with_init(head_l_ref, 0.0)
|
|
386
381
|
l_curr = exp_m_diff * l_prev + p_rowsum
|
|
387
382
|
head_l_ref[...] = l_curr
|
|
383
|
+
|
|
384
|
+
return p, exp_m_diff
|
|
385
|
+
|
|
386
|
+
def flash_attention_step2_pv(
|
|
387
|
+
q_shape_0,
|
|
388
|
+
v, # [bkv_sz, head_dim]
|
|
389
|
+
p, # from step1
|
|
390
|
+
exp_m_diff, # from step1
|
|
391
|
+
*,
|
|
392
|
+
bkv_idx,
|
|
393
|
+
kv_head_idx,
|
|
394
|
+
):
|
|
395
|
+
head_acc_ref = acc_ref.at[kv_head_idx, :q_shape_0]
|
|
396
|
+
|
|
397
|
+
def load_with_init(ref, init_val):
|
|
398
|
+
return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val),
|
|
399
|
+
ref[...])
|
|
400
|
+
|
|
401
|
+
pv = jnp.einsum("nm,md->nd", p, v, preferred_element_type=jnp.float32)
|
|
402
|
+
if v_scale is not None:
|
|
403
|
+
pv *= v_scale
|
|
388
404
|
o_prev = load_with_init(head_acc_ref, 0.0)
|
|
389
405
|
o_curr = broadcast_minor(exp_m_diff, o_prev.shape) * o_prev + pv
|
|
390
406
|
head_acc_ref[...] = o_curr
|
|
@@ -440,42 +456,51 @@ def _ragged_paged_attention_kernel(
|
|
|
440
456
|
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
|
|
441
457
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
442
458
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
459
|
+
if not wait:
|
|
460
|
+
# Fetch effective kv from kv cache.
|
|
461
|
+
def loop_body(i, offset):
|
|
462
|
+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
|
|
463
|
+
_async_copy(
|
|
464
|
+
cache_hbm_ref.at[pl.ds(
|
|
465
|
+
page_indices_ref[page_indices_offset + i] * page_size,
|
|
466
|
+
sz)],
|
|
467
|
+
vmem_ref.at[pl.ds(i * page_size, sz)],
|
|
468
|
+
sem,
|
|
469
|
+
wait=False,
|
|
470
|
+
)
|
|
471
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
472
|
+
return offset + sz
|
|
473
|
+
|
|
474
|
+
offset = lax.fori_loop(
|
|
475
|
+
0,
|
|
476
|
+
bkv_p_frm_cache,
|
|
477
|
+
loop_body,
|
|
478
|
+
0, # offset
|
|
479
|
+
unroll=False,
|
|
453
480
|
)
|
|
454
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
455
|
-
return offset + sz
|
|
456
|
-
|
|
457
|
-
offset = lax.fori_loop(
|
|
458
|
-
0,
|
|
459
|
-
bkv_p_frm_cache,
|
|
460
|
-
loop_body,
|
|
461
|
-
0, # offset
|
|
462
|
-
unroll=False,
|
|
463
|
-
)
|
|
464
481
|
|
|
465
|
-
|
|
466
|
-
@pl.when(bkv_sz_frm_new > 0)
|
|
467
|
-
def _fetch_bkv_from_new_kv():
|
|
482
|
+
size = lax.select(bkv_sz_frm_new > 0, bkv_sz_frm_new, 0)
|
|
468
483
|
new_kv_len_start = q_end - kv_left_frm_new
|
|
469
484
|
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
|
|
470
485
|
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
471
486
|
_async_copy(
|
|
472
|
-
kv_hbm_ref.at[pl.ds(new_kv_len_start,
|
|
473
|
-
vmem_ref.at[pl.ds(offset,
|
|
487
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, size)],
|
|
488
|
+
vmem_ref.at[pl.ds(offset, size)],
|
|
474
489
|
sem,
|
|
475
490
|
wait,
|
|
476
491
|
)
|
|
477
492
|
|
|
478
|
-
|
|
493
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
494
|
+
else:
|
|
495
|
+
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
|
|
496
|
+
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
|
|
497
|
+
_async_copy(
|
|
498
|
+
src=dst,
|
|
499
|
+
dst=dst,
|
|
500
|
+
sem=sem,
|
|
501
|
+
wait=True,
|
|
502
|
+
)
|
|
503
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
479
504
|
|
|
480
505
|
def _update_kv_cache(seq_idx,
|
|
481
506
|
bkv_sem_idx,
|
|
@@ -511,30 +536,41 @@ def _ragged_paged_attention_kernel(
|
|
|
511
536
|
debug_print("[RPA debug] p_ignore={}", p_ignore)
|
|
512
537
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
513
538
|
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
539
|
+
if not wait:
|
|
540
|
+
|
|
541
|
+
def loop_body(i, states):
|
|
542
|
+
update_sz, ignore = states
|
|
543
|
+
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
544
|
+
|
|
545
|
+
_async_copy(
|
|
546
|
+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
|
|
547
|
+
sz)],
|
|
548
|
+
cache_hbm_ref.at[pl.ds(
|
|
549
|
+
page_indices_ref[page_indices_offset + i] * page_size +
|
|
550
|
+
ignore,
|
|
551
|
+
sz,
|
|
552
|
+
)],
|
|
553
|
+
sem,
|
|
554
|
+
wait=False,
|
|
555
|
+
)
|
|
556
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
557
|
+
return update_sz - sz, 0
|
|
558
|
+
|
|
559
|
+
lax.fori_loop(
|
|
560
|
+
0,
|
|
561
|
+
kv_p_end - kv_p_start,
|
|
562
|
+
loop_body,
|
|
563
|
+
(update_sz, ignore), # total transfer size
|
|
564
|
+
unroll=False,
|
|
565
|
+
)
|
|
566
|
+
else:
|
|
567
|
+
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
|
|
518
568
|
_async_copy(
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
sz,
|
|
524
|
-
)],
|
|
525
|
-
sem,
|
|
526
|
-
wait,
|
|
569
|
+
src=dst,
|
|
570
|
+
dst=dst,
|
|
571
|
+
sem=sem,
|
|
572
|
+
wait=True,
|
|
527
573
|
)
|
|
528
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
529
|
-
return update_sz - sz, 0
|
|
530
|
-
|
|
531
|
-
lax.fori_loop(
|
|
532
|
-
0,
|
|
533
|
-
kv_p_end - kv_p_start,
|
|
534
|
-
loop_body,
|
|
535
|
-
(update_sz, ignore), # total transfer size
|
|
536
|
-
unroll=False,
|
|
537
|
-
)
|
|
538
574
|
|
|
539
575
|
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
|
|
540
576
|
sem = sems.at[1, bq_sem_idx]
|
|
@@ -819,6 +855,11 @@ def _ragged_paged_attention_kernel(
|
|
|
819
855
|
|
|
820
856
|
# Flash attention with cur bkv and bq
|
|
821
857
|
# NOTE: kv_packing is divided by 2 because k and v are packed together.
|
|
858
|
+
prev_bq_shape_0 = None
|
|
859
|
+
prev_kv_head_bv = None
|
|
860
|
+
prev_kv_head_idx = None
|
|
861
|
+
prev_kv_head_p = None
|
|
862
|
+
prev_kv_head_exp_m_diff = None
|
|
822
863
|
heads_per_load = max(1, kv_packing // 2)
|
|
823
864
|
for kv_head_start in range(0, actual_num_kv_heads,
|
|
824
865
|
heads_per_load):
|
|
@@ -830,21 +871,53 @@ def _ragged_paged_attention_kernel(
|
|
|
830
871
|
)
|
|
831
872
|
assert len(bkv_lst) == heads_per_load
|
|
832
873
|
for i in range(heads_per_load):
|
|
833
|
-
|
|
834
|
-
if
|
|
874
|
+
cur_kv_head_idx = kv_head_start + i
|
|
875
|
+
if cur_kv_head_idx >= actual_num_kv_heads:
|
|
835
876
|
break
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
877
|
+
|
|
878
|
+
cur_kv_head_bq = load_bq(bq_sem_idx,
|
|
879
|
+
cur_kv_head_idx,
|
|
880
|
+
actual_bq_sz=actual_bq_sz)
|
|
839
881
|
bk, bv = bkv_lst[i]
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
882
|
+
# FlashAttention is divided into `flash_attention_step1_qk_softmax`
|
|
883
|
+
# and `flash_attention_step2_pv` to pipeline the computation.
|
|
884
|
+
# `step2_pv` for the previous KV head, which depends on the softmax
|
|
885
|
+
# output, is overlapped with `step1_qk_softmax` for the current KV
|
|
886
|
+
# head, reducing overall wait times.
|
|
887
|
+
cur_kv_head_p, cur_kv_head_exp_m_diff = (
|
|
888
|
+
flash_attention_step1_qk_softmax(
|
|
889
|
+
cur_kv_head_bq,
|
|
890
|
+
bk,
|
|
891
|
+
bv,
|
|
892
|
+
bq_idx=bq_idx,
|
|
893
|
+
bkv_idx=bkv_idx,
|
|
894
|
+
kv_head_idx=cur_kv_head_idx,
|
|
895
|
+
))
|
|
896
|
+
if prev_bq_shape_0 is not None:
|
|
897
|
+
flash_attention_step2_pv(
|
|
898
|
+
prev_bq_shape_0,
|
|
899
|
+
prev_kv_head_bv,
|
|
900
|
+
prev_kv_head_p,
|
|
901
|
+
prev_kv_head_exp_m_diff,
|
|
902
|
+
bkv_idx=bkv_idx,
|
|
903
|
+
kv_head_idx=prev_kv_head_idx,
|
|
904
|
+
)
|
|
905
|
+
prev_bq_shape_0 = cur_kv_head_bq.shape[0]
|
|
906
|
+
prev_kv_head_bv = bv
|
|
907
|
+
prev_kv_head_p = cur_kv_head_p
|
|
908
|
+
prev_kv_head_exp_m_diff = cur_kv_head_exp_m_diff
|
|
909
|
+
prev_kv_head_idx = cur_kv_head_idx
|
|
910
|
+
|
|
911
|
+
# Execute pv of last attention head.
|
|
912
|
+
assert prev_bq_shape_0 is not None
|
|
913
|
+
flash_attention_step2_pv(
|
|
914
|
+
prev_bq_shape_0,
|
|
915
|
+
prev_kv_head_bv,
|
|
916
|
+
prev_kv_head_p,
|
|
917
|
+
prev_kv_head_exp_m_diff,
|
|
918
|
+
bkv_idx=bkv_idx,
|
|
919
|
+
kv_head_idx=prev_kv_head_idx,
|
|
920
|
+
)
|
|
848
921
|
|
|
849
922
|
lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
|
|
850
923
|
|