tpu-inference 0.12.0.dev20251207__tar.gz → 0.12.0.dev20251219__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (193) hide show
  1. {tpu_inference-0.12.0.dev20251207/tpu_inference.egg-info → tpu_inference-0.12.0.dev20251219}/PKG-INFO +5 -7
  2. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/README.md +4 -6
  3. tpu_inference-0.12.0.dev20251219/tests/kernels/gmm_test.py +191 -0
  4. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/quantized_matmul_kernel_test.py +2 -34
  5. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  6. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  7. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/lora/test_layers.py +7 -3
  8. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/lora/test_lora.py +1 -1
  9. tpu_inference-0.12.0.dev20251219/tests/lora/test_lora_perf.py +53 -0
  10. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/test_envs.py +78 -1
  11. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/distributed/tpu_connector.py +3 -3
  12. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/envs.py +38 -7
  13. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/executors/ray_distributed_executor.py +3 -0
  14. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  15. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  16. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/fused_moe/v1/kernel.py +357 -324
  17. tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/megablox/common.py +41 -0
  18. tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/megablox/gmm.py +633 -0
  19. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  20. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  21. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  22. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +181 -101
  23. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +82 -78
  24. tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4447 -0
  25. tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +535 -0
  26. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  27. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/common/attention_interface.py +1 -7
  28. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/common/quant_methods.py +1 -0
  29. tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/pp_utils.py +39 -0
  30. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/fused_moe.py +87 -67
  31. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/linear_common.py +43 -21
  32. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/quantization/__init__.py +2 -0
  33. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/quantization/awq.py +1 -1
  34. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/quantization/common.py +5 -5
  35. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  36. tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +252 -0
  37. tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/fp8.py +104 -0
  38. tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/mxfp4.py +448 -0
  39. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/quantization/unquantized.py +83 -47
  40. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/lora/torch_lora_ops.py +8 -13
  41. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/common/model_loader.py +43 -18
  42. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/llama3.py +79 -33
  43. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/weight_utils.py +19 -1
  44. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/vllm/vllm_model_wrapper.py +1 -1
  45. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/platforms/tpu_platform.py +8 -34
  46. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/kv_cache.py +3 -1
  47. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/tpu_runner.py +5 -5
  48. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/utils.py +2 -1
  49. tpu_inference-0.12.0.dev20251219/tpu_inference/worker/__init__.py +0 -0
  50. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/worker/tpu_worker.py +22 -36
  51. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219/tpu_inference.egg-info}/PKG-INFO +5 -7
  52. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference.egg-info/SOURCES.txt +7 -0
  53. tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -4147
  54. tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +0 -367
  55. tpu_inference-0.12.0.dev20251207/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +0 -203
  56. tpu_inference-0.12.0.dev20251207/tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
  57. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/LICENSE +0 -0
  58. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/MANIFEST.in +0 -0
  59. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/pyproject.toml +0 -0
  60. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/requirements.txt +0 -0
  61. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/setup.cfg +0 -0
  62. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/setup.py +0 -0
  63. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/__init__.py +0 -0
  64. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/core/__init__.py +0 -0
  65. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/core/test_core_tpu.py +0 -0
  66. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/core/test_disagg_executor.py +0 -0
  67. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/core/test_disagg_utils.py +0 -0
  68. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/core/test_dp_scheduler.py +0 -0
  69. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/core/test_init.py +0 -0
  70. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/__init__.py +0 -0
  71. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/fused_moe_v1_test.py +0 -0
  72. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/mla_v1_test.py +0 -0
  73. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/ragged_kv_cache_update_v2_test.py +0 -0
  74. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/ragged_paged_attention_kernel_v2_test.py +0 -0
  75. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/lora/__init__.py +0 -0
  76. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/lora/conftest.py +0 -0
  77. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/lora/test_bgmv.py +0 -0
  78. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/lora/utils.py +0 -0
  79. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/test_base.py +0 -0
  80. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/test_quantization.py +0 -0
  81. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/test_tpu_info.py +0 -0
  82. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/test_utils.py +0 -0
  83. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/__init__.py +0 -0
  84. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/core/__init__.py +0 -0
  85. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/core/core_tpu.py +0 -0
  86. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/core/disagg_executor.py +0 -0
  87. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/core/disagg_utils.py +0 -0
  88. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/core/sched/__init__.py +0 -0
  89. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/core/sched/dp_scheduler.py +0 -0
  90. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/distributed/__init__.py +0 -0
  91. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/distributed/jax_parallel_state.py +0 -0
  92. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/distributed/utils.py +0 -0
  93. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/env_override.py +0 -0
  94. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/executors/__init__.py +0 -0
  95. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/experimental/__init__.py +0 -0
  96. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/experimental/llama3_jax_stashed.py +0 -0
  97. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/__init__.py +0 -0
  98. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/collectives/__init__.py +0 -0
  99. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/collectives/util.py +0 -0
  100. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/flash_attention/__init__.py +0 -0
  101. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/flash_attention/kernel.py +0 -0
  102. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/fused_moe/__init__.py +0 -0
  103. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  104. {tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/mla → tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/megablox}/__init__.py +0 -0
  105. {tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/mla/v1 → tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/mla}/__init__.py +0 -0
  106. {tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/quantized_matmul → tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/mla/v1}/__init__.py +0 -0
  107. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/mla/v1/kernel.py +0 -0
  108. {tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/ragged_paged_attention → tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/quantized_matmul}/__init__.py +0 -0
  109. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +0 -0
  110. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/quantized_matmul/util.py +0 -0
  111. {tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/ragged_paged_attention/v2 → tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention}/__init__.py +0 -0
  112. {tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/ragged_paged_attention/v3 → tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v2}/__init__.py +0 -0
  113. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +0 -0
  114. {tpu_inference-0.12.0.dev20251207/tpu_inference/layers → tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v3}/__init__.py +0 -0
  115. {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/common → tpu_inference-0.12.0.dev20251219/tpu_inference/layers}/__init__.py +0 -0
  116. {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/jax → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/common}/__init__.py +0 -0
  117. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/common/attention_metadata.py +0 -0
  118. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/common/binary_search.py +0 -0
  119. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/common/sharding.py +0 -0
  120. {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/jax/attention → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax}/__init__.py +0 -0
  121. {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/jax/moe → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/attention}/__init__.py +0 -0
  122. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/attention/attention.py +0 -0
  123. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/attention/deepseek_v3_attention.py +0 -0
  124. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/attention/gpt_oss_attention.py +0 -0
  125. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/attention/llama4_attention.py +0 -0
  126. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/base.py +0 -0
  127. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/constants.py +0 -0
  128. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/layers.py +0 -0
  129. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/misc.py +0 -0
  130. {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/jax/sample → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/moe}/__init__.py +0 -0
  131. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/moe/deepseek_v3_moe.py +0 -0
  132. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/moe/gpt_oss_moe.py +0 -0
  133. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/moe/moe.py +0 -0
  134. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/rope.py +0 -0
  135. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/rope_interface.py +0 -0
  136. {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/vllm → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/sample}/__init__.py +0 -0
  137. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/sample/rejection_sampler.py +0 -0
  138. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/sample/sampling.py +0 -0
  139. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/sample/sampling_metadata.py +0 -0
  140. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/transformer_block.py +0 -0
  141. {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/vllm/quantization/compressed_tensors → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm}/__init__.py +0 -0
  142. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/attention.py +0 -0
  143. {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/compressed_tensors}/__init__.py +0 -0
  144. {tpu_inference-0.12.0.dev20251207/tpu_inference/lora → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes}/__init__.py +0 -0
  145. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +0 -0
  146. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +0 -0
  147. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/sharding.py +0 -0
  148. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/logger.py +0 -0
  149. {tpu_inference-0.12.0.dev20251207/tpu_inference/models → tpu_inference-0.12.0.dev20251219/tpu_inference/lora}/__init__.py +0 -0
  150. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/lora/torch_punica_tpu.py +0 -0
  151. {tpu_inference-0.12.0.dev20251207/tpu_inference/models/common → tpu_inference-0.12.0.dev20251219/tpu_inference/models}/__init__.py +0 -0
  152. {tpu_inference-0.12.0.dev20251207/tpu_inference/models/jax → tpu_inference-0.12.0.dev20251219/tpu_inference/models/common}/__init__.py +0 -0
  153. {tpu_inference-0.12.0.dev20251207/tpu_inference/models/jax/utils → tpu_inference-0.12.0.dev20251219/tpu_inference/models/jax}/__init__.py +0 -0
  154. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/deepseek_v3.py +0 -0
  155. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/gpt_oss.py +0 -0
  156. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/jax_intermediate_tensor.py +0 -0
  157. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/llama4.py +0 -0
  158. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/llama_eagle3.py +0 -0
  159. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/llama_guard_4.py +0 -0
  160. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/qwen2.py +0 -0
  161. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/qwen2_5_vl.py +0 -0
  162. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/qwen3.py +0 -0
  163. {tpu_inference-0.12.0.dev20251207/tpu_inference/models/jax/utils/quantization → tpu_inference-0.12.0.dev20251219/tpu_inference/models/jax/utils}/__init__.py +0 -0
  164. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/file_utils.py +0 -0
  165. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/multi_modal_utils.py +0 -0
  166. {tpu_inference-0.12.0.dev20251207/tpu_inference/models/vllm → tpu_inference-0.12.0.dev20251219/tpu_inference/models/jax/utils/quantization}/__init__.py +0 -0
  167. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -0
  168. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -0
  169. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -0
  170. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -0
  171. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -0
  172. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/quantization/quantization_utils.py +0 -0
  173. {tpu_inference-0.12.0.dev20251207/tpu_inference/runner → tpu_inference-0.12.0.dev20251219/tpu_inference/models/vllm}/__init__.py +0 -0
  174. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/vllm/vllm_model_wrapper_context.py +0 -0
  175. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/platforms/__init__.py +0 -0
  176. {tpu_inference-0.12.0.dev20251207/tpu_inference/spec_decode → tpu_inference-0.12.0.dev20251219/tpu_inference/runner}/__init__.py +0 -0
  177. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/block_table.py +0 -0
  178. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/compilation_manager.py +0 -0
  179. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/input_batch.py +0 -0
  180. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/kv_cache_manager.py +0 -0
  181. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/lora_utils.py +0 -0
  182. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/multimodal_manager.py +0 -0
  183. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/persistent_batch_manager.py +0 -0
  184. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/speculative_decoding_manager.py +0 -0
  185. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/structured_decoding_manager.py +0 -0
  186. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/utils.py +0 -0
  187. {tpu_inference-0.12.0.dev20251207/tpu_inference/spec_decode/jax → tpu_inference-0.12.0.dev20251219/tpu_inference/spec_decode}/__init__.py +0 -0
  188. {tpu_inference-0.12.0.dev20251207/tpu_inference/worker → tpu_inference-0.12.0.dev20251219/tpu_inference/spec_decode/jax}/__init__.py +0 -0
  189. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/spec_decode/jax/eagle3.py +0 -0
  190. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/tpu_info.py +0 -0
  191. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference.egg-info/dependency_links.txt +0 -0
  192. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference.egg-info/requires.txt +0 -0
  193. {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpu_inference
3
- Version: 0.12.0.dev20251207
3
+ Version: 0.12.0.dev20251219
4
4
  Author: tpu_inference Contributors
5
5
  Classifier: Development Status :: 3 - Alpha
6
6
  Classifier: Intended Audience :: Developers
@@ -53,14 +53,12 @@ Dynamic: requires-python
53
53
 
54
54
  ---
55
55
 
56
- _Upcoming Events_ 🔥
57
-
58
- - Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
59
- - Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
60
- - Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
61
-
62
56
  _Latest News_ 🔥
63
57
 
58
+ - [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
59
+ - Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
60
+ - Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
61
+
64
62
  - [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
65
63
 
66
64
  <details>
@@ -11,14 +11,12 @@
11
11
 
12
12
  ---
13
13
 
14
- _Upcoming Events_ 🔥
15
-
16
- - Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
17
- - Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
18
- - Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
19
-
20
14
  _Latest News_ 🔥
21
15
 
16
+ - [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
17
+ - Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
18
+ - Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
19
+
22
20
  - [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
23
21
 
24
22
  <details>
@@ -0,0 +1,191 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from absl.testing import absltest, parameterized
4
+ from jax._src import test_util as jtu
5
+
6
+ from tpu_inference.kernels.megablox.gmm import gmm
7
+
8
+ jax.config.parse_flags_with_absl()
9
+
10
+
11
+ def quantize_tensor(x: jax.Array,
12
+ dtype: jnp.dtype,
13
+ axis: int = -1,
14
+ block_size: int = 256):
15
+ if jnp.issubdtype(dtype, jnp.integer):
16
+ dtype_info = jnp.iinfo(dtype)
17
+ max_val = int(dtype_info.max)
18
+ min_val = int(dtype_info.min)
19
+ else:
20
+ dtype_info = jnp.finfo(dtype)
21
+ max_val = float(dtype_info.max)
22
+ min_val = float(dtype_info.min)
23
+
24
+ orig_shape = x.shape
25
+ blocked_shape = orig_shape[:axis] + (-1,
26
+ block_size) + orig_shape[axis + 1:]
27
+ x_blocked = x.reshape(blocked_shape)
28
+
29
+ x_blocked_abs_max = jnp.max(jnp.abs(x_blocked),
30
+ axis=axis + 1,
31
+ keepdims=True)
32
+ scale = x_blocked_abs_max / max_val
33
+ x_blocked_q = jnp.clip(x_blocked / scale, min_val, max_val).astype(dtype)
34
+
35
+ x_q = x_blocked_q.reshape(orig_shape)
36
+ scale = scale.squeeze(axis=axis + 1).astype(jnp.float32)
37
+ return x_q, scale
38
+
39
+
40
+ def reference_gmm(
41
+ lhs: jax.Array,
42
+ rhs: jax.Array,
43
+ group_sizes: jax.Array,
44
+ rhs_scale: jax.Array | None = None,
45
+ rhs_bias: jax.Array | None = None,
46
+ group_offset: jax.Array | None = None,
47
+ ):
48
+ num_groups, out_size, in_size = rhs.shape
49
+ assert lhs.shape[1] == in_size
50
+
51
+ if group_offset is None:
52
+ group_offset = jnp.array(0, dtype=jnp.int32)
53
+ start = group_sizes[:group_offset].sum()
54
+ group_sizes = group_sizes[group_offset:]
55
+ assert len(group_sizes) == num_groups
56
+
57
+ if rhs_scale is not None:
58
+ num_blocks = rhs_scale.shape[1]
59
+ else:
60
+ num_blocks = 1
61
+ block_size = in_size // num_blocks
62
+
63
+ gmm_out = [jnp.zeros((start, out_size), lhs.dtype)]
64
+ for group in range(num_groups):
65
+ end = start + group_sizes[group]
66
+
67
+ lhs_slice = lhs[start:end]
68
+ rhs_slice = rhs[group]
69
+
70
+ out = 0
71
+ for block in range(num_blocks):
72
+ block_start = block * block_size
73
+ block_end = block_start + block_size
74
+ lhs_block = lhs_slice[:, block_start:block_end].astype(jnp.float32)
75
+ rhs_block = rhs_slice[:, block_start:block_end].astype(jnp.float32)
76
+
77
+ acc = jnp.einsum("bd,hd->bh", lhs_block, rhs_block)
78
+ if rhs_scale is not None:
79
+ acc *= rhs_scale[group][block]
80
+ out += acc
81
+ if rhs_bias is not None:
82
+ out = out + rhs_bias[group]
83
+
84
+ gmm_out.append(out.astype(lhs.dtype))
85
+ start = end
86
+
87
+ return jnp.concat(gmm_out, axis=0)
88
+
89
+
90
+ @jtu.with_config(jax_numpy_dtype_promotion="standard")
91
+ class GmmTest(jtu.JaxTestCase):
92
+
93
+ @parameterized.product(
94
+ batch_size=[128],
95
+ in_size=[1024],
96
+ out_size=[1024],
97
+ num_groups=[16, 32],
98
+ has_bias=[True, False],
99
+ )
100
+ def test_gmm(self, batch_size, in_size, out_size, num_groups, has_bias):
101
+ key = jax.random.key(0)
102
+
103
+ lhs = jax.random.normal(key, (batch_size, in_size), dtype=jnp.bfloat16)
104
+ rhs = jax.random.normal(key, (num_groups, out_size, in_size),
105
+ dtype=jnp.bfloat16)
106
+ rhs_bias = None
107
+ if has_bias:
108
+ rhs_bias = jax.random.normal(key, (num_groups, 1, out_size),
109
+ dtype=jnp.bfloat16)
110
+
111
+ group_sizes = jax.random.randint(key, (num_groups, ),
112
+ 0,
113
+ batch_size,
114
+ dtype=jnp.int32)
115
+
116
+ expected = reference_gmm(lhs, rhs, group_sizes, rhs_bias=rhs_bias)
117
+ actual = gmm(
118
+ lhs,
119
+ rhs,
120
+ group_sizes,
121
+ rhs_bias=rhs_bias,
122
+ transpose_rhs=True,
123
+ preferred_element_type=jnp.bfloat16,
124
+ )
125
+
126
+ self.assertArraysAllClose(actual, expected)
127
+
128
+ @parameterized.product(
129
+ batch_size=[128],
130
+ in_size=[1024],
131
+ out_size=[1024],
132
+ num_groups=[16, 32],
133
+ has_bias=[True, False],
134
+ weight_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn],
135
+ block_size=[256, 512],
136
+ )
137
+ def test_gmm_weight_quantized(
138
+ self,
139
+ batch_size,
140
+ in_size,
141
+ out_size,
142
+ num_groups,
143
+ has_bias,
144
+ weight_dtype,
145
+ block_size,
146
+ ):
147
+ if weight_dtype == jnp.float4_e2m1fn and not jtu.is_device_tpu_at_least(
148
+ version=7):
149
+ self.skipTest("Expect TPUv7+")
150
+ key = jax.random.key(0)
151
+
152
+ lhs = jax.random.normal(key, (batch_size, in_size), dtype=jnp.bfloat16)
153
+ rhs = jax.random.normal(key, (num_groups, out_size, in_size),
154
+ dtype=jnp.bfloat16)
155
+ rhs_q, rhs_scale = quantize_tensor(rhs,
156
+ weight_dtype,
157
+ axis=2,
158
+ block_size=block_size)
159
+ rhs_scale = jnp.swapaxes(rhs_scale, 1, 2)
160
+ rhs_scale = jnp.expand_dims(rhs_scale, axis=2)
161
+
162
+ rhs_bias = None
163
+ if has_bias:
164
+ rhs_bias = jax.random.normal(key, (num_groups, 1, out_size),
165
+ dtype=jnp.bfloat16)
166
+
167
+ group_sizes = jax.random.randint(key, (num_groups, ),
168
+ 0,
169
+ batch_size,
170
+ dtype=jnp.int32)
171
+
172
+ expected = reference_gmm(lhs,
173
+ rhs_q,
174
+ group_sizes,
175
+ rhs_scale=rhs_scale,
176
+ rhs_bias=rhs_bias)
177
+ actual = gmm(
178
+ lhs,
179
+ rhs_q,
180
+ group_sizes,
181
+ rhs_scale=rhs_scale,
182
+ rhs_bias=rhs_bias,
183
+ transpose_rhs=True,
184
+ preferred_element_type=jnp.bfloat16,
185
+ )
186
+
187
+ self.assertArraysAllClose(actual, expected, atol=3e-1, rtol=3e-1)
188
+
189
+
190
+ if __name__ == "__main__":
191
+ absltest.main(testLoader=jtu.JaxTestLoader())
@@ -1,7 +1,5 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- import functools
4
-
5
3
  import jax
6
4
  import jax.numpy as jnp
7
5
  from absl.testing import absltest, parameterized
@@ -10,6 +8,7 @@ from jax._src import test_util as jtu
10
8
  from tpu_inference.kernels.quantized_matmul import (kernel, tuned_block_sizes,
11
9
  util)
12
10
 
11
+ xla_quantized_matmul = kernel.xla_quantized_matmul
13
12
  quantized_matmul_kernel = kernel.quantized_matmul_kernel
14
13
  quantize_tensor = util.quantize_tensor
15
14
  get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
@@ -17,37 +16,6 @@ get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
17
16
  jax.config.parse_flags_with_absl()
18
17
 
19
18
 
20
- @functools.partial(jax.jit, static_argnames=["quantize_activation"])
21
- def reference_quantized_matmul(
22
- x: jax.Array,
23
- w_q: jax.Array,
24
- w_scale: jax.Array,
25
- quantize_activation=True,
26
- ):
27
- if quantize_activation:
28
- acc_dtype = jnp.float32
29
- if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
30
- acc_dtype = jnp.int32
31
-
32
- x_q, x_scale = quantize_tensor(x, w_q.dtype)
33
- out = jax.lax.dot_general(
34
- x_q,
35
- w_q,
36
- dimension_numbers=(((1, ), (1, )), ((), ())),
37
- preferred_element_type=acc_dtype,
38
- ).astype(jnp.float32)
39
- out *= x_scale
40
- else:
41
- out = jax.lax.dot_general(
42
- x,
43
- w_q,
44
- dimension_numbers=(((1, ), (1, )), ((), ())),
45
- preferred_element_type=jnp.float32,
46
- )
47
- out *= jnp.expand_dims(w_scale, 0)
48
- return out.astype(x.dtype)
49
-
50
-
51
19
  @jtu.with_config(jax_numpy_dtype_promotion="standard")
52
20
  class QuantizedMatmulKernelTest(jtu.JaxTestCase):
53
21
 
@@ -94,7 +62,7 @@ class QuantizedMatmulKernelTest(jtu.JaxTestCase):
94
62
  x_q_dtype=x_q_dtype,
95
63
  tuned_value=tuned_value,
96
64
  )
97
- expected = reference_quantized_matmul(
65
+ expected = xla_quantized_matmul(
98
66
  x, w_q, w_scale, quantize_activation=quantize_activation)
99
67
 
100
68
  self.assertAllClose(output,
@@ -176,7 +176,9 @@ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
176
176
  )
177
177
  output = output[:cu_q_lens[distribution[-1]]]
178
178
 
179
- dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
179
+ dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
180
+ dtypes, "bit_width") else dtypes.itemsize_bits(
181
+ jnp.dtype(kv_dtype)))
180
182
  tols = {
181
183
  32: 0.15,
182
184
  16: 0.2,
@@ -162,7 +162,9 @@ class RaggedPagedAttentionKernelTest(jtu.JaxTestCase):
162
162
  )
163
163
  output = output[:cu_q_lens[distribution[-1]]]
164
164
 
165
- dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
165
+ dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
166
+ dtypes, "bit_width") else dtypes.itemsize_bits(
167
+ jnp.dtype(kv_dtype)))
166
168
  tols = {
167
169
  32: 0.15,
168
170
  16: 0.2,
@@ -18,7 +18,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
18
18
  ReplicatedLinearWithLoRA,
19
19
  RowParallelLinearWithLoRA)
