tpu-inference 0.12.0.dev20251219__tar.gz → 0.12.0rc1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (193) hide show
  1. {tpu_inference-0.12.0.dev20251219/tpu_inference.egg-info → tpu_inference-0.12.0rc1}/PKG-INFO +8 -6
  2. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/README.md +6 -4
  3. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/requirements.txt +1 -1
  4. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/mla_v1_test.py +41 -129
  5. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/quantized_matmul_kernel_test.py +34 -2
  6. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +1 -3
  7. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/ragged_paged_attention_kernel_v3_test.py +1 -3
  8. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/lora/test_layers.py +3 -7
  9. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/lora/test_lora.py +1 -1
  10. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/test_envs.py +1 -78
  11. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/test_quantization.py +0 -3
  12. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/distributed/tpu_connector.py +3 -3
  13. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/envs.py +7 -38
  14. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/executors/ray_distributed_executor.py +0 -3
  15. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/collectives/all_gather_matmul.py +6 -12
  16. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +2 -7
  17. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/fused_moe/v1/kernel.py +324 -357
  18. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/mla/v1/kernel.py +120 -98
  19. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/quantized_matmul/kernel.py +8 -69
  20. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +1 -2
  21. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +1 -2
  22. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +101 -181
  23. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +78 -82
  24. tpu_inference-0.12.0rc1/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  25. tpu_inference-0.12.0rc1/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  26. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/ragged_paged_attention/v3/util.py +1 -2
  27. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/common/attention_interface.py +7 -1
  28. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/common/quant_methods.py +0 -1
  29. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/common/sharding.py +2 -6
  30. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/attention/deepseek_v3_attention.py +64 -232
  31. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  32. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/fused_moe.py +247 -180
  33. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/linear_common.py +21 -43
  34. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/quantization/__init__.py +0 -2
  35. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/quantization/awq.py +1 -1
  36. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/quantization/common.py +5 -5
  37. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +3 -4
  38. tpu_inference-0.12.0rc1/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  39. tpu_inference-0.12.0rc1/tpu_inference/layers/vllm/quantization/mxfp4.py +341 -0
  40. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/quantization/unquantized.py +81 -105
  41. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/lora/torch_lora_ops.py +13 -8
  42. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/common/model_loader.py +20 -48
  43. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/deepseek_v3.py +64 -185
  44. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/gpt_oss.py +3 -3
  45. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/llama3.py +33 -79
  46. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/quantization/quantization_utils.py +2 -4
  47. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/weight_utils.py +2 -26
  48. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/vllm/vllm_model_wrapper.py +1 -1
  49. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/platforms/tpu_platform.py +37 -15
  50. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/compilation_manager.py +2 -3
  51. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/kv_cache.py +20 -40
  52. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/kv_cache_manager.py +15 -31
  53. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/tpu_runner.py +7 -14
  54. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/utils.py +6 -11
  55. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/worker/tpu_worker.py +44 -44
  56. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1/tpu_inference.egg-info}/PKG-INFO +8 -6
  57. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference.egg-info/SOURCES.txt +0 -7
  58. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference.egg-info/requires.txt +1 -1
  59. tpu_inference-0.12.0.dev20251219/tests/kernels/gmm_test.py +0 -191
  60. tpu_inference-0.12.0.dev20251219/tests/lora/test_lora_perf.py +0 -53
  61. tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/megablox/common.py +0 -41
  62. tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/megablox/gmm.py +0 -633
  63. tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -4447
  64. tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +0 -535
  65. tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/pp_utils.py +0 -39
  66. tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +0 -252
  67. tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/fp8.py +0 -104
  68. tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/mxfp4.py +0 -448
  69. tpu_inference-0.12.0.dev20251219/tpu_inference/worker/__init__.py +0 -0
  70. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/LICENSE +0 -0
  71. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/MANIFEST.in +0 -0
  72. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/pyproject.toml +0 -0
  73. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/setup.cfg +0 -0
  74. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/setup.py +0 -0
  75. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/__init__.py +0 -0
  76. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/core/__init__.py +0 -0
  77. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/core/test_core_tpu.py +0 -0
  78. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/core/test_disagg_executor.py +0 -0
  79. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/core/test_disagg_utils.py +0 -0
  80. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/core/test_dp_scheduler.py +0 -0
  81. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/core/test_init.py +0 -0
  82. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/__init__.py +0 -0
  83. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/fused_moe_v1_test.py +0 -0
  84. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/ragged_kv_cache_update_v2_test.py +0 -0
  85. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/ragged_paged_attention_kernel_v2_test.py +0 -0
  86. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/lora/__init__.py +0 -0
  87. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/lora/conftest.py +0 -0
  88. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/lora/test_bgmv.py +0 -0
  89. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/lora/utils.py +0 -0
  90. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/test_base.py +0 -0
  91. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/test_tpu_info.py +0 -0
  92. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/test_utils.py +0 -0
  93. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/__init__.py +0 -0
  94. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/core/__init__.py +0 -0
  95. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/core/core_tpu.py +0 -0
  96. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/core/disagg_executor.py +0 -0
  97. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/core/disagg_utils.py +0 -0
  98. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/core/sched/__init__.py +0 -0
  99. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/core/sched/dp_scheduler.py +0 -0
  100. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/distributed/__init__.py +0 -0
  101. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/distributed/jax_parallel_state.py +0 -0
  102. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/distributed/utils.py +0 -0
  103. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/env_override.py +0 -0
  104. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/executors/__init__.py +0 -0
  105. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/experimental/__init__.py +0 -0
  106. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/experimental/llama3_jax_stashed.py +0 -0
  107. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/__init__.py +0 -0
  108. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/collectives/__init__.py +0 -0
  109. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/collectives/util.py +0 -0
  110. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/flash_attention/__init__.py +0 -0
  111. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/flash_attention/kernel.py +0 -0
  112. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/fused_moe/__init__.py +0 -0
  113. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  114. {tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/megablox → tpu_inference-0.12.0rc1/tpu_inference/kernels/mla}/__init__.py +0 -0
  115. {tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/mla → tpu_inference-0.12.0rc1/tpu_inference/kernels/mla/v1}/__init__.py +0 -0
  116. {tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/mla/v1 → tpu_inference-0.12.0rc1/tpu_inference/kernels/quantized_matmul}/__init__.py +0 -0
  117. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +0 -0
  118. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/quantized_matmul/util.py +0 -0
  119. {tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/quantized_matmul → tpu_inference-0.12.0rc1/tpu_inference/kernels/ragged_paged_attention}/__init__.py +0 -0
  120. {tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention → tpu_inference-0.12.0rc1/tpu_inference/kernels/ragged_paged_attention/v2}/__init__.py +0 -0
  121. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +0 -0
  122. {tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v2 → tpu_inference-0.12.0rc1/tpu_inference/kernels/ragged_paged_attention/v3}/__init__.py +0 -0
  123. {tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v3 → tpu_inference-0.12.0rc1/tpu_inference/layers}/__init__.py +0 -0
  124. {tpu_inference-0.12.0.dev20251219/tpu_inference/layers → tpu_inference-0.12.0rc1/tpu_inference/layers/common}/__init__.py +0 -0
  125. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/common/attention_metadata.py +0 -0
  126. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/common/binary_search.py +0 -0
  127. {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/common → tpu_inference-0.12.0rc1/tpu_inference/layers/jax}/__init__.py +0 -0
  128. {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax → tpu_inference-0.12.0rc1/tpu_inference/layers/jax/attention}/__init__.py +0 -0
  129. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/attention/attention.py +0 -0
  130. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/attention/llama4_attention.py +0 -0
  131. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/base.py +0 -0
  132. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/constants.py +0 -0
  133. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/layers.py +0 -0
  134. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/misc.py +0 -0
  135. {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/attention → tpu_inference-0.12.0rc1/tpu_inference/layers/jax/moe}/__init__.py +0 -0
  136. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/moe/deepseek_v3_moe.py +0 -0
  137. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/moe/gpt_oss_moe.py +0 -0
  138. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/moe/moe.py +0 -0
  139. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/rope.py +0 -0
  140. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/rope_interface.py +0 -0
  141. {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/moe → tpu_inference-0.12.0rc1/tpu_inference/layers/jax/sample}/__init__.py +0 -0
  142. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/sample/rejection_sampler.py +0 -0
  143. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/sample/sampling.py +0 -0
  144. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/sample/sampling_metadata.py +0 -0
  145. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/transformer_block.py +0 -0
  146. {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/sample → tpu_inference-0.12.0rc1/tpu_inference/layers/vllm}/__init__.py +0 -0
  147. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/attention.py +0 -0
  148. {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm → tpu_inference-0.12.0rc1/tpu_inference/layers/vllm/quantization/compressed_tensors}/__init__.py +0 -0
  149. {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/compressed_tensors → tpu_inference-0.12.0rc1/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes}/__init__.py +0 -0
  150. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +0 -0
  151. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +0 -0
  152. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/sharding.py +0 -0
  153. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/logger.py +0 -0
  154. {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes → tpu_inference-0.12.0rc1/tpu_inference/lora}/__init__.py +0 -0
  155. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/lora/torch_punica_tpu.py +0 -0
  156. {tpu_inference-0.12.0.dev20251219/tpu_inference/lora → tpu_inference-0.12.0rc1/tpu_inference/models}/__init__.py +0 -0
  157. {tpu_inference-0.12.0.dev20251219/tpu_inference/models → tpu_inference-0.12.0rc1/tpu_inference/models/common}/__init__.py +0 -0
  158. {tpu_inference-0.12.0.dev20251219/tpu_inference/models/common → tpu_inference-0.12.0rc1/tpu_inference/models/jax}/__init__.py +0 -0
  159. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/jax_intermediate_tensor.py +0 -0
  160. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/llama4.py +0 -0
  161. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/llama_eagle3.py +0 -0
  162. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/llama_guard_4.py +0 -0
  163. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/qwen2.py +0 -0
  164. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/qwen2_5_vl.py +0 -0
  165. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/qwen3.py +0 -0
  166. {tpu_inference-0.12.0.dev20251219/tpu_inference/models/jax → tpu_inference-0.12.0rc1/tpu_inference/models/jax/utils}/__init__.py +0 -0
  167. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/file_utils.py +0 -0
  168. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/multi_modal_utils.py +0 -0
  169. {tpu_inference-0.12.0.dev20251219/tpu_inference/models/jax/utils → tpu_inference-0.12.0rc1/tpu_inference/models/jax/utils/quantization}/__init__.py +0 -0
  170. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -0
  171. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -0
  172. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -0
  173. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -0
  174. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -0
  175. {tpu_inference-0.12.0.dev20251219/tpu_inference/models/jax/utils/quantization → tpu_inference-0.12.0rc1/tpu_inference/models/vllm}/__init__.py +0 -0
  176. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/vllm/vllm_model_wrapper_context.py +0 -0
  177. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/platforms/__init__.py +0 -0
  178. {tpu_inference-0.12.0.dev20251219/tpu_inference/models/vllm → tpu_inference-0.12.0rc1/tpu_inference/runner}/__init__.py +0 -0
  179. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/block_table.py +0 -0
  180. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/input_batch.py +0 -0
  181. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/lora_utils.py +0 -0
  182. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/multimodal_manager.py +0 -0
  183. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/persistent_batch_manager.py +0 -0
  184. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/speculative_decoding_manager.py +0 -0
  185. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/structured_decoding_manager.py +0 -0
  186. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/utils.py +0 -0
  187. {tpu_inference-0.12.0.dev20251219/tpu_inference/runner → tpu_inference-0.12.0rc1/tpu_inference/spec_decode}/__init__.py +0 -0
  188. {tpu_inference-0.12.0.dev20251219/tpu_inference/spec_decode → tpu_inference-0.12.0rc1/tpu_inference/spec_decode/jax}/__init__.py +0 -0
  189. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/spec_decode/jax/eagle3.py +0 -0
  190. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/tpu_info.py +0 -0
  191. {tpu_inference-0.12.0.dev20251219/tpu_inference/spec_decode/jax → tpu_inference-0.12.0rc1/tpu_inference/worker}/__init__.py +0 -0
  192. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference.egg-info/dependency_links.txt +0 -0
  193. {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpu_inference
3
- Version: 0.12.0.dev20251219
3
+ Version: 0.12.0rc1
4
4
  Author: tpu_inference Contributors
5
5
  Classifier: Development Status :: 3 - Alpha
6
6
  Classifier: Intended Audience :: Developers
@@ -25,7 +25,7 @@ Requires-Dist: jax[tpu]==0.8.0
25
25
  Requires-Dist: jaxlib==0.8.0
26
26
  Requires-Dist: jaxtyping
27
27
  Requires-Dist: flax==0.11.1
28
- Requires-Dist: torchax==0.0.10
28
+ Requires-Dist: torchax==0.0.7
29
29
  Requires-Dist: qwix==0.1.1
30
30
  Requires-Dist: torchvision==0.24.0
31
31
  Requires-Dist: pathwaysutils
@@ -53,11 +53,13 @@ Dynamic: requires-python
53
53
 
54
54
  ---
55
55
 
56
- _Latest News_ 🔥
56
+ _Upcoming Events_ 🔥
57
+
58
+ - Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
59
+ - Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
60
+ - Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
57
61
 
58
- - [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
59
- - Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
60
- - Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
62
+ _Latest News_ 🔥
61
63
 
62
64
  - [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
63
65
 
@@ -11,11 +11,13 @@
11
11
 
12
12
  ---
13
13
 
14
- _Latest News_ 🔥
14
+ _Upcoming Events_ 🔥
15
+
16
+ - Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
17
+ - Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
18
+ - Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
15
19
 
16
- - [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
17
- - Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
18
- - Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
20
+ _Latest News_ 🔥
19
21
 
20
22
  - [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
21
23
 
@@ -9,7 +9,7 @@ jax[tpu]==0.8.0
9
9
  jaxlib==0.8.0
10
10
  jaxtyping
11
11
  flax==0.11.1
12
- torchax==0.0.10
12
+ torchax==0.0.7
13
13
  qwix==0.1.1
14
14
  torchvision==0.24.0
15
15
  pathwaysutils
@@ -42,7 +42,6 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
42
42
 
43
43
  padded_r_dim = align_to(r_dim, 128)
44
44
  padded_lkv_dim = align_to(lkv_dim, 128)
45
- padded_kv_dim = padded_lkv_dim + padded_r_dim
46
45
  packing = get_dtype_packing(kv_dtype)
47
46
  q_lens = [s[0] for s in seq_lens]
48
47
  kv_lens_list = [s[1] for s in seq_lens]
@@ -70,10 +69,13 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
70
69
  new_kv_c = gen_random((total_q_len, lkv_dim), kv_dtype)
71
70
  new_k_pe = gen_random((total_q_len, r_dim), kv_dtype)
72
71
 
73
- cache_kv = gen_random(
74
- (total_num_pages, page_size // packing, packing, padded_kv_dim),
72
+ cache_kv_c = gen_random(
73
+ (total_num_pages, page_size // packing, packing, padded_lkv_dim),
75
74
  kv_dtype,
76
75
  )
76
+ cache_k_pe = gen_random(
77
+ (total_num_pages, page_size // packing, packing, padded_r_dim),
78
+ kv_dtype)
77
79
  kv_lens = jnp.array(kv_lens_list, dtype=jnp.int32)
78
80
  page_indices = jnp.array(page_indices_list, dtype=jnp.int32)
79
81
  cu_q_lens = jnp.array(cu_q_lens_list, dtype=jnp.int32)
@@ -82,13 +84,14 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
82
84
  ql_nope_for_kernel = ql_nope.copy()
83
85
  q_pe_for_kernel = q_pe.copy()
84
86
 
85
- expected_out, expected_updated_kv = (
87
+ expected_out, expected_updated_kv_c, expeceted_updated_k_pe = (
86
88
  mla.ref_mla_ragged_paged_attention(
87
89
  ql_nope,
88
90
  q_pe,
89
91
  new_kv_c,
90
92
  new_k_pe,
91
- cache_kv.copy(),
93
+ cache_kv_c.copy(),
94
+ cache_k_pe.copy(),
92
95
  kv_lens,
93
96
  page_indices,
94
97
  cu_q_lens,
@@ -98,140 +101,49 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
98
101
  soft_cap=soft_cap,
99
102
  ))
100
103
 
101
- kernel_out, kernel_updated_kv = (mla.mla_ragged_paged_attention(
102
- ql_nope_for_kernel,
103
- q_pe_for_kernel,
104
- new_kv_c,
105
- new_k_pe,
106
- cache_kv.copy(),
107
- kv_lens,
108
- page_indices,
109
- cu_q_lens,
110
- distribution,
111
- sm_scale=sm_scale,
112
- sliding_window=sliding_window,
113
- soft_cap=soft_cap,
114
- num_kv_pages_per_block=num_kv_pages_per_block,
115
- num_queries_per_block=num_queries_per_block,
116
- vmem_limit_bytes=vmem_limit_bytes,
117
- ))
104
+ kernel_out, kernel_updated_kv_c, kernel_updated_k_pe = (
105
+ mla.mla_ragged_paged_attention(
106
+ ql_nope_for_kernel,
107
+ q_pe_for_kernel,
108
+ new_kv_c,
109
+ new_k_pe,
110
+ cache_kv_c.copy(),
111
+ cache_k_pe.copy(),
112
+ kv_lens,
113
+ page_indices,
114
+ cu_q_lens,
115
+ distribution,
116
+ sm_scale=sm_scale,
117
+ sliding_window=sliding_window,
118
+ soft_cap=soft_cap,
119
+ num_kv_pages_per_block=num_kv_pages_per_block,
120
+ num_queries_per_block=num_queries_per_block,
121
+ vmem_limit_bytes=vmem_limit_bytes,
122
+ ))
118
123
 
119
124
  self.assertEqual(expected_out.shape,
120
125
  (total_q_len, num_heads, padded_lkv_dim))
121
126
  self.assertEqual(
122
- expected_updated_kv.shape,
123
- (total_num_pages, page_size // packing, packing, padded_kv_dim),
127
+ expected_updated_kv_c.shape,
128
+ (total_num_pages, page_size // packing, packing, padded_lkv_dim),
129
+ )
130
+ self.assertEqual(
131
+ expeceted_updated_k_pe.shape,
132
+ (total_num_pages, page_size // packing, packing, padded_r_dim),
124
133
  )
125
134
  self.assertEqual(expected_out.dtype, kv_dtype)
126
- self.assertEqual(expected_updated_kv.dtype, kv_dtype)
135
+ self.assertEqual(expected_updated_kv_c.dtype, kv_dtype)
136
+ self.assertEqual(expeceted_updated_k_pe.dtype, kv_dtype)
127
137
 
128
138
  self.assertAllClose(expected_out, kernel_out, atol=0.2, rtol=0.2)
129
- self.assertAllClose(expected_updated_kv,
130
- kernel_updated_kv,
139
+ self.assertAllClose(expected_updated_kv_c,
140
+ kernel_updated_kv_c,
141
+ atol=0.2,
142
+ rtol=0.2)
143
+ self.assertAllClose(expeceted_updated_k_pe,
144
+ kernel_updated_k_pe,
131
145
  atol=0.2,
132
146
  rtol=0.2)
133
-
134
- def test_update_kv_cache(self):
135
- lkv_dim = 4
136
- r_dim = 4
137
- padded_lkv_dim = align_to(lkv_dim, 128)
138
- padded_r_dim = align_to(r_dim, 128)
139
- kv_dtype = jnp.bfloat16
140
- new_kv_c = jnp.arange(16, dtype=kv_dtype).reshape((4, lkv_dim))
141
- new_k_pe = (jnp.arange(16, dtype=kv_dtype).reshape((4, r_dim)) + 100)
142
- total_num_pages = 2
143
- page_size = 4
144
- cache_kv_shape = mla.get_kv_cache_shape(
145
- total_num_pages,
146
- page_size,
147
- padded_lkv_dim + padded_r_dim,
148
- kv_dtype,
149
- )
150
- cache_kv = jnp.zeros(cache_kv_shape, dtype=kv_dtype)
151
-
152
- # two sequences, first with 3 tokens, second with 1 token
153
- kv_lens = jnp.array([3, 1], dtype=jnp.int32)
154
- # first seq uses page 0, second uses page 1
155
- page_indices = jnp.array([0, -1, 1, -1], dtype=jnp.int32)
156
- # three tokens for first seq, one for second
157
- cu_q_lens = jnp.array([0, 3, 4], dtype=jnp.int32)
158
- distribution = jnp.array([0, 0, 2], dtype=jnp.int32)
159
-
160
- # manually compute the expected cache
161
- padded_new_kv_c = jnp.pad(new_kv_c,
162
- ((0, 0), (0, padded_lkv_dim - lkv_dim)),
163
- constant_values=0)
164
- padded_new_k_pe = jnp.pad(new_k_pe,
165
- ((0, 0), (0, padded_r_dim - r_dim)),
166
- constant_values=0)
167
-
168
- expected_cache = cache_kv
169
- # First sequence
170
- # token 0
171
- page_idx, row, col = 0, 0, 0
172
- expected_cache = expected_cache.at[page_idx, row,
173
- col, :padded_lkv_dim].set(
174
- padded_new_kv_c[0])
175
- expected_cache = expected_cache.at[page_idx, row, col,
176
- padded_lkv_dim:padded_lkv_dim +
177
- padded_r_dim].set(
178
- padded_new_k_pe[0])
179
- # token 1
180
- page_idx, row, col = 0, 0, 1
181
- expected_cache = expected_cache.at[page_idx, row,
182
- col, :padded_lkv_dim].set(
183
- padded_new_kv_c[1])
184
- expected_cache = expected_cache.at[page_idx, row, col,
185
- padded_lkv_dim:padded_lkv_dim +
186
- padded_r_dim].set(
187
- padded_new_k_pe[1])
188
- # token 2
189
- page_idx, row, col = 0, 1, 0
190
- expected_cache = expected_cache.at[page_idx, row,
191
- col, :padded_lkv_dim].set(
192
- padded_new_kv_c[2])
193
- expected_cache = expected_cache.at[page_idx, row, col,
194
- padded_lkv_dim:padded_lkv_dim +
195
- padded_r_dim].set(
196
- padded_new_k_pe[2])
197
-
198
- # Second sequence
199
- # token 0
200
- page_idx, row, col = 1, 0, 0
201
- expected_cache = expected_cache.at[page_idx, row,
202
- col, :padded_lkv_dim].set(
203
- padded_new_kv_c[3])
204
- expected_cache = expected_cache.at[page_idx, row, col,
205
- padded_lkv_dim:padded_lkv_dim +
206
- padded_r_dim].set(
207
- padded_new_k_pe[3])
208
-
209
- updated_cache = mla.update_kv_cache(
210
- new_kv_c,
211
- new_k_pe,
212
- cache_kv,
213
- kv_lens,
214
- page_indices,
215
- cu_q_lens,
216
- distribution,
217
- )
218
-
219
- self.assertAllClose(updated_cache, expected_cache)
220
-
221
- def test_get_kv_cache_shape(self):
222
- total_num_pages = 10
223
- page_size = 16
224
- lkv_dim = 128
225
- kv_dtype = jnp.bfloat16
226
- # The calculation for the expected shape is as follows:
227
- # kv_packing is determined by the dtype, which is 2 for bfloat16.
228
- # The second dimension is page_size / kv_packing = 16 / 2 = 8
229
- # The third dimension is kv_packing = 2
230
- # The fourth dimension is lkv_dim aligned to 128, which is 128
231
- expected_shape = (10, 8, 2, 128)
232
- self.assertEqual(
233
- mla.get_kv_cache_shape(total_num_pages, page_size, lkv_dim,
234
- kv_dtype), expected_shape)
235
147
 
236
148
  def test_ragged_paged_attention_basic(self):
237
149
  dtype = jnp.bfloat16
@@ -1,5 +1,7 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
+ import functools
4
+
3
5
  import jax
4
6
  import jax.numpy as jnp
5
7
  from absl.testing import absltest, parameterized
@@ -8,7 +10,6 @@ from jax._src import test_util as jtu
8
10
  from tpu_inference.kernels.quantized_matmul import (kernel, tuned_block_sizes,
9
11
  util)
10
12
 
11
- xla_quantized_matmul = kernel.xla_quantized_matmul
12
13
  quantized_matmul_kernel = kernel.quantized_matmul_kernel
13
14
  quantize_tensor = util.quantize_tensor
14
15
  get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
@@ -16,6 +17,37 @@ get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
16
17
  jax.config.parse_flags_with_absl()
17
18
 
18
19
 
20
+ @functools.partial(jax.jit, static_argnames=["quantize_activation"])
21
+ def reference_quantized_matmul(
22
+ x: jax.Array,
23
+ w_q: jax.Array,
24
+ w_scale: jax.Array,
25
+ quantize_activation=True,
26
+ ):
27
+ if quantize_activation:
28
+ acc_dtype = jnp.float32
29
+ if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
30
+ acc_dtype = jnp.int32
31
+
32
+ x_q, x_scale = quantize_tensor(x, w_q.dtype)
33
+ out = jax.lax.dot_general(
34
+ x_q,
35
+ w_q,
36
+ dimension_numbers=(((1, ), (1, )), ((), ())),
37
+ preferred_element_type=acc_dtype,
38
+ ).astype(jnp.float32)
39
+ out *= x_scale
40
+ else:
41
+ out = jax.lax.dot_general(
42
+ x,
43
+ w_q,
44
+ dimension_numbers=(((1, ), (1, )), ((), ())),
45
+ preferred_element_type=jnp.float32,
46
+ )
47
+ out *= jnp.expand_dims(w_scale, 0)
48
+ return out.astype(x.dtype)
49
+
50
+
19
51
  @jtu.with_config(jax_numpy_dtype_promotion="standard")
20
52
  class QuantizedMatmulKernelTest(jtu.JaxTestCase):
21
53
 
@@ -62,7 +94,7 @@ class QuantizedMatmulKernelTest(jtu.JaxTestCase):
62
94
  x_q_dtype=x_q_dtype,
63
95
  tuned_value=tuned_value,
64
96
  )
65
- expected = xla_quantized_matmul(
97
+ expected = reference_quantized_matmul(
66
98
  x, w_q, w_scale, quantize_activation=quantize_activation)
67
99
 
68
100
  self.assertAllClose(output,
@@ -176,9 +176,7 @@ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
176
176
  )
177
177
  output = output[:cu_q_lens[distribution[-1]]]
178
178
 
179
- dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
180
- dtypes, "bit_width") else dtypes.itemsize_bits(
181
- jnp.dtype(kv_dtype)))
179
+ dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
182
180
  tols = {
183
181
  32: 0.15,
184
182
  16: 0.2,
@@ -162,9 +162,7 @@ class RaggedPagedAttentionKernelTest(jtu.JaxTestCase):
162
162
  )
163
163
  output = output[:cu_q_lens[distribution[-1]]]
164
164
 
165
- dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
166
- dtypes, "bit_width") else dtypes.itemsize_bits(
167
- jnp.dtype(kv_dtype)))
165
+ dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
168
166
  tols = {
169
167
  32: 0.15,
170
168
  16: 0.2,
@@ -18,7 +18,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
18
18
  ReplicatedLinearWithLoRA,
19
19
  RowParallelLinearWithLoRA)
20
20
  # yapf: enable
21
- from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
21
+ from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
22
22
  from vllm.lora.punica_wrapper import get_punica_wrapper
23
23
  from vllm.model_executor.layers.linear import (ColumnParallelLinear,
24
24
  MergedColumnParallelLinear,
@@ -499,13 +499,9 @@ def _create_random_linear_parallel_layer(layer_type, vllm_config, mesh):
499
499
  return linear, lora_linear
500
500
 
501
501
 
502
- def _get_devices():
503
- return jax.devices()
504
-
505
-
506
502
  def _create_mesh():
507
503
  axis_names = ("data", "model")
508
- devices = _get_devices()
504
+ devices = jax.devices()
509
505
  mesh_shape = (1, len(devices))
510
506
  mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices)
511
507
  return mesh
@@ -517,7 +513,7 @@ def _verify_lora_linear_layer(linear, lora_linear):
517
513
  # BaseLinearLayerWithLoRA.weight property guarantees this.
518
514
  # if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
519
515
  # So the below check will fail.
520
- if len(_get_devices()) == 1:
516
+ if len(jax.devices()) == 1:
521
517
  assert torch.equal(linear.weight.data,
522
518
  lora_linear.weight.to('cpu'))
523
519
 
@@ -29,7 +29,7 @@ def setup_vllm(num_loras: int, tp: int = 1) -> vllm.LLM:
29
29
 
30
30
 
31
31
  # For multi-chip test, we only use TP=2 because the base model Qwen/Qwen2.5-3B-Instruct has 2 kv heads and the current attention kernel requires it to be divisible by tp_size.
32
- TP = [2] if os.environ.get("TEST_LORA_TP", False) else [1]
32
+ TP = [2] if os.environ.get("USE_V6E8_QUEUE", False) else [1]
33
33
 
34
34
 
35
35
  @pytest.mark.parametrize("tp", TP)
@@ -60,7 +60,6 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
60
60
  monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
61
61
  monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
62
62
  monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
63
- monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "0")
64
63
  monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
65
64
 
66
65
  # Test SKIP_JAX_PRECOMPILE (default False)
@@ -87,82 +86,6 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
87
86
  monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
88
87
  assert envs.USE_MOE_EP_KERNEL is True
89
88
 
90
- # Test ENABLE_QUANTIZED_MATMUL_KERNEL (default False)
91
- assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is False
92
- monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "1")
93
- assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is True
94
-
95
-
96
- def test_boolean_env_vars_string_values(monkeypatch: pytest.MonkeyPatch):
97
- """Test that boolean env vars accept string values like 'True' and 'False'"""
98
-
99
- # Test NEW_MODEL_DESIGN with string "True"
100
- monkeypatch.setenv("NEW_MODEL_DESIGN", "True")
101
- assert envs.NEW_MODEL_DESIGN is True
102
-
103
- monkeypatch.setenv("NEW_MODEL_DESIGN", "true")
104
- assert envs.NEW_MODEL_DESIGN is True
105
-
106
- monkeypatch.setenv("NEW_MODEL_DESIGN", "False")
107
- assert envs.NEW_MODEL_DESIGN is False
108
-
109
- monkeypatch.setenv("NEW_MODEL_DESIGN", "false")
110
- assert envs.NEW_MODEL_DESIGN is False
111
-
112
- # Test SKIP_JAX_PRECOMPILE with string values
113
- monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "True")
114
- assert envs.SKIP_JAX_PRECOMPILE is True
115
-
116
- monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "false")
117
- assert envs.SKIP_JAX_PRECOMPILE is False
118
-
119
- # Test VLLM_XLA_CHECK_RECOMPILATION with string values
120
- monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "TRUE")
121
- assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
122
-
123
- monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "FALSE")
124
- assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
125
-
126
- # Test USE_MOE_EP_KERNEL with string values
127
- monkeypatch.setenv("USE_MOE_EP_KERNEL", "true")
128
- assert envs.USE_MOE_EP_KERNEL is True
129
-
130
- monkeypatch.setenv("USE_MOE_EP_KERNEL", "False")
131
- assert envs.USE_MOE_EP_KERNEL is False
132
-
133
-
134
- def test_boolean_env_vars_invalid_values(monkeypatch: pytest.MonkeyPatch):
135
- """Test that boolean env vars raise errors for invalid values"""
136
-
137
- # Test invalid value for NEW_MODEL_DESIGN
138
- monkeypatch.setenv("NEW_MODEL_DESIGN", "yes")
139
- with pytest.raises(
140
- ValueError,
141
- match="Invalid boolean value 'yes' for NEW_MODEL_DESIGN"):
142
- _ = envs.NEW_MODEL_DESIGN
143
-
144
- monkeypatch.setenv("NEW_MODEL_DESIGN", "2")
145
- with pytest.raises(ValueError,
146
- match="Invalid boolean value '2' for NEW_MODEL_DESIGN"):
147
- _ = envs.NEW_MODEL_DESIGN
148
-
149
- # Test invalid value for SKIP_JAX_PRECOMPILE
150
- monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "invalid")
151
- with pytest.raises(
152
- ValueError,
153
- match="Invalid boolean value 'invalid' for SKIP_JAX_PRECOMPILE"):
154
- _ = envs.SKIP_JAX_PRECOMPILE
155
-
156
-
157
- def test_boolean_env_vars_empty_string(monkeypatch: pytest.MonkeyPatch):
158
- """Test that empty string returns default value"""
159
-
160
- monkeypatch.setenv("NEW_MODEL_DESIGN", "")
161
- assert envs.NEW_MODEL_DESIGN is False # Should return default
162
-
163
- monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "")
164
- assert envs.SKIP_JAX_PRECOMPILE is False # Should return default
165
-
166
89
 
167
90
  def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
168
91
  # Ensure clean environment for integer vars by setting to defaults
@@ -256,7 +179,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
256
179
 
257
180
  def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
258
181
  monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
259
- assert envs.MODEL_IMPL_TYPE == "auto"
182
+ assert envs.MODEL_IMPL_TYPE == "flax_nnx"
260
183
 
261
184
 
262
185
  def test_cache_preserves_values_across_env_changes(
@@ -112,8 +112,6 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
112
112
  self.mesh = Mesh(jax.devices(), ('model', ))
113
113
  self.rng = jax.random.PRNGKey(0)
114
114
  self.model = SimpleModel(rngs=nnx.Rngs(0))
115
- self.model.vllm_config = MagicMock()
116
- self.model.vllm_config.model_config.use_mla = False
117
115
 
118
116
  self.qwix_config = [
119
117
  {
@@ -133,7 +131,6 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
133
131
  """Test that qwix.quantize_model is called with the correct arguments."""
134
132
  quantized_model_mock = MagicMock(spec=nnx.Module)
135
133
  mock_quantize_model.return_value = quantized_model_mock
136
- self.model.vllm_config.sharding_config.total_dp_size = 1
137
134
 
138
135
  with patch(
139
136
  "tpu_inference.models.jax.utils.quantization.quantization_utils.init_logger",
@@ -694,9 +694,9 @@ class TPUConnectorWorker:
694
694
 
695
695
  def get_uuid() -> int:
696
696
  int128 = uuid4().int
697
- # Must be less than 64-bit int, otherwise vllm output encoder would raise error.
698
- # use 50 bit to avoid GO trunk the int when doing JSon serialization
699
- return int128 >> 78
697
+ # Must be 64-bit int, otherwise vllm output encoder would raise error.
698
+ int64 = int128 >> 64
699
+ return int64
700
700
 
701
701
 
702
702
  @jax.jit
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
16
16
  DECODE_SLICES: str = ""
17
17
  SKIP_JAX_PRECOMPILE: bool = False
18
18
  VLLM_XLA_CHECK_RECOMPILATION: bool = False
19
- MODEL_IMPL_TYPE: str = "auto"
19
+ MODEL_IMPL_TYPE: str = "flax_nnx"
20
20
  NEW_MODEL_DESIGN: bool = False
21
21
  PHASED_PROFILING_DIR: str = ""
22
22
  PYTHON_TRACER_LEVEL: int = 1
@@ -24,7 +24,6 @@ if TYPE_CHECKING:
24
24
  NUM_SLICES: int = 1
25
25
  RAY_USAGE_STATS_ENABLED: str = "0"
26
26
  VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
27
- ENABLE_QUANTIZED_MATMUL_KERNEL: bool = False
28
27
 
29
28
 
30
29
  def env_with_choices(
@@ -70,34 +69,6 @@ def env_with_choices(
70
69
  return _get_validated_env
71
70
 
72
71
 
73
- def env_bool(env_name: str, default: bool = False) -> Callable[[], bool]:
74
- """
75
- Accepts both numeric strings ("0", "1") and boolean strings
76
- ("true", "false", "True", "False").
77
-
78
- Args:
79
- env_name: Name of the environment variable
80
- default: Default boolean value if not set
81
- """
82
-
83
- def _get_bool_env() -> bool:
84
- value = os.getenv(env_name)
85
- if value is None or value == "":
86
- return default
87
-
88
- value_lower = value.lower()
89
- if value_lower in ("true", "1"):
90
- return True
91
- elif value_lower in ("false", "0"):
92
- return False
93
- else:
94
- raise ValueError(
95
- f"Invalid boolean value '{value}' for {env_name}. "
96
- f"Valid options: '0', '1', 'true', 'false', 'True', 'False'.")
97
-
98
- return _get_bool_env
99
-
100
-
101
72
  environment_variables: dict[str, Callable[[], Any]] = {
102
73
  # JAX platform selection (e.g., "tpu", "cpu", "proxy")
103
74
  "JAX_PLATFORMS":
@@ -122,17 +93,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
122
93
  lambda: os.getenv("DECODE_SLICES", ""),
123
94
  # Skip JAX precompilation step during initialization
124
95
  "SKIP_JAX_PRECOMPILE":
125
- env_bool("SKIP_JAX_PRECOMPILE", default=False),
96
+ lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE") or "0")),
126
97
  # Check for XLA recompilation during execution
127
98
  "VLLM_XLA_CHECK_RECOMPILATION":
128
- env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
99
+ lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION") or "0")),
129
100
  # Model implementation type (e.g., "flax_nnx")
130
101
  "MODEL_IMPL_TYPE":
131
- env_with_choices("MODEL_IMPL_TYPE", "auto",
132
- ["auto", "vllm", "flax_nnx", "jetpack"]),
102
+ env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
103
+ ["vllm", "flax_nnx", "jetpack"]),
133
104
  # Enable new experimental model design
134
105
  "NEW_MODEL_DESIGN":
135
- env_bool("NEW_MODEL_DESIGN", default=False),
106
+ lambda: bool(int(os.getenv("NEW_MODEL_DESIGN") or "0")),
136
107
  # Directory to store phased profiling output
137
108
  "PHASED_PROFILING_DIR":
138
109
  lambda: os.getenv("PHASED_PROFILING_DIR", ""),
@@ -141,7 +112,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
141
112
  lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
142
113
  # Use custom expert-parallel kernel for MoE (Mixture of Experts)
143
114
  "USE_MOE_EP_KERNEL":
144
- env_bool("USE_MOE_EP_KERNEL", default=False),
115
+ lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL") or "0")),
145
116
  # Number of TPU slices for multi-slice mesh
146
117
  "NUM_SLICES":
147
118
  lambda: int(os.getenv("NUM_SLICES") or "1"),
@@ -151,8 +122,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
151
122
  # Ray compiled DAG channel type for TPU
152
123
  "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
153
124
  env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
154
- "ENABLE_QUANTIZED_MATMUL_KERNEL":
155
- lambda: bool(int(os.getenv("ENABLE_QUANTIZED_MATMUL_KERNEL") or "0")),
156
125
  }
157
126
 
158
127
 
@@ -145,9 +145,6 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
145
145
  device_str: node['Resources'][device_str]
146
146
  } for node in ray_nodes]
147
147
  else:
148
- assert pp_size == len(
149
- ray_nodes
150
- ), f"Cannot use PP across hosts, please set --pipeline-parallel-size to 1 or {len(ray_nodes)}"
151
148
  num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
152
149
  placement_group_specs = [{
153
150
  device_str: num_devices_per_pp_rank