tpu-inference 0.13.2.dev20260104__tar.gz → 0.13.2rc3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (278) hide show
  1. {tpu_inference-0.13.2.dev20260104/tpu_inference.egg-info → tpu_inference-0.13.2rc3}/PKG-INFO +1 -1
  2. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/e2e/test_speculative_decoding.py +2 -2
  3. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/vllm/test_awq.py +5 -6
  4. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/vllm/test_compressed_tensors_moe.py +3 -0
  5. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +9 -32
  6. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +4 -6
  7. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/vllm/test_mxfp4.py +5 -13
  8. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/vllm/test_unquantized.py +16 -27
  9. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/lora/test_layers.py +3 -5
  10. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/executors/ray_distributed_executor.py +3 -3
  11. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/common/quantization.py +2 -14
  12. tpu_inference-0.13.2.dev20260104/tpu_inference/layers/common/fused_moe_gmm.py → tpu_inference-0.13.2rc3/tpu_inference/layers/vllm/fused_moe.py +1 -1
  13. tpu_inference-0.13.2rc3/tpu_inference/layers/vllm/linear_common.py +221 -0
  14. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/vllm/quantization/__init__.py +3 -3
  15. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/vllm/quantization/awq.py +81 -81
  16. tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/quantization/configs.py → tpu_inference-0.13.2rc3/tpu_inference/layers/vllm/quantization/common.py +15 -12
  17. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +5 -5
  18. tpu_inference-0.13.2rc3/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  19. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +91 -97
  20. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +43 -65
  21. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/vllm/quantization/fp8.py +5 -6
  22. tpu_inference-0.13.2rc3/tpu_inference/layers/vllm/quantization/mxfp4.py +410 -0
  23. tpu_inference-0.13.2rc3/tpu_inference/layers/vllm/quantization/unquantized.py +428 -0
  24. tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/process_weights/cleanup_sharding.py → tpu_inference-0.13.2rc3/tpu_inference/layers/vllm/sharding.py +12 -4
  25. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/common/model_loader.py +1 -6
  26. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/vllm/vllm_model_wrapper.py +1 -2
  27. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/platforms/tpu_platform.py +7 -0
  28. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/runner/compilation_manager.py +4 -10
  29. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/runner/lora_utils.py +1 -2
  30. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/runner/multimodal_manager.py +1 -1
  31. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3/tpu_inference.egg-info}/PKG-INFO +1 -1
  32. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference.egg-info/SOURCES.txt +3 -8
  33. tpu_inference-0.13.2.dev20260104/tpu_inference/layers/common/utils.py +0 -94
  34. tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/fused_moe.py +0 -114
  35. tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/linear.py +0 -64
  36. tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +0 -369
  37. tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/process_weights/linear_weights.py +0 -174
  38. tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +0 -199
  39. tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/quantization/mxfp4.py +0 -225
  40. tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/quantization/unquantized.py +0 -298
  41. tpu_inference-0.13.2.dev20260104/tpu_inference/worker/__init__.py +0 -13
  42. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/LICENSE +0 -0
  43. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/MANIFEST.in +0 -0
  44. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/README.md +0 -0
  45. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/pyproject.toml +0 -0
  46. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/requirements.txt +0 -0
  47. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/setup.cfg +0 -0
  48. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/setup.py +0 -0
  49. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/__init__.py +0 -0
  50. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/core/__init__.py +0 -0
  51. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/core/test_core_tpu.py +0 -0
  52. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/core/test_disagg_executor.py +0 -0
  53. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/core/test_disagg_utils.py +0 -0
  54. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/core/test_dp_scheduler.py +0 -0
  55. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/core/test_init.py +0 -0
  56. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/distributed/__init__.py +0 -0
  57. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/distributed/test_distributed_utils.py +0 -0
  58. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/distributed/test_tpu_connector.py +0 -0
  59. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/e2e/__init__.py +0 -0
  60. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/e2e/test_async_scheduler.py +0 -0
  61. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/e2e/test_data_parallel.py +0 -0
  62. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/e2e/test_hybrid_kvcache.py +0 -0
  63. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/e2e/test_local_disagg.py +0 -0
  64. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/e2e/test_model_loader.py +0 -0
  65. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/e2e/test_multi_modal_inference.py +0 -0
  66. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/e2e/test_pipeline_parallel.py +0 -0
  67. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/e2e/test_runai_model_streamer_loader.py +0 -0
  68. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/e2e/test_sampling_params.py +0 -0
  69. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/e2e/test_structured_decoding.py +0 -0
  70. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/executors/__init__.py +0 -0
  71. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/executors/test_ray_distributed_executor.py +0 -0
  72. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/experimental/__init__.py +0 -0
  73. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/experimental/test_llama3_jax_stashed.py +0 -0
  74. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/kernels/__init__.py +0 -0
  75. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/kernels/collectives/__init__.py +0 -0
  76. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/kernels/collectives/all_gather_matmul_kernel_test.py +0 -0
  77. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/kernels/fused_moe_v1_test.py +0 -0
  78. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/kernels/gmm_test.py +0 -0
  79. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/kernels/mla_v1_test.py +0 -0
  80. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/kernels/quantized_matmul_kernel_test.py +0 -0
  81. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/kernels/ragged_kv_cache_update_v2_test.py +0 -0
  82. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/kernels/ragged_paged_attention_kernel_v2_test.py +0 -0
  83. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +0 -0
  84. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/kernels/ragged_paged_attention_kernel_v3_test.py +0 -0
  85. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/__init__.py +0 -0
  86. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/common/__init__.py +0 -0
  87. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/common/test_attention_interface.py +0 -0
  88. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/common/test_quantization.py +0 -0
  89. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/__init__.py +0 -0
  90. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/attention/__init__.py +0 -0
  91. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/attention/test_common_attention.py +0 -0
  92. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/attention/test_deepseek_v3_attention.py +0 -0
  93. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/attention/test_llama4_attention.py +0 -0
  94. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/moe/__init__.py +0 -0
  95. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/moe/test_deepseek_moe.py +0 -0
  96. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/sample/__init__.py +0 -0
  97. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/sample/test_rejection_sampler.py +0 -0
  98. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/sample/test_sampling.py +0 -0
  99. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/sample/test_sampling_metadata.py +0 -0
  100. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/test_layers.py +0 -0
  101. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/test_qwix.py +0 -0
  102. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/test_rope.py +0 -0
  103. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/test_sharding.py +0 -0
  104. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/jax/test_transformer_block.py +0 -0
  105. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/vllm/__init__.py +0 -0
  106. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/vllm/test_attention.py +0 -0
  107. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/vllm/test_fp8.py +0 -0
  108. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/layers/vllm/utils.py +0 -0
  109. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/lora/__init__.py +0 -0
  110. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/lora/conftest.py +0 -0
  111. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/lora/test_bgmv.py +0 -0
  112. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/lora/test_lora.py +0 -0
  113. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/lora/test_lora_perf.py +0 -0
  114. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/lora/utils.py +0 -0
  115. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/__init__.py +0 -0
  116. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/common/__init__.py +0 -0
  117. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/common/test_model_loader.py +0 -0
  118. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/jax/__init__.py +0 -0
  119. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/jax/test_deepseek_v3.py +0 -0
  120. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/jax/test_llama3.py +0 -0
  121. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/jax/test_llama4.py +0 -0
  122. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/jax/test_llama_eagle3.py +0 -0
  123. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/jax/test_llama_guard_4.py +0 -0
  124. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/jax/test_qwen2.py +0 -0
  125. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/jax/test_qwen2_5_vl.py +0 -0
  126. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/jax/test_qwen3.py +0 -0
  127. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/jax/test_weight_loading.py +0 -0
  128. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/jax/utils/__init__.py +0 -0
  129. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/models/jax/utils/test_multi_modal_utils.py +0 -0
  130. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/platforms/__init__.py +0 -0
  131. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/platforms/test_tpu_platform.py +0 -0
  132. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/runner/__init__.py +0 -0
  133. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/runner/test_block_table.py +0 -0
  134. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/runner/test_input_batch.py +0 -0
  135. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/runner/test_kv_cache.py +0 -0
  136. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/runner/test_kv_cache_manager.py +0 -0
  137. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/runner/test_multimodal_manager.py +0 -0
  138. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/runner/test_persistent_batch_manager.py +0 -0
  139. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/runner/test_speculative_decoding_manager.py +0 -0
  140. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/runner/test_structured_decoding_manager.py +0 -0
  141. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/runner/test_tpu_runner.py +0 -0
  142. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/runner/test_tpu_runner_dp.py +0 -0
  143. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/runner/test_tpu_runner_mesh.py +0 -0
  144. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/runner/test_utils.py +0 -0
  145. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/spec_decode/__init__.py +0 -0
  146. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/spec_decode/test_eagle3.py +0 -0
  147. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/test_base.py +0 -0
  148. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/test_envs.py +0 -0
  149. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/test_tpu_info.py +0 -0
  150. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/test_utils.py +0 -0
  151. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/worker/__init__.py +0 -0
  152. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tests/worker/tpu_worker_test.py +0 -0
  153. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/__init__.py +0 -0
  154. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/core/__init__.py +0 -0
  155. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/core/core_tpu.py +0 -0
  156. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/core/disagg_executor.py +0 -0
  157. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/core/disagg_utils.py +0 -0
  158. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/core/sched/__init__.py +0 -0
  159. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/core/sched/dp_scheduler.py +0 -0
  160. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/distributed/__init__.py +0 -0
  161. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/distributed/jax_parallel_state.py +0 -0
  162. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/distributed/tpu_connector.py +0 -0
  163. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/distributed/utils.py +0 -0
  164. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/env_override.py +0 -0
  165. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/envs.py +0 -0
  166. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/executors/__init__.py +0 -0
  167. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/experimental/__init__.py +0 -0
  168. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/experimental/llama3_jax_stashed.py +0 -0
  169. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/__init__.py +0 -0
  170. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/collectives/__init__.py +0 -0
  171. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/collectives/all_gather_matmul.py +0 -0
  172. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +0 -0
  173. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/collectives/util.py +0 -0
  174. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/flash_attention/__init__.py +0 -0
  175. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/flash_attention/kernel.py +0 -0
  176. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/fused_moe/__init__.py +0 -0
  177. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  178. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/fused_moe/v1/kernel.py +0 -0
  179. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/megablox/__init__.py +0 -0
  180. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/megablox/common.py +0 -0
  181. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/megablox/gmm.py +0 -0
  182. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/mla/__init__.py +0 -0
  183. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/mla/v1/__init__.py +0 -0
  184. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/mla/v1/kernel.py +0 -0
  185. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  186. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/quantized_matmul/kernel.py +0 -0
  187. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +0 -0
  188. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/quantized_matmul/util.py +0 -0
  189. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  190. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  191. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +0 -0
  192. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +0 -0
  193. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +0 -0
  194. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  195. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +0 -0
  196. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +0 -0
  197. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -0
  198. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +0 -0
  199. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/kernels/ragged_paged_attention/v3/util.py +0 -0
  200. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/__init__.py +0 -0
  201. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/common/__init__.py +0 -0
  202. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/common/attention_interface.py +0 -0
  203. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/common/attention_metadata.py +0 -0
  204. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/common/binary_search.py +0 -0
  205. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/common/quant_methods.py +0 -0
  206. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/common/sharding.py +0 -0
  207. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/__init__.py +0 -0
  208. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/attention/__init__.py +0 -0
  209. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/attention/attention.py +0 -0
  210. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/attention/deepseek_v3_attention.py +0 -0
  211. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/attention/gpt_oss_attention.py +0 -0
  212. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/attention/llama4_attention.py +0 -0
  213. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/base.py +0 -0
  214. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/constants.py +0 -0
  215. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/layers.py +0 -0
  216. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/misc.py +0 -0
  217. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/moe/__init__.py +0 -0
  218. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/moe/deepseek_v3_moe.py +0 -0
  219. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/moe/gpt_oss_moe.py +0 -0
  220. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/moe/moe.py +0 -0
  221. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/pp_utils.py +0 -0
  222. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/rope.py +0 -0
  223. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/rope_interface.py +0 -0
  224. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/sample/__init__.py +0 -0
  225. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/sample/rejection_sampler.py +0 -0
  226. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/sample/sampling.py +0 -0
  227. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/sample/sampling_metadata.py +0 -0
  228. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/jax/transformer_block.py +0 -0
  229. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/vllm/__init__.py +0 -0
  230. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/layers/vllm/attention.py +0 -0
  231. {tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/process_weights → tpu_inference-0.13.2rc3/tpu_inference/layers/vllm/quantization/compressed_tensors}/__init__.py +0 -0
  232. {tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/quantization/compressed_tensors → tpu_inference-0.13.2rc3/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes}/__init__.py +0 -0
  233. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/logger.py +0 -0
  234. {tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes → tpu_inference-0.13.2rc3/tpu_inference/lora}/__init__.py +0 -0
  235. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/lora/torch_lora_ops.py +0 -0
  236. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/lora/torch_punica_tpu.py +0 -0
  237. {tpu_inference-0.13.2.dev20260104/tpu_inference/lora → tpu_inference-0.13.2rc3/tpu_inference/models}/__init__.py +0 -0
  238. {tpu_inference-0.13.2.dev20260104/tpu_inference/models → tpu_inference-0.13.2rc3/tpu_inference/models/common}/__init__.py +0 -0
  239. {tpu_inference-0.13.2.dev20260104/tpu_inference/models/common → tpu_inference-0.13.2rc3/tpu_inference/models/jax}/__init__.py +0 -0
  240. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/jax/deepseek_v3.py +0 -0
  241. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/jax/gpt_oss.py +0 -0
  242. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/jax/jax_intermediate_tensor.py +0 -0
  243. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/jax/llama3.py +0 -0
  244. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/jax/llama4.py +0 -0
  245. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/jax/llama_eagle3.py +0 -0
  246. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/jax/llama_guard_4.py +0 -0
  247. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/jax/qwen2.py +0 -0
  248. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/jax/qwen2_5_vl.py +0 -0
  249. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/jax/qwen3.py +0 -0
  250. {tpu_inference-0.13.2.dev20260104/tpu_inference/models/jax → tpu_inference-0.13.2rc3/tpu_inference/models/jax/utils}/__init__.py +0 -0
  251. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/jax/utils/file_utils.py +0 -0
  252. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/jax/utils/multi_modal_utils.py +0 -0
  253. {tpu_inference-0.13.2.dev20260104/tpu_inference/models/jax/utils → tpu_inference-0.13.2rc3/tpu_inference/models/jax/utils/qwix}/__init__.py +0 -0
  254. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/jax/utils/qwix/qwix_utils.py +0 -0
  255. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/jax/utils/weight_utils.py +0 -0
  256. {tpu_inference-0.13.2.dev20260104/tpu_inference/models/jax/utils/qwix → tpu_inference-0.13.2rc3/tpu_inference/models/vllm}/__init__.py +0 -0
  257. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/models/vllm/vllm_model_wrapper_context.py +0 -0
  258. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/platforms/__init__.py +0 -0
  259. {tpu_inference-0.13.2.dev20260104/tpu_inference/models/vllm → tpu_inference-0.13.2rc3/tpu_inference/runner}/__init__.py +0 -0
  260. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/runner/block_table.py +0 -0
  261. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/runner/input_batch.py +0 -0
  262. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/runner/kv_cache.py +0 -0
  263. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/runner/kv_cache_manager.py +0 -0
  264. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/runner/persistent_batch_manager.py +0 -0
  265. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/runner/speculative_decoding_manager.py +0 -0
  266. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/runner/structured_decoding_manager.py +0 -0
  267. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/runner/tpu_runner.py +0 -0
  268. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/runner/utils.py +0 -0
  269. {tpu_inference-0.13.2.dev20260104/tpu_inference/runner → tpu_inference-0.13.2rc3/tpu_inference/spec_decode}/__init__.py +0 -0
  270. {tpu_inference-0.13.2.dev20260104/tpu_inference/spec_decode → tpu_inference-0.13.2rc3/tpu_inference/spec_decode/jax}/__init__.py +0 -0
  271. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/spec_decode/jax/eagle3.py +0 -0
  272. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/tpu_info.py +0 -0
  273. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/utils.py +0 -0
  274. {tpu_inference-0.13.2.dev20260104/tpu_inference/spec_decode/jax → tpu_inference-0.13.2rc3/tpu_inference/worker}/__init__.py +0 -0
  275. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference/worker/tpu_worker.py +0 -0
  276. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference.egg-info/dependency_links.txt +0 -0
  277. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference.egg-info/requires.txt +0 -0
  278. {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc3}/tpu_inference.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpu_inference
