tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_envs.py +11 -32
- tests/test_utils.py +2 -1
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +2 -9
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
- tpu_inference/layers/common/attention_interface.py +1 -7
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
- tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +10 -43
- tpu_inference/models/jax/llama3.py +1 -2
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +1 -2
- tpu_inference/models/jax/qwen2_5_vl.py +48 -163
- tpu_inference/models/jax/qwen3.py +1 -2
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
- tpu_inference/platforms/tpu_platform.py +31 -37
- tpu_inference/runner/compilation_manager.py +58 -141
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +147 -271
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +13 -36
- tpu_inference/worker/tpu_worker.py +25 -162
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
|
@@ -440,54 +440,42 @@ def _ragged_paged_attention_kernel(
|
|
|
440
440
|
debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
|
|
441
441
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
442
442
|
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
wait=False,
|
|
454
|
-
)
|
|
455
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
456
|
-
return offset + sz
|
|
457
|
-
|
|
458
|
-
offset = lax.fori_loop(
|
|
459
|
-
0,
|
|
460
|
-
bkv_p_frm_cache,
|
|
461
|
-
loop_body,
|
|
462
|
-
0, # offset
|
|
463
|
-
unroll=False,
|
|
443
|
+
# Fetch effective kv from kv cache.
|
|
444
|
+
def loop_body(i, offset):
|
|
445
|
+
sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
|
|
446
|
+
_async_copy(
|
|
447
|
+
cache_hbm_ref.at[pl.ds(
|
|
448
|
+
page_indices_ref[page_indices_offset + i] * page_size,
|
|
449
|
+
sz)],
|
|
450
|
+
vmem_ref.at[pl.ds(i * page_size, sz)],
|
|
451
|
+
sem,
|
|
452
|
+
wait,
|
|
464
453
|
)
|
|
454
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
455
|
+
return offset + sz
|
|
456
|
+
|
|
457
|
+
offset = lax.fori_loop(
|
|
458
|
+
0,
|
|
459
|
+
bkv_p_frm_cache,
|
|
460
|
+
loop_body,
|
|
461
|
+
0, # offset
|
|
462
|
+
unroll=False,
|
|
463
|
+
)
|
|
465
464
|
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
473
|
-
_async_copy(
|
|
474
|
-
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
|
|
475
|
-
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
|
|
476
|
-
sem,
|
|
477
|
-
wait,
|
|
478
|
-
)
|
|
479
|
-
|
|
480
|
-
return kv_len_start + offset, bkv_sz_frm_new
|
|
481
|
-
else:
|
|
482
|
-
offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
|
|
483
|
-
dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
|
|
465
|
+
# Fetch kv directly from new kv.
|
|
466
|
+
@pl.when(bkv_sz_frm_new > 0)
|
|
467
|
+
def _fetch_bkv_from_new_kv():
|
|
468
|
+
new_kv_len_start = q_end - kv_left_frm_new
|
|
469
|
+
debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
|
|
470
|
+
debug_print("[RPA debug] offset_in_bkv={}", offset)
|
|
484
471
|
_async_copy(
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
sem
|
|
488
|
-
wait
|
|
472
|
+
kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
|
|
473
|
+
vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
|
|
474
|
+
sem,
|
|
475
|
+
wait,
|
|
489
476
|
)
|
|
490
|
-
|
|
477
|
+
|
|
478
|
+
return kv_len_start + offset, bkv_sz_frm_new
|
|
491
479
|
|
|
492
480
|
def _update_kv_cache(seq_idx,
|
|
493
481
|
bkv_sem_idx,
|
|
@@ -523,41 +511,30 @@ def _ragged_paged_attention_kernel(
|
|
|
523
511
|
debug_print("[RPA debug] p_ignore={}", p_ignore)
|
|
524
512
|
debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
|
|
525
513
|
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
531
|
-
|
|
532
|
-
_async_copy(
|
|
533
|
-
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
|
|
534
|
-
sz)],
|
|
535
|
-
cache_hbm_ref.at[pl.ds(
|
|
536
|
-
page_indices_ref[page_indices_offset + i] * page_size +
|
|
537
|
-
ignore,
|
|
538
|
-
sz,
|
|
539
|
-
)],
|
|
540
|
-
sem,
|
|
541
|
-
wait=False,
|
|
542
|
-
)
|
|
543
|
-
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
544
|
-
return update_sz - sz, 0
|
|
545
|
-
|
|
546
|
-
lax.fori_loop(
|
|
547
|
-
0,
|
|
548
|
-
kv_p_end - kv_p_start,
|
|
549
|
-
loop_body,
|
|
550
|
-
(update_sz, ignore), # total transfer size
|
|
551
|
-
unroll=False,
|
|
552
|
-
)
|
|
553
|
-
else:
|
|
554
|
-
dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
|
|
514
|
+
def loop_body(i, states):
|
|
515
|
+
update_sz, ignore = states
|
|
516
|
+
sz = jnp.minimum(page_size - ignore, update_sz)
|
|
517
|
+
|
|
555
518
|
_async_copy(
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
519
|
+
vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
|
|
520
|
+
cache_hbm_ref.at[pl.ds(
|
|
521
|
+
page_indices_ref[page_indices_offset + i] * page_size +
|
|
522
|
+
ignore,
|
|
523
|
+
sz,
|
|
524
|
+
)],
|
|
525
|
+
sem,
|
|
526
|
+
wait,
|
|
560
527
|
)
|
|
528
|
+
debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
|
|
529
|
+
return update_sz - sz, 0
|
|
530
|
+
|
|
531
|
+
lax.fori_loop(
|
|
532
|
+
0,
|
|
533
|
+
kv_p_end - kv_p_start,
|
|
534
|
+
loop_body,
|
|
535
|
+
(update_sz, ignore), # total transfer size
|
|
536
|
+
unroll=False,
|
|
537
|
+
)
|
|
561
538
|
|
|
562
539
|
def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
|
|
563
540
|
sem = sems.at[1, bq_sem_idx]
|