tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (56) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_envs.py +11 -32
  6. tests/test_utils.py +2 -1
  7. tpu_inference/__init__.py +3 -22
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +2 -9
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
  16. tpu_inference/layers/common/attention_interface.py +1 -7
  17. tpu_inference/layers/common/sharding.py +5 -5
  18. tpu_inference/layers/vllm/fused_moe.py +208 -170
  19. tpu_inference/layers/vllm/quantization/common.py +1 -6
  20. tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
  21. tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
  22. tpu_inference/layers/vllm/sharding.py +2 -2
  23. tpu_inference/lora/torch_punica_tpu.py +2 -1
  24. tpu_inference/mock/__init__.py +0 -0
  25. tpu_inference/mock/vllm_config_utils.py +28 -0
  26. tpu_inference/mock/vllm_envs.py +1219 -0
  27. tpu_inference/mock/vllm_logger.py +212 -0
  28. tpu_inference/mock/vllm_logging_utils.py +15 -0
  29. tpu_inference/models/common/model_loader.py +10 -43
  30. tpu_inference/models/jax/llama3.py +1 -2
  31. tpu_inference/models/jax/llama_eagle3.py +5 -8
  32. tpu_inference/models/jax/phi3.py +376 -0
  33. tpu_inference/models/jax/qwen2.py +1 -2
  34. tpu_inference/models/jax/qwen2_5_vl.py +48 -163
  35. tpu_inference/models/jax/qwen3.py +1 -2
  36. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  37. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  38. tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
  39. tpu_inference/platforms/tpu_platform.py +31 -37
  40. tpu_inference/runner/compilation_manager.py +58 -141
  41. tpu_inference/runner/kv_cache.py +1 -1
  42. tpu_inference/runner/kv_cache_manager.py +18 -17
  43. tpu_inference/runner/persistent_batch_manager.py +2 -40
  44. tpu_inference/runner/structured_decoding_manager.py +3 -2
  45. tpu_inference/runner/tpu_runner.py +147 -271
  46. tpu_inference/runner/utils.py +2 -2
  47. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  48. tpu_inference/tpu_info.py +3 -4
  49. tpu_inference/utils.py +13 -36
  50. tpu_inference/worker/tpu_worker.py +25 -162
  51. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
  52. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
  53. tpu_inference/models/jax/llama_guard_4.py +0 -361
  54. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
  55. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
  56. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