20
20
  # yapf: enable
21
- from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
21
+ from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
22
22
  from vllm.lora.punica_wrapper import get_punica_wrapper
23
23
  from vllm.model_executor.layers.linear import (ColumnParallelLinear,
24
24
  MergedColumnParallelLinear,
@@ -499,9 +499,13 @@ def _create_random_linear_parallel_layer(layer_type, vllm_config, mesh):
499
499
  return linear, lora_linear
500
500
 
501
501
 
502
+ def _get_devices():
503
+ return jax.devices()
504
+
505
+
502
506
  def _create_mesh():
503
507
  axis_names = ("data", "model")
504
- devices = jax.devices()
508
+ devices = _get_devices()
505
509
  mesh_shape = (1, len(devices))
506
510
  mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices)
507
511
  return mesh
@@ -513,7 +517,7 @@ def _verify_lora_linear_layer(linear, lora_linear):
513
517
  # BaseLinearLayerWithLoRA.weight property guarantees this.
514
518
  # if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
515
519
  # So the below check will fail.
516
- if len(jax.devices()) == 1:
520
+ if len(_get_devices()) == 1:
517
521
  assert torch.equal(linear.weight.data,
518
522
  lora_linear.weight.to('cpu'))
519
523
 
@@ -29,7 +29,7 @@ def setup_vllm(num_loras: int, tp: int = 1) -> vllm.LLM:
29
29
 