3
- Version: 0.13.2.dev20260104
3
+ Version: 0.13.2rc3
4
4
  Author: tpu_inference Contributors
5
5
  Classifier: Development Status :: 3 - Alpha
6
6
  Classifier: Intended Audience :: Developers
@@ -271,7 +271,7 @@ def test_ngram_performance_random(
271
271
  "prompt_lookup_max": 2,
272
272
  "prompt_lookup_min": 2,
273
273
  "num_speculative_tokens": 4,
274
- }, 1.2 if _is_v7x() else 3.0)
274
+ }, 1.5 if _is_v7x() else 3.0)
275
275
 
276
276
 
277
277
  def test_eagle3_correctness(
@@ -308,4 +308,4 @@ def test_eagle3_performance(
308
308
  "model": "unkmaster/EAGLE3-LLaMA3.1-Instruct-8B",
309
309
  "num_speculative_tokens": 2,
310
310
  "draft_tensor_parallel_size": 1
311
- }, 0.6 if _is_v7x() else 1.8)
311
+ }, 1.2 if _is_v7x() else 1.8)
@@ -39,8 +39,7 @@ from vllm.scalar_type import scalar_types
39
39
  from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
40
40
  from tpu_inference.layers.vllm.quantization.awq import (VllmAWQConfig,
41
41
  VllmAWQLinearMethod)