@@ -440,54 +440,42 @@ def _ragged_paged_attention_kernel(
440
440
  debug_print("[RPA debug] bkv_sz_frm_new={}", bkv_sz_frm_new)
441
441
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
442
442
 
443
- if not wait:
444
- # Fetch effective kv from kv cache.
445
- def loop_body(i, offset):
446
- sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
447
- _async_copy(
448
- cache_hbm_ref.at[pl.ds(
449
- page_indices_ref[page_indices_offset + i] * page_size,
450
- sz)],
451
- vmem_ref.at[pl.ds(i * page_size, sz)],
452
- sem,
453
- wait=False,
454
- )
455
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
456
- return offset + sz
457
-
458
- offset = lax.fori_loop(
459
- 0,
460
- bkv_p_frm_cache,
461
- loop_body,
462
- 0, # offset
463
- unroll=False,
443
+ # Fetch effective kv from kv cache.
444
+ def loop_body(i, offset):
445
+ sz = jnp.minimum(page_size, kv_left_frm_cache - i * page_size)
446
+ _async_copy(
447
+ cache_hbm_ref.at[pl.ds(
448
+ page_indices_ref[page_indices_offset + i] * page_size,
449
+ sz)],
450
+ vmem_ref.at[pl.ds(i * page_size, sz)],
451
+ sem,
452
+ wait,
464
453
  )
454
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
455
+ return offset + sz
456
+
457
+ offset = lax.fori_loop(
458
+ 0,
459
+ bkv_p_frm_cache,
460
+ loop_body,
461
+ 0, # offset
462
+ unroll=False,
463
+ )
465
464
 
466
- # Fetch kv directly from new kv.
467
- @pl.when(bkv_sz_frm_new > 0)
468
- def _fetch_bkv_from_new_kv():
469
- new_kv_len_start = q_end - kv_left_frm_new
470
- debug_print("[RPA debug] new_kv_len_start={}",
471
- new_kv_len_start)
472
- debug_print("[RPA debug] offset_in_bkv={}", offset)
473
- _async_copy(
474
- kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
475
- vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
476
- sem,
477
- wait,
478
- )
479
-
480
- return kv_len_start + offset, bkv_sz_frm_new
481
- else:
482
- offset = jnp.minimum(kv_left_frm_cache, page_size * bkv_p)
483
- dst = vmem_ref.at[pl.ds(0, offset + bkv_sz_frm_new)]
465
+ # Fetch kv directly from new kv.
466
+ @pl.when(bkv_sz_frm_new > 0)
467
+ def _fetch_bkv_from_new_kv():
468
+ new_kv_len_start = q_end - kv_left_frm_new
469
+ debug_print("[RPA debug] new_kv_len_start={}", new_kv_len_start)
470
+ debug_print("[RPA debug] offset_in_bkv={}", offset)
484
471
  _async_copy(
485
- src=dst,
486
- dst=dst,
487
- sem=sem,
488
- wait=True,
472
+ kv_hbm_ref.at[pl.ds(new_kv_len_start, bkv_sz_frm_new)],
473
+ vmem_ref.at[pl.ds(offset, bkv_sz_frm_new)],
474
+ sem,
475
+ wait,
489
476
  )
490
- return kv_len_start + offset, bkv_sz_frm_new
477
+
478
+ return kv_len_start + offset, bkv_sz_frm_new
491
479
 
492
480
  def _update_kv_cache(seq_idx,
493
481
  bkv_sem_idx,
@@ -523,41 +511,30 @@ def _ragged_paged_attention_kernel(
523
511
  debug_print("[RPA debug] p_ignore={}", p_ignore)
524
512
  debug_print("[RPA debug] page_indices_offset={}", page_indices_offset)
525
513
 
526
- if not wait:
527
-
528
- def loop_body(i, states):
529
- update_sz, ignore = states
530
- sz = jnp.minimum(page_size - ignore, update_sz)
531
-
532
- _async_copy(
533
- vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore,
534
- sz)],
535
- cache_hbm_ref.at[pl.ds(
536
- page_indices_ref[page_indices_offset + i] * page_size +
537
- ignore,
538
- sz,
539
- )],
540
- sem,
541
- wait=False,
542
- )
543
- debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
544
- return update_sz - sz, 0
545
-
546
- lax.fori_loop(
547
- 0,
548
- kv_p_end - kv_p_start,
549
- loop_body,
550
- (update_sz, ignore), # total transfer size
551
- unroll=False,
552
- )
553
- else:
554
- dst = cache_hbm_ref.at[pl.ds(0, update_sz)]
514
+ def loop_body(i, states):
515
+ update_sz, ignore = states
516
+ sz = jnp.minimum(page_size - ignore, update_sz)
517
+
555
518
  _async_copy(
556
- src=dst,
557
- dst=dst,
558
- sem=sem,
559
- wait=True,
519
+ vmem_ref.at[pl.ds((p_ignore + i) * page_size + ignore, sz)],
520
+ cache_hbm_ref.at[pl.ds(
521
+ page_indices_ref[page_indices_offset + i] * page_size +
522
+ ignore,
523
+ sz,
524
+ )],
525
+ sem,
526
+ wait,
560
527
  )
528
+ debug_print("[RPA debug] loop_body i={}, sz={}", i, sz)
529
+ return update_sz - sz, 0
530
+
531
+ lax.fori_loop(
532
+ 0,
533
+ kv_p_end - kv_p_start,
534
+ loop_body,
535
+ (update_sz, ignore), # total transfer size
536
+ unroll=False,
537
+ )
561
538
 
562
539
  def _fetch_bq(seq_idx, bq_idx, bq_sem_idx, *, wait=False):
563
540
  sem = sems.at[1, bq_sem_idx]