30
30
 
31
31
  # For multi-chip test, we only use TP=2 because the base model Qwen/Qwen2.5-3B-Instruct has 2 kv heads and the current attention kernel requires it to be divisible by tp_size.
32
- TP = [2] if os.environ.get("USE_V6E8_QUEUE", False) else [1]
32
+ TP = [2] if os.environ.get("TEST_LORA_TP", False) else [1]
33
33
 
34
34
 
35
35
  @pytest.mark.parametrize("tp", TP)
@@ -0,0 +1,53 @@
1
+ import os
2
+ import time
3
+
4
+ import pytest
5
+ import vllm
6
+ from vllm.lora.request import LoRARequest
7
+
8
+ TP = [2] if os.environ.get("USE_V6E8_QUEUE", False) else [1]
9
+
10
+
11
+ @pytest.mark.parametrize("tp", TP)
12
+ def test_lora_performance(tp):
13
+ prompt = "What is 1+1? \n"
14
+ llm_without_lora = vllm.LLM(
15
+ model="Qwen/Qwen2.5-3B-Instruct",
16
+ max_model_len=256,
17
+ max_num_batched_tokens=64,
18
+ max_num_seqs=8,
19
+ tensor_parallel_size=tp,
20
+ )
21
+ start_time = time.time()
22
+ llm_without_lora.generate(
23
+ prompt,
24
+ sampling_params=vllm.SamplingParams(max_tokens=16, temperature=0),
25
+ )[0].outputs[0].text
26
+ base_time = time.time() - start_time
27
+
28
+ del llm_without_lora
29
+ # Waiting for TPUs to be released
30
+ time.sleep(10)
31
+
32
+ llm_with_lora = vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
33
+ max_model_len=256,
34
+ max_num_batched_tokens=64,
35
+ max_num_seqs=8,
36
+ tensor_parallel_size=tp,
37
+ enable_lora=True,
38
+ max_loras=1,
39
+ max_lora_rank=8)
40
+ lora_request = LoRARequest(
41
+ "lora_adapter_2", 2,
42
+ "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
43
+ start_time = time.time()
44
+ llm_with_lora.generate(prompt,
45
+ sampling_params=vllm.SamplingParams(max_tokens=16,
46
+ temperature=0),
47
+ lora_request=lora_request)[0].outputs[0].text
48
+ lora_time = time.time() - start_time
49
+ print(f"Base time: {base_time}, LoRA time: {lora_time}")
50
+ assert (base_time /
51
+ lora_time) < 8, f"Base time: {base_time}, LoRA time: {lora_time}"
52
+
53
+ del llm_with_lora
@@ -60,6 +60,7 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
60
60
  monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
61
61
  monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
62
62
  monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
63
+ monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "0")
63
64
  monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