42
- from tpu_inference.layers.vllm.quantization.configs import \
43
- VllmQuantLinearConfig
42
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
44
43
 
45
44
  from . import utils as test_utils
46
45
 
@@ -104,8 +103,8 @@ def return_ref_and_layer_output(
104
103
  assert isinstance(quant_method, VllmAWQLinearMethod)
105
104
  quant_config = quant_method.quant_config
106
105
  assert isinstance(quant_config, VllmAWQConfig)
107
- jax_config = quant_method.linear_config
108
- assert isinstance(jax_config, VllmQuantLinearConfig)
106
+ jax_config = quant_method.jax_config
107
+ assert isinstance(jax_config, JaxCommonLinearConfig)
109
108
 
110
109
  input_tensor = torch.rand(
111
110
  batch_size, layer.input_size, dtype=torch.bfloat16) / 10
@@ -135,8 +134,8 @@ def initialize_and_return_layer_weights(layer: torch.nn.Module):
135
134
  assert isinstance(quant_method, VllmAWQLinearMethod)
136
135
  quant_config = quant_method.quant_config
137
136
  assert isinstance(quant_config, VllmAWQConfig)
138
- jax_config = quant_method.linear_config
139
- assert isinstance(jax_config, VllmQuantLinearConfig)
137
+ jax_config = quant_method.jax_config
138
+ assert isinstance(jax_config, JaxCommonLinearConfig)
140
139
 
141
140
  # torch.rand returns value in the range of [0, 1). We subtract by 0.2 to
142
141
  # simulate asymmetry
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import os
15
16
  import tempfile
16
17
 
17
18
  import jax.numpy as jnp
@@ -42,6 +43,8 @@ from . import utils as test_utils
42
43
 
43
44
  P = PartitionSpec
44
45
 
46
+ os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
47
+
45
48
  MODEL = 'BCCard/Qwen3-30B-A3B-FP8-Dynamic'
46
49
 
47
50
 
@@ -16,7 +16,6 @@ import tempfile
16
16
  from typing import Optional
17
17
 
18
18
  import jax
19
- import jax.numpy as jnp
20
19
  import pytest
21
20
  import torch
22
21
  import torchax
@@ -37,15 +36,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
37
36
  CompressedTensorsLinearMethod
38
37
  from vllm.model_executor.model_loader import get_model as vllm_get_model
39
38
 
40
- from tpu_inference.layers.common.quantization import (dequantize_tensor,
41
- quantize_tensor)
42
39
  from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
40
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
43
41
  from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
44
42
  VllmCompressedTensorsConfig
45
- from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
46
- VllmCompressedTensorsW8A8Fp8
47
- from tpu_inference.layers.vllm.quantization.configs import \
48
- VllmQuantLinearConfig
43
+ from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import (
44
+ VllmCompressedTensorsW8A8Fp8, requantize_with_max_scale)
49
45
 
50
46
  from . import utils as test_utils
51
47
 
@@ -102,8 +98,8 @@ def return_ref_and_layer_output(layer: torch.nn.Module, batch_size: int = 16):
102
98
  assert isinstance(layer, LinearBase)
103
99
  scheme = layer.scheme
104
100
  assert isinstance(scheme, VllmCompressedTensorsW8A8Fp8)
105
- quant_config = scheme.linear_config
106
- assert isinstance(quant_config, VllmQuantLinearConfig)
101
+ quant_config = scheme.jax_config
102
+ assert isinstance(quant_config, JaxCommonLinearConfig)
107
103
  quant_method = layer.quant_method
108
104
  assert isinstance(quant_method, CompressedTensorsLinearMethod)
109
105
  per_tensor = scheme.strategy == QuantizationStrategy.TENSOR
@@ -118,27 +114,8 @@ def return_ref_and_layer_output(layer: torch.nn.Module, batch_size: int = 16):
118
114
  # For per_tensor with merged layers, vLLM requenzites them so all merged
119
115
  # layers shared the same scale values.
120
116
  if per_tensor:
121
- dtype = weight.dtype
122
-
123
- weight = t2j(weight)
124
- weight_scale = t2j(weight_scale)
125
- weights = []
126
- start = 0
127
- # Multiple weights may have been concatenated. Loop through
128
- # each weight and perform dequantization.
129
- for i, output_size in enumerate(quant_config.output_sizes):
130
- end = start + output_size
131
- weights.append(
132
- dequantize_tensor(weight[start:end], weight_scale[i]))
133
- start = end
134
- weight = jnp.concat(weights, axis=0)
135
- weight, weight_scale = quantize_tensor(
136
- jnp.float8_e4m3fn,
137
- weight,
138
- None,
139
- )
140
- weight = j2t(weight.astype(jnp.float32)).to(dtype)
141
- weight_scale = j2t(weight_scale)
117
+ weight_scale, weight = requantize_with_max_scale(
118
+ layer.weight, layer.weight_scale, quant_config.output_sizes)
142
119
  if input_scale is not None:
143
120
  input_scale = input_scale.max()
144
121
 
@@ -174,8 +151,8 @@ def initialize_layer_weights(layer: torch.nn.Module):
174
151
  assert isinstance(layer, LinearBase)
175
152
  scheme = layer.scheme
176
153
  assert isinstance(scheme, VllmCompressedTensorsW8A8Fp8)
177
- quant_config = scheme.linear_config
178
- assert isinstance(quant_config, VllmQuantLinearConfig)
154
+ quant_config = scheme.jax_config
155
+ assert isinstance(quant_config, JaxCommonLinearConfig)
179
156
  per_tensor = scheme.strategy == QuantizationStrategy.TENSOR
180
157
 
181
158
  weight_list = []
@@ -185,7 +185,7 @@ def test_row_parallel_linear(model, bias, num_devices, enable_sp,
185
185
  if bias:
186
186
  jax_row_linear.bias.data = bias_data
187
187
 
188
- input_tensor = torch.rand(10, jax_row_linear.input_size, dtype=dtype) / 10
188
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
189
189
  input_tensor = input_tensor.to('cpu')
190
190
 
191
191
  jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
@@ -259,8 +259,7 @@ def test_column_parallel_linear(model, bias, num_devices, enable_sp,
259
259
  if bias:
260
260
  jax_column_linear.bias.data = bias_data
261
261
 
262
- input_tensor = torch.rand(10, jax_column_linear.input_size,
263
- dtype=dtype) / 10
262
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
264
263
  input_tensor = input_tensor.to('cpu')
265
264
 
266
265
  jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
@@ -339,7 +338,7 @@ def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
339
338
  if bias:
340
339
  jax_qkv_linear.bias.data = bias_data
341
340
 
342
- input_tensor = torch.rand(10, jax_qkv_linear.input_size, dtype=dtype) / 10
341
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
343
342
  input_tensor = input_tensor.to('cpu')
344
343
 
345
344
  jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
@@ -415,8 +414,7 @@ def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
415
414
  if bias:
416
415
  jax_merged_column_linear.bias.data = bias_data
417
416
 
418
- input_tensor = torch.rand(
419
- 10, jax_merged_column_linear.input_size, dtype=dtype) / 10
417
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
420
418
  input_tensor = input_tensor.to('cpu')
421
419
 
422
420
  jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import tempfile
16
- from unittest import mock
17
16
 
18
17
  import jax
19
18
  import jax.numpy as jnp
@@ -30,7 +29,6 @@ from vllm.engine.arg_utils import EngineArgs
30
29
  from vllm.forward_context import set_forward_context
31
30
  from vllm.model_executor.layers.fused_moe.layer import FusedMoE
32
31
 
33
- from tpu_inference.layers.vllm.fused_moe import FusedMoEBackend
34
32
  from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
35
33
  from tpu_inference.layers.vllm.quantization.mxfp4 import (VllmMxfp4Config,
36
34
  VllmMxfp4MoEMethod)
@@ -162,8 +160,6 @@ def test_mxfp4_fused_moe(num_devices, num_tokens, intermediate_size,
162
160
  )
163
161
  vllm_config = engine_args.create_engine_config()
164
162
  vllm_config.model_config.dtype = dtype
165
- vllm_config.parallel_config = ParallelConfig(
166
- tensor_parallel_size=mesh.devices.size, enable_expert_parallel=use_ep)
167
163
 
168
164
  quant_config = get_tpu_quantization_config(vllm_config, mesh)
169
165
  with set_current_vllm_config(vllm_config):
@@ -194,16 +190,13 @@ def test_mxfp4_fused_moe(num_devices, num_tokens, intermediate_size,
194
190
 
195
191
  with torchax.default_env(), set_forward_context(None, vllm_config):
196
192
  assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod)
197
- if use_ep:
198
- assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.GMM_EP
199
- else:
200
- assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.GMM_TP
201
193
 
202
194
  jax_a = a.to('jax')
203
195
  score = score.to('jax')
204
196
 
205
197
  vllm_fused_moe.quant_method.process_weights_after_loading(
206
198
  vllm_fused_moe)
199
+
207
200
  actual = vllm_fused_moe(jax_a, score)
208
201
 
209
202
  torch.testing.assert_close(expected,
@@ -220,7 +213,6 @@ def test_mxfp4_fused_moe(num_devices, num_tokens, intermediate_size,
220
213
  @pytest.mark.parametrize("num_experts", [8])
221
214
  @pytest.mark.parametrize("topk", [2])
222
215
  @pytest.mark.parametrize("enable_attn_dp", [False, True])
223
- @mock.patch("os.environ", {"USE_MOE_EP_KERNEL": "1"})
224
216
  def test_mxfp4_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
225
217
  hidden_size, num_experts, topk,
226
218
  enable_attn_dp):
@@ -261,7 +253,7 @@ def test_mxfp4_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
261
253
  vllm_config = engine_args.create_engine_config()
262
254
  vllm_config.model_config.dtype = dtype
263
255
  vllm_config.parallel_config = ParallelConfig(
264
- tensor_parallel_size=mesh.devices.size, enable_expert_parallel=True)
256
+ tensor_parallel_size=mesh.devices.size)
265
257
 
266
258
  quant_config = get_tpu_quantization_config(vllm_config, mesh)
267
259
  with set_current_vllm_config(vllm_config):
@@ -293,14 +285,14 @@ def test_mxfp4_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
293
285
 
294
286
  with torchax.default_env(), set_forward_context(None, vllm_config):
295
287
  assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod)
296
- assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.FUSED_MOE
297
288
 
298
289
  jax_a = a.to('jax')
299
290
  score = score.to('jax')
300
291
 
292
+ vllm_fused_moe.quant_method.use_kernel = True
301
293
  vllm_fused_moe.quant_method.process_weights_after_loading(
302
294
  vllm_fused_moe)
303
- vllm_fused_moe.quant_method.extra_backend_kwargs.update({
295
+ vllm_fused_moe.quant_method.block_size = {
304
296
  "bt": 32,
305
297
  "bf": 512,
306
298
  "bd1": 1024,
@@ -309,7 +301,7 @@ def test_mxfp4_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
309
301
  "bfc": 512,
310
302
  "bd1c": 1024,
311
303
  "bd2c": 1024,
312
- })
304
+ }
313
305
 
314
306
  actual = vllm_fused_moe(jax_a, score)
315
307
 
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import tempfile
16
- from unittest import mock
17
16
 
18
17
  import jax
19
18
  import pytest
@@ -36,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
36
35
  RowParallelLinear)
37
36
  from vllm.model_executor.model_loader import get_model as vllm_get_model
38
37
 
39
- from tpu_inference.layers.vllm.fused_moe import FusedMoEBackend
40
38
  from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
41
39
  from tpu_inference.layers.vllm.quantization.unquantized import (
42
40
  VllmUnquantizedConfig, VllmUnquantizedFusedMoEMethod,
@@ -141,6 +139,9 @@ def test_row_parallel_linear(model, bias, num_devices, enable_sp,
141
139
  vllm_config = engine_args.create_engine_config()
142
140
  vllm_config.compilation_config.pass_config.enable_sp = enable_sp
143
141
 
142
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
143
+ input_tensor = input_tensor.to('cpu')
144
+
144
145
  with set_current_vllm_config(vllm_config):
145
146
  row_linear = RowParallelLinear(
146
147
  input_size=4096,
@@ -150,9 +151,6 @@ def test_row_parallel_linear(model, bias, num_devices, enable_sp,
150
151
  return_bias=False,
151
152
  )
152
153
 
153
- input_tensor = torch.rand(10, row_linear.input_size, dtype=dtype) / 10
154
- input_tensor = input_tensor.to('cpu')
155
-
156
154
  weight_data = torch.rand_like(row_linear.weight.data) / 10
157
155
  if bias:
158
156
  bias_data = torch.rand_like(row_linear.bias.data)
@@ -218,6 +216,9 @@ def test_column_parallel_linear(model, bias, num_devices, enable_sp,
218
216
  vllm_config = engine_args.create_engine_config()
219
217
  vllm_config.compilation_config.pass_config.enable_sp = enable_sp
220
218
 
219
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
220
+ input_tensor = input_tensor.to('cpu')
221
+
221
222
  with set_current_vllm_config(vllm_config):
222
223
  column_linear = ColumnParallelLinear(
223
224
  input_size=4096,
@@ -227,9 +228,6 @@ def test_column_parallel_linear(model, bias, num_devices, enable_sp,
227
228
  return_bias=False,
228
229
  )
229
230
 
230
- input_tensor = torch.rand(10, column_linear.input_size, dtype=dtype) / 10
231
- input_tensor = input_tensor.to('cpu')
232
-
233
231
  weight_data = torch.rand_like(column_linear.weight.data) / 10
234
232
  if bias:
235
233
  bias_data = torch.rand_like(column_linear.bias.data)
@@ -295,6 +293,9 @@ def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
295
293
  vllm_config = engine_args.create_engine_config()
296
294
  vllm_config.compilation_config.pass_config.enable_sp = enable_sp
297
295
 
296
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
297
+ input_tensor = input_tensor.to('cpu')
298
+
298
299
  with set_current_vllm_config(vllm_config):
299
300
  qkv_linear = QKVParallelLinear(
300
301
  hidden_size=4096,
@@ -306,9 +307,6 @@ def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
306
307
  return_bias=False,
307
308
  )
308
309
 
309
- input_tensor = torch.rand(10, qkv_linear.input_size, dtype=dtype) / 10
310
- input_tensor = input_tensor.to('cpu')
311
-
312
310
  weight_data = torch.rand_like(qkv_linear.weight.data) / 10
313
311
  if bias:
314
312
  bias_data = torch.rand_like(qkv_linear.bias.data)
@@ -377,6 +375,9 @@ def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
377
375
  vllm_config = engine_args.create_engine_config()
378
376
  vllm_config.compilation_config.pass_config.enable_sp = enable_sp
379
377
 
378
+ input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
379
+ input_tensor = input_tensor.to('cpu')
380
+
380
381
  # Call vLLM code
381
382
  with set_current_vllm_config(vllm_config):
382
383
  merged_column_linear = MergedColumnParallelLinear(
@@ -387,10 +388,6 @@ def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
387
388
  return_bias=False,
388
389
  )
389
390
 
390
- input_tensor = torch.rand(10, merged_column_linear.input_size,
391
- dtype=dtype) / 10
392
- input_tensor = input_tensor.to('cpu')
393
-
394
391
  weight_data = torch.rand_like(merged_column_linear.weight.data) / 10
395
392
  if bias:
396
393
  bias_data = torch.rand_like(merged_column_linear.bias.data)
@@ -478,8 +475,6 @@ def test_fused_moe(use_ep, num_devices, num_tokens, intermediate_size,
478
475
  )
479
476
  vllm_config = engine_args.create_engine_config()
480
477
  vllm_config.model_config.dtype = dtype
481
- vllm_config.parallel_config = ParallelConfig(
482
- tensor_parallel_size=mesh.devices.size, enable_expert_parallel=use_ep)
483
478
 
484
479
  quant_config = get_tpu_quantization_config(vllm_config, mesh)
485
480
  with set_current_vllm_config(vllm_config):
@@ -511,10 +506,6 @@ def test_fused_moe(use_ep, num_devices, num_tokens, intermediate_size,
511
506
  with torchax.default_env(), set_forward_context(None, vllm_config):
512
507
  assert isinstance(vllm_fused_moe.quant_method,
513
508
  VllmUnquantizedFusedMoEMethod)
514
- if use_ep:
515
- assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.GMM_EP
516
- else:
517
- assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.GMM_TP
518
509
 
519
510
  jax_a = a.to('jax')
520
511
  score = score.to('jax')
@@ -538,7 +529,6 @@ def test_fused_moe(use_ep, num_devices, num_tokens, intermediate_size,
538
529
  @pytest.mark.parametrize("topk", [8])
539
530
  @pytest.mark.parametrize("has_bias", [False, True])
540
531
  @pytest.mark.parametrize("enable_attn_dp", [False, True])
541
- @mock.patch("os.environ", {"USE_MOE_EP_KERNEL": "1"})
542
532
  def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
543
533
  hidden_size, num_experts, topk, has_bias,
544
534
  enable_attn_dp):
@@ -602,7 +592,7 @@ def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
602
592
  vllm_config = engine_args.create_engine_config()
603
593
  vllm_config.model_config.dtype = dtype
604
594
  vllm_config.parallel_config = ParallelConfig(
605
- tensor_parallel_size=mesh.devices.size, enable_expert_parallel=True)
595
+ tensor_parallel_size=mesh.devices.size)
606
596
 
607
597
  quant_config = get_tpu_quantization_config(vllm_config, mesh)
608
598
  with set_current_vllm_config(vllm_config):
@@ -619,6 +609,7 @@ def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
619
609
  has_bias=has_bias,
620
610
  )
621
611
  vllm_fused_moe.moe_parallel_config.use_ep = True
612
+ vllm_fused_moe.quant_method.use_kernel = True
622
613
 
623
614
  vllm_fused_moe.w13_weight.data = w1
624
615
  vllm_fused_moe.w2_weight.data = w2
@@ -634,14 +625,12 @@ def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
634
625
  with torchax.default_env(), set_forward_context(None, vllm_config):
635
626
  assert isinstance(vllm_fused_moe.quant_method,
636
627
  VllmUnquantizedFusedMoEMethod)
637
- assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.FUSED_MOE
638
-
639
628
  jax_a = a.to('jax')
640
629
  score = score.to('jax')
641
630
 
642
631
  vllm_fused_moe.quant_method.process_weights_after_loading(
643
632
  vllm_fused_moe)
644
- vllm_fused_moe.quant_method.extra_backend_kwargs.update({
633
+ vllm_fused_moe.quant_method.block_size = {
645
634
  "bt": 32,
646
635
  "bf": 512,
647
636
  "bd1": 512,
@@ -650,7 +639,7 @@ def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
650
639
  "bfc": 256,
651
640
  "bd1c": 256,
652
641
  "bd2c": 256,
653
- })
642
+ }
654
643
  actual = vllm_fused_moe(jax_a, score)
655
644
 
656
645
  torch.testing.assert_close(
@@ -42,12 +42,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
42
42
  from vllm.model_executor.utils import set_random_seed
43
43
  from vllm.platforms import current_platform
44
44
 
45
- from tpu_inference.layers.vllm.process_weights.cleanup_sharding import \
46
- _shard_module_to_tpu
47
- from tpu_inference.layers.vllm.quantization.configs import \
48
- VllmQuantLinearConfig
45
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
49
46
  from tpu_inference.layers.vllm.quantization.unquantized import \
50
47
  VllmUnquantizedLinearMethod
48
+ from tpu_inference.layers.vllm.sharding import _shard_module_to_tpu
51
49
 
52
50
  from .utils import DummyLoRAManager
53
51
 
@@ -631,7 +629,7 @@ def _create_lora_wrapper(linear,
631
629
  mesh,
632
630
  repeats=1):
633
631
  base_linear.weight.data = linear.weight.data
634
- jax_config = VllmQuantLinearConfig(vllm_config, mesh, base_linear)
632
+ jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
635
633
  linear_method = VllmUnquantizedLinearMethod(jax_config)
636
634
  base_linear.quant_method = linear_method
637
635
  linear_method.process_weights_after_loading(
@@ -20,7 +20,7 @@ import ray
20
20
  import vllm.envs as envs
21
21
  from ray.util.placement_group import PlacementGroup
22
22
  from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
23
- from vllm.multimodal.inputs import MultiModalKwargsItem
23
+ from vllm.multimodal.inputs import MultiModalKwargs
24
24
  from vllm.platforms import current_platform
25
25
  from vllm.ray.ray_env import get_env_vars_to_copy
26
26
  from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
@@ -53,7 +53,7 @@ logger = init_logger(__name__)
53
53
 
54
54
 
55
55
  def _encode_hook(obj: Any) -> Any:
56
- """Custom msgspec enc hook that supports array types and MultiModalKwargsItem.
56
+ """Custom msgspec enc hook that supports array types and MultiModalKwargs.
57
57
 
58
58
  See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
59
59
  """
@@ -62,7 +62,7 @@ def _encode_hook(obj: Any) -> Any:
62
62
  f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
63
63
  f"Given array has a type code of {obj.typecode}.")
64
64
  return obj.tobytes()
65
- if isinstance(obj, MultiModalKwargsItem):
65
+ if isinstance(obj, MultiModalKwargs):
66
66
  return dict(obj)
67
67
 
68
68
 
@@ -52,7 +52,7 @@ def quantize_tensor_to_mxfp4_packed(
52
52
 
53
53
 
54
54
  def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
55
- """Unpack e2m1 tensor that was packed into u8."""
55
+ """Unpack e2m1 tensor packed into u8."""
56
56
  assert u8_packed_e2m1.dtype == jnp.uint8
57
57
  e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
58
58
  # bitcast creates one more dimension that splits 8 bits into two e2m1.
@@ -61,7 +61,7 @@ def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
61
61
 
62
62
 
63
63
  def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
64
- """Convert e8m0 (that was bitcasted to u8) into fp32."""
64
+ """Convert e8m0 (that was bitcasted to u8) into fp32"""
65
65
  assert u8.dtype == jnp.uint8
66
66
 
67
67
  e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
@@ -70,18 +70,6 @@ def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
70
70
  return jnp.ldexp(ones, exponents)
71
71
 
72
72
 
73
- def awq_u32_unpack_u4(awq_u32_packed: jax.Array) -> jax.Array:
74
- """Unpack u4 tensor that was packed into u32 in awq ordering."""
75
-
76
- awq_u4 = jax.lax.bitcast_convert_type(awq_u32_packed, jnp.uint4)
77
-
78
- # AWQ packs 8 uint4 into 32-bits in this order: (0, 2, 4, 6, 1, 3, 5, 7).
79
- # Following list maps the order used by AWQ into an ascending order.
80
- reverse_awq_order = (0, 4, 1, 5, 2, 6, 3, 7)
81
- u4 = awq_u4[..., reverse_awq_order]
82
- return jnp.reshape(u4, u4.shape[:-2] + (-1, ))
83
-
84
-
85
73
  def dequantize_tensor(
86
74
  tensor_q: jax.Array,
87
75
  scale: jax.Array,
@@ -21,7 +21,7 @@ from jax.sharding import PartitionSpec as P
21
21
 
22
22
  from tpu_inference.kernels.megablox.gmm import gmm
23
23
  from tpu_inference.layers.common.sharding import ShardingAxisName
24
- from tpu_inference.layers.common.utils import \
24
+ from tpu_inference.layers.vllm.linear_common import \
25
25
  slice_sharded_tensor_for_concatenation
26
26
  from tpu_inference.utils import get_mesh_shape_product
27
27