64
65
 
65
66
  # Test SKIP_JAX_PRECOMPILE (default False)
@@ -86,6 +87,82 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
86
87
  monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
87
88
  assert envs.USE_MOE_EP_KERNEL is True
88
89
 
90
+ # Test ENABLE_QUANTIZED_MATMUL_KERNEL (default False)
91
+ assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is False
92
+ monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "1")
93
+ assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is True
94
+
95
+
96
+ def test_boolean_env_vars_string_values(monkeypatch: pytest.MonkeyPatch):
97
+ """Test that boolean env vars accept string values like 'True' and 'False'"""
98
+
99
+ # Test NEW_MODEL_DESIGN with string "True"
100
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "True")
101
+ assert envs.NEW_MODEL_DESIGN is True
102
+
103
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "true")
104
+ assert envs.NEW_MODEL_DESIGN is True
105
+
106
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "False")
107
+ assert envs.NEW_MODEL_DESIGN is False
108
+
109
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "false")
110
+ assert envs.NEW_MODEL_DESIGN is False
111
+
112
+ # Test SKIP_JAX_PRECOMPILE with string values
113
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "True")
114
+ assert envs.SKIP_JAX_PRECOMPILE is True
115
+
116
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "false")
117
+ assert envs.SKIP_JAX_PRECOMPILE is False
118
+
119
+ # Test VLLM_XLA_CHECK_RECOMPILATION with string values
120
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "TRUE")
121
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
122
+
123
+ monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "FALSE")
124
+ assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
125
+
126
+ # Test USE_MOE_EP_KERNEL with string values
127
+ monkeypatch.setenv("USE_MOE_EP_KERNEL", "true")
128
+ assert envs.USE_MOE_EP_KERNEL is True
129
+
130
+ monkeypatch.setenv("USE_MOE_EP_KERNEL", "False")
131
+ assert envs.USE_MOE_EP_KERNEL is False
132
+
133
+
134
+ def test_boolean_env_vars_invalid_values(monkeypatch: pytest.MonkeyPatch):
135
+ """Test that boolean env vars raise errors for invalid values"""
136
+
137
+ # Test invalid value for NEW_MODEL_DESIGN
138
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "yes")
139
+ with pytest.raises(
140
+ ValueError,
141
+ match="Invalid boolean value 'yes' for NEW_MODEL_DESIGN"):
142
+ _ = envs.NEW_MODEL_DESIGN
143
+
144
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "2")
145
+ with pytest.raises(ValueError,
146
+ match="Invalid boolean value '2' for NEW_MODEL_DESIGN"):
147
+ _ = envs.NEW_MODEL_DESIGN
148
+
149
+ # Test invalid value for SKIP_JAX_PRECOMPILE
150
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "invalid")
151
+ with pytest.raises(
152
+ ValueError,
153
+ match="Invalid boolean value 'invalid' for SKIP_JAX_PRECOMPILE"):
154
+ _ = envs.SKIP_JAX_PRECOMPILE
155
+
156
+
157
+ def test_boolean_env_vars_empty_string(monkeypatch: pytest.MonkeyPatch):
158
+ """Test that empty string returns default value"""
159
+
160
+ monkeypatch.setenv("NEW_MODEL_DESIGN", "")
161
+ assert envs.NEW_MODEL_DESIGN is False # Should return default
162
+
163
+ monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "")
164
+ assert envs.SKIP_JAX_PRECOMPILE is False # Should return default
165
+
89
166
 
90
167
  def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
91
168
  # Ensure clean environment for integer vars by setting to defaults
@@ -179,7 +256,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
179
256
 
180
257
  def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
181
258
  monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
182
- assert envs.MODEL_IMPL_TYPE == "flax_nnx"
259
+ assert envs.MODEL_IMPL_TYPE == "auto"
183
260
 
184
261
 
185
262
  def test_cache_preserves_values_across_env_changes(
@@ -694,9 +694,9 @@ class TPUConnectorWorker:
694
694
 
695
695
  def get_uuid() -> int:
696
696
  int128 = uuid4().int
697
- # Must be 64-bit int, otherwise vllm output encoder would raise error.
698
- int64 = int128 >> 64
699
- return int64
697
+ # Must be less than 64-bit int, otherwise vllm output encoder would raise error.
698
+ # use 50 bit to avoid GO trunk the int when doing JSon serialization
699
+ return int128 >> 78
700
700
 
701
701
 
702
702
  @jax.jit
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
16
16
  DECODE_SLICES: str = ""
17
17
  SKIP_JAX_PRECOMPILE: bool = False
18
18
  VLLM_XLA_CHECK_RECOMPILATION: bool = False
19
- MODEL_IMPL_TYPE: str = "flax_nnx"
19
+ MODEL_IMPL_TYPE: str = "auto"
20
20
  NEW_MODEL_DESIGN: bool = False
21
21
  PHASED_PROFILING_DIR: str = ""
22
22
  PYTHON_TRACER_LEVEL: int = 1
@@ -24,6 +24,7 @@ if TYPE_CHECKING:
24
24
  NUM_SLICES: int = 1
25
25
  RAY_USAGE_STATS_ENABLED: str = "0"
26
26
  VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
27
+ ENABLE_QUANTIZED_MATMUL_KERNEL: bool = False
27
28
 
28
29
 
29
30
  def env_with_choices(
@@ -69,6 +70,34 @@ def env_with_choices(
69
70
  return _get_validated_env
70
71
 
71
72
 
73
+ def env_bool(env_name: str, default: bool = False) -> Callable[[], bool]:
74
+ """
75
+ Accepts both numeric strings ("0", "1") and boolean strings
76
+ ("true", "false", "True", "False").
77
+
78
+ Args:
79
+ env_name: Name of the environment variable
80
+ default: Default boolean value if not set
81
+ """
82
+
83
+ def _get_bool_env() -> bool:
84
+ value = os.getenv(env_name)
85
+ if value is None or value == "":
86
+ return default
87
+
88
+ value_lower = value.lower()
89
+ if value_lower in ("true", "1"):
90
+ return True
91
+ elif value_lower in ("false", "0"):
92
+ return False
93
+ else:
94
+ raise ValueError(
95
+ f"Invalid boolean value '{value}' for {env_name}. "
96
+ f"Valid options: '0', '1', 'true', 'false', 'True', 'False'.")
97
+
98
+ return _get_bool_env
99
+
100
+
72
101
  environment_variables: dict[str, Callable[[], Any]] = {
73
102
  # JAX platform selection (e.g., "tpu", "cpu", "proxy")
74
103
  "JAX_PLATFORMS":
@@ -93,17 +122,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
93
122
  lambda: os.getenv("DECODE_SLICES", ""),
94
123
  # Skip JAX precompilation step during initialization
95
124
  "SKIP_JAX_PRECOMPILE":
96
- lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE") or "0")),
125
+ env_bool("SKIP_JAX_PRECOMPILE", default=False),
97
126
  # Check for XLA recompilation during execution
98
127
  "VLLM_XLA_CHECK_RECOMPILATION":
99
- lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION") or "0")),
128
+ env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
100
129
  # Model implementation type (e.g., "flax_nnx")
101
130
  "MODEL_IMPL_TYPE":
102
- env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
103
- ["vllm", "flax_nnx", "jetpack"]),
131
+ env_with_choices("MODEL_IMPL_TYPE", "auto",
132
+ ["auto", "vllm", "flax_nnx", "jetpack"]),
104
133
  # Enable new experimental model design
105
134
  "NEW_MODEL_DESIGN":
106
- lambda: bool(int(os.getenv("NEW_MODEL_DESIGN") or "0")),
135
+ env_bool("NEW_MODEL_DESIGN", default=False),
107
136
  # Directory to store phased profiling output
108
137
  "PHASED_PROFILING_DIR":
109
138
  lambda: os.getenv("PHASED_PROFILING_DIR", ""),
@@ -112,7 +141,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
112
141
  lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
113
142
  # Use custom expert-parallel kernel for MoE (Mixture of Experts)
114
143
  "USE_MOE_EP_KERNEL":
115
- lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL") or "0")),
144
+ env_bool("USE_MOE_EP_KERNEL", default=False),
116
145
  # Number of TPU slices for multi-slice mesh
117
146
  "NUM_SLICES":
118
147
  lambda: int(os.getenv("NUM_SLICES") or "1"),
@@ -122,6 +151,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
122
151
  # Ray compiled DAG channel type for TPU
123
152
  "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
124
153
  env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
154
+ "ENABLE_QUANTIZED_MATMUL_KERNEL":
155
+ lambda: bool(int(os.getenv("ENABLE_QUANTIZED_MATMUL_KERNEL") or "0")),
125
156
  }
126
157
 
127
158
 
@@ -145,6 +145,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
145
145
  device_str: node['Resources'][device_str]
146
146
  } for node in ray_nodes]
147
147
  else:
148
+ assert pp_size == len(
149
+ ray_nodes
150
+ ), f"Cannot use PP across hosts, please set --pipeline-parallel-size to 1 or {len(ray_nodes)}"
148
151
  num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
149
152
  placement_group_specs = [{
150
153
  device_str: num_devices_per_pp_rank
@@ -540,12 +540,16 @@ def get_vmem_estimate_bytes(
540
540
  """Returns the total vmem bytes used by the kernel."""
541
541
  m_per_device = m // tp_size
542
542
  n_per_device = n // tp_size
543
- y_vmem_bytes = n_per_device * k * dtypes.bit_width(y_dtype) // 8
543
+ y_vmem_bytes = (n_per_device * k * (dtypes.bit_width(y_dtype) if hasattr(
544
+ dtypes, "bit_width") else dtypes.itemsize_bits(y_dtype)) // 8)
544
545
  total_bytes = (
545
- 2 * m_per_device * k * dtypes.bit_width(x_dtype) //
546
- 8 # x_vmem_scratch_ref
546
+ 2 * m_per_device * k *
547
+ (dtypes.bit_width(x_dtype) if hasattr(dtypes, "bit_width") else
548
+ dtypes.itemsize_bits(x_dtype)) // 8 # x_vmem_scratch_ref
547
549
  + y_vmem_bytes # y_vmem_scratch_ref
548
- + 2 * m * bn * dtypes.bit_width(out_dtype) // 8 # o_vmem_scratch_ref
550
+ + 2 * m * bn *
551
+ (dtypes.bit_width(out_dtype) if hasattr(dtypes, "bit_width") else
552
+ dtypes.itemsize_bits(out_dtype)) // 8 # o_vmem_scratch_ref
549
553
  + acc_bytes # acc_vmem_scratch_ref, jnp.float32
550
554
  )
551
555
  return total_bytes
@@ -639,8 +643,10 @@ def all_gather_matmul(
639
643
  # NOTE(chengjiyao): acc buffer is not used in the grid_k == 1 case.
640
644
  if grid_k == 1:
641
645
  acc_shape = (8, 128)
642
- acc_bytes = acc_shape[0] * acc_shape[1] * dtypes.bit_width(
643
- jnp.float32) // 8
646
+ acc_bytes = (
647
+ acc_shape[0] *
648
+ acc_shape[1] * (dtypes.bit_width(jnp.float32) if hasattr(
649
+ dtypes, "bit_width") else dtypes.itemsize_bits(jnp.float32)) // 8)
644
650
  y_vmem_shape = (n_per_device, k) if rhs_transpose else (k, n_per_device)
645
651
  estimated_vmem_bytes = get_vmem_estimate_bytes(
646
652
  m,