tpu-inference 0.13.2.dev20260104__tar.gz → 0.13.2rc1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/MANIFEST.in +1 -0
- {tpu_inference-0.13.2.dev20260104/tpu_inference.egg-info → tpu_inference-0.13.2rc1}/PKG-INFO +1 -1
- tpu_inference-0.13.2rc1/requirements_v7x.txt +25 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/setup.py +19 -5
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/e2e/test_speculative_decoding.py +2 -2
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/test_qwix.py +1 -1
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/vllm/test_awq.py +5 -6
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/vllm/test_compressed_tensors_moe.py +3 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +9 -32
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +4 -6
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/vllm/test_mxfp4.py +5 -13
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/vllm/test_unquantized.py +16 -27
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/lora/test_layers.py +3 -5
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/executors/ray_distributed_executor.py +3 -3
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/common/quantization.py +2 -14
- tpu_inference-0.13.2.dev20260104/tpu_inference/layers/common/fused_moe_gmm.py → tpu_inference-0.13.2rc1/tpu_inference/layers/vllm/fused_moe.py +1 -1
- tpu_inference-0.13.2rc1/tpu_inference/layers/vllm/linear_common.py +221 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/vllm/quantization/__init__.py +3 -3
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/vllm/quantization/awq.py +81 -81
- tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/quantization/configs.py → tpu_inference-0.13.2rc1/tpu_inference/layers/vllm/quantization/common.py +15 -12
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +5 -5
- tpu_inference-0.13.2rc1/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +91 -97
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +43 -65
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/vllm/quantization/fp8.py +5 -6
- tpu_inference-0.13.2rc1/tpu_inference/layers/vllm/quantization/mxfp4.py +410 -0
- tpu_inference-0.13.2rc1/tpu_inference/layers/vllm/quantization/unquantized.py +428 -0
- tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/process_weights/cleanup_sharding.py → tpu_inference-0.13.2rc1/tpu_inference/layers/vllm/sharding.py +12 -4
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/common/model_loader.py +1 -6
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/jax/utils/qwix/qwix_utils.py +3 -3
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/vllm/vllm_model_wrapper.py +1 -2
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/platforms/tpu_platform.py +7 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/runner/compilation_manager.py +4 -10
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/runner/kv_cache_manager.py +2 -1
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/runner/lora_utils.py +1 -2
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/runner/multimodal_manager.py +1 -1
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/runner/tpu_runner.py +1 -3
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1/tpu_inference.egg-info}/PKG-INFO +1 -1
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference.egg-info/SOURCES.txt +4 -8
- tpu_inference-0.13.2.dev20260104/tpu_inference/layers/common/utils.py +0 -94
- tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/fused_moe.py +0 -114
- tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/linear.py +0 -64
- tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +0 -369
- tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/process_weights/linear_weights.py +0 -174
- tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +0 -199
- tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/quantization/mxfp4.py +0 -225
- tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/quantization/unquantized.py +0 -298
- tpu_inference-0.13.2.dev20260104/tpu_inference/worker/__init__.py +0 -13
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/LICENSE +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/README.md +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/pyproject.toml +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/requirements.txt +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/setup.cfg +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/core/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/core/test_core_tpu.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/core/test_disagg_executor.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/core/test_disagg_utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/core/test_dp_scheduler.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/core/test_init.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/distributed/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/distributed/test_distributed_utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/distributed/test_tpu_connector.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/e2e/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/e2e/test_async_scheduler.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/e2e/test_data_parallel.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/e2e/test_hybrid_kvcache.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/e2e/test_local_disagg.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/e2e/test_model_loader.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/e2e/test_multi_modal_inference.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/e2e/test_pipeline_parallel.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/e2e/test_runai_model_streamer_loader.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/e2e/test_sampling_params.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/e2e/test_structured_decoding.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/executors/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/executors/test_ray_distributed_executor.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/experimental/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/experimental/test_llama3_jax_stashed.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/kernels/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/kernels/collectives/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/kernels/collectives/all_gather_matmul_kernel_test.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/kernels/fused_moe_v1_test.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/kernels/gmm_test.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/kernels/mla_v1_test.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/kernels/quantized_matmul_kernel_test.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/kernels/ragged_kv_cache_update_v2_test.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/kernels/ragged_paged_attention_kernel_v2_test.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/kernels/ragged_paged_attention_kernel_v3_test.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/common/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/common/test_attention_interface.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/common/test_quantization.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/attention/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/attention/test_common_attention.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/attention/test_deepseek_v3_attention.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/attention/test_llama4_attention.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/moe/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/moe/test_deepseek_moe.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/sample/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/sample/test_rejection_sampler.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/sample/test_sampling.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/sample/test_sampling_metadata.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/test_layers.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/test_rope.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/test_sharding.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/jax/test_transformer_block.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/vllm/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/vllm/test_attention.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/vllm/test_fp8.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/vllm/utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/lora/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/lora/conftest.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/lora/test_bgmv.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/lora/test_lora.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/lora/test_lora_perf.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/lora/utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/common/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/common/test_model_loader.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/jax/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/jax/test_deepseek_v3.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/jax/test_llama3.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/jax/test_llama4.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/jax/test_llama_eagle3.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/jax/test_llama_guard_4.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/jax/test_qwen2.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/jax/test_qwen2_5_vl.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/jax/test_qwen3.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/jax/test_weight_loading.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/jax/utils/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/models/jax/utils/test_multi_modal_utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/platforms/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/platforms/test_tpu_platform.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/runner/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/runner/test_block_table.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/runner/test_input_batch.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/runner/test_kv_cache.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/runner/test_kv_cache_manager.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/runner/test_multimodal_manager.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/runner/test_persistent_batch_manager.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/runner/test_speculative_decoding_manager.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/runner/test_structured_decoding_manager.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/runner/test_tpu_runner.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/runner/test_tpu_runner_dp.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/runner/test_tpu_runner_mesh.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/runner/test_utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/spec_decode/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/spec_decode/test_eagle3.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/test_base.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/test_envs.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/test_tpu_info.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/test_utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/worker/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/worker/tpu_worker_test.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/core/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/core/core_tpu.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/core/disagg_executor.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/core/disagg_utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/core/sched/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/core/sched/dp_scheduler.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/distributed/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/distributed/jax_parallel_state.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/distributed/tpu_connector.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/distributed/utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/env_override.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/envs.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/executors/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/experimental/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/experimental/llama3_jax_stashed.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/collectives/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/collectives/all_gather_matmul.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/collectives/util.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/flash_attention/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/flash_attention/kernel.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/fused_moe/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/fused_moe/v1/kernel.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/megablox/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/megablox/common.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/megablox/gmm.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/mla/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/mla/v1/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/mla/v1/kernel.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/quantized_matmul/kernel.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/quantized_matmul/util.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/kernels/ragged_paged_attention/v3/util.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/common/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/common/attention_interface.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/common/attention_metadata.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/common/binary_search.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/common/quant_methods.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/common/sharding.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/attention/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/attention/attention.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/attention/deepseek_v3_attention.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/attention/gpt_oss_attention.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/attention/llama4_attention.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/base.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/constants.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/layers.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/misc.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/moe/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/moe/deepseek_v3_moe.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/moe/gpt_oss_moe.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/moe/moe.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/pp_utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/rope.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/rope_interface.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/sample/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/sample/rejection_sampler.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/sample/sampling.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/sample/sampling_metadata.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/jax/transformer_block.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/vllm/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/layers/vllm/attention.py +0 -0
- {tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/process_weights → tpu_inference-0.13.2rc1/tpu_inference/layers/vllm/quantization/compressed_tensors}/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/quantization/compressed_tensors → tpu_inference-0.13.2rc1/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes}/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/logger.py +0 -0
- {tpu_inference-0.13.2.dev20260104/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes → tpu_inference-0.13.2rc1/tpu_inference/lora}/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/lora/torch_lora_ops.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/lora/torch_punica_tpu.py +0 -0
- {tpu_inference-0.13.2.dev20260104/tpu_inference/lora → tpu_inference-0.13.2rc1/tpu_inference/models}/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104/tpu_inference/models → tpu_inference-0.13.2rc1/tpu_inference/models/common}/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104/tpu_inference/models/common → tpu_inference-0.13.2rc1/tpu_inference/models/jax}/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/jax/deepseek_v3.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/jax/gpt_oss.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/jax/jax_intermediate_tensor.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/jax/llama3.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/jax/llama4.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/jax/llama_eagle3.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/jax/llama_guard_4.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/jax/qwen2.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/jax/qwen2_5_vl.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/jax/qwen3.py +0 -0
- {tpu_inference-0.13.2.dev20260104/tpu_inference/models/jax → tpu_inference-0.13.2rc1/tpu_inference/models/jax/utils}/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/jax/utils/file_utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/jax/utils/multi_modal_utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104/tpu_inference/models/jax/utils → tpu_inference-0.13.2rc1/tpu_inference/models/jax/utils/qwix}/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/jax/utils/weight_utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104/tpu_inference/models/jax/utils/qwix → tpu_inference-0.13.2rc1/tpu_inference/models/vllm}/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/models/vllm/vllm_model_wrapper_context.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/platforms/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104/tpu_inference/models/vllm → tpu_inference-0.13.2rc1/tpu_inference/runner}/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/runner/block_table.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/runner/input_batch.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/runner/kv_cache.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/runner/persistent_batch_manager.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/runner/speculative_decoding_manager.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/runner/structured_decoding_manager.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/runner/utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104/tpu_inference/runner → tpu_inference-0.13.2rc1/tpu_inference/spec_decode}/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104/tpu_inference/spec_decode → tpu_inference-0.13.2rc1/tpu_inference/spec_decode/jax}/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/spec_decode/jax/eagle3.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/tpu_info.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/utils.py +0 -0
- {tpu_inference-0.13.2.dev20260104/tpu_inference/spec_decode/jax → tpu_inference-0.13.2rc1/tpu_inference/worker}/__init__.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference/worker/tpu_worker.py +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference.egg-info/dependency_links.txt +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference.egg-info/requires.txt +0 -0
- {tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tpu_inference.egg-info/top_level.txt +0 -0
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# This file contains additional dependencies needed for TPU v7x support.
|
|
2
|
+
# It is expected to be used in conjunction with the main requirements.txt file.
|
|
3
|
+
--pre
|
|
4
|
+
-i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
|
|
5
|
+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
|
6
|
+
jax==0.8.1
|
|
7
|
+
jaxlib==0.8.1
|
|
8
|
+
jaxtyping==0.3.2
|
|
9
|
+
libtpu==0.0.31
|
|
10
|
+
|
|
11
|
+
tpu-info==0.7.1
|
|
12
|
+
yapf==0.43.0
|
|
13
|
+
pytest
|
|
14
|
+
pytest-mock
|
|
15
|
+
absl-py
|
|
16
|
+
numpy
|
|
17
|
+
google-cloud-storage
|
|
18
|
+
flax==0.11.1
|
|
19
|
+
torchax==0.0.10
|
|
20
|
+
qwix==0.1.1
|
|
21
|
+
torchvision==0.24.0
|
|
22
|
+
pathwaysutils
|
|
23
|
+
parameterized
|
|
24
|
+
numba==0.62.1
|
|
25
|
+
runai-model-streamer[s3,gcs]==0.15.0
|
|
@@ -20,26 +20,40 @@ def get_requirements() -> List[str]:
|
|
|
20
20
|
requirements = f.read().strip().split("\n")
|
|
21
21
|
resolved_requirements = []
|
|
22
22
|
for line in requirements:
|
|
23
|
+
if not line or line.startswith("#"):
|
|
24
|
+
continue
|
|
23
25
|
if line.startswith("-r "):
|
|
24
26
|
resolved_requirements += _read_requirements(line.split()[1])
|
|
25
|
-
elif line.startswith("--"):
|
|
27
|
+
elif line.startswith(("-", "--")):
|
|
26
28
|
continue
|
|
27
29
|
else:
|
|
28
30
|
resolved_requirements.append(line)
|
|
29
31
|
return resolved_requirements
|
|
30
32
|
|
|
31
33
|
try:
|
|
32
|
-
requirements = _read_requirements("
|
|
34
|
+
#requirements = _read_requirements("requirements_v7x.txt")
|
|
35
|
+
|
|
36
|
+
# For TPU v7x build
|
|
37
|
+
if os.getenv("IS_FOR_V7X", "false").lower() == "true":
|
|
38
|
+
print("Using requirements_v7x.txt")
|
|
39
|
+
requirements = _read_requirements("requirements_v7x.txt")
|
|
40
|
+
#requirements.extend(v7x_requirements)
|
|
41
|
+
else:
|
|
42
|
+
#For TPU v6e build
|
|
43
|
+
print("Using requirements.txt")
|
|
44
|
+
requirements = _read_requirements("requirements.txt")
|
|
45
|
+
|
|
33
46
|
except ValueError:
|
|
34
47
|
print("Failed to read requirements.txt in vllm_tpu.")
|
|
35
48
|
return requirements
|
|
36
49
|
|
|
37
50
|
|
|
38
51
|
def get_version():
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
52
|
+
version = os.getenv("VLLM_VERSION_OVERRIDE", "0.0.0").strip()
|
|
53
|
+
if os.getenv("IS_FOR_V7X", "false").lower() == "true":
|
|
54
|
+
version = f"{version}.post7"
|
|
42
55
|
|
|
56
|
+
return version
|
|
43
57
|
|
|
44
58
|
setup(
|
|
45
59
|
name="tpu_inference",
|
{tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/e2e/test_speculative_decoding.py
RENAMED
|
@@ -271,7 +271,7 @@ def test_ngram_performance_random(
|
|
|
271
271
|
"prompt_lookup_max": 2,
|
|
272
272
|
"prompt_lookup_min": 2,
|
|
273
273
|
"num_speculative_tokens": 4,
|
|
274
|
-
}, 1.
|
|
274
|
+
}, 1.5 if _is_v7x() else 3.0)
|
|
275
275
|
|
|
276
276
|
|
|
277
277
|
def test_eagle3_correctness(
|
|
@@ -308,4 +308,4 @@ def test_eagle3_performance(
|
|
|
308
308
|
"model": "unkmaster/EAGLE3-LLaMA3.1-Instruct-8B",
|
|
309
309
|
"num_speculative_tokens": 2,
|
|
310
310
|
"draft_tensor_parallel_size": 1
|
|
311
|
-
},
|
|
311
|
+
}, 1.2 if _is_v7x() else 1.8)
|
|
@@ -832,7 +832,7 @@ class TestGetDefaultQwixQuantizationConfig(unittest.TestCase):
|
|
|
832
832
|
# Patch the constants in the module where the function resides
|
|
833
833
|
self.patchers = [
|
|
834
834
|
patch(
|
|
835
|
-
"tpu_inference.models.jax.utils.qwix.qwix_utils.
|
|
835
|
+
"tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_DEEPSEEK_FP8_CONFIG",
|
|
836
836
|
self.mock_deepseek_config),
|
|
837
837
|
patch(
|
|
838
838
|
"tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_LLAMA4_FP8_CONFIG",
|
|
@@ -39,8 +39,7 @@ from vllm.scalar_type import scalar_types
|
|
|
39
39
|
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
40
40
|
from tpu_inference.layers.vllm.quantization.awq import (VllmAWQConfig,
|
|
41
41
|
VllmAWQLinearMethod)
|
|
42
|
-
from tpu_inference.layers.vllm.quantization.
|
|
43
|
-
VllmQuantLinearConfig
|
|
42
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
|
|
44
43
|
|
|
45
44
|
from . import utils as test_utils
|
|
46
45
|
|
|
@@ -104,8 +103,8 @@ def return_ref_and_layer_output(
|
|
|
104
103
|
assert isinstance(quant_method, VllmAWQLinearMethod)
|
|
105
104
|
quant_config = quant_method.quant_config
|
|
106
105
|
assert isinstance(quant_config, VllmAWQConfig)
|
|
107
|
-
jax_config = quant_method.
|
|
108
|
-
assert isinstance(jax_config,
|
|
106
|
+
jax_config = quant_method.jax_config
|
|
107
|
+
assert isinstance(jax_config, JaxCommonLinearConfig)
|
|
109
108
|
|
|
110
109
|
input_tensor = torch.rand(
|
|
111
110
|
batch_size, layer.input_size, dtype=torch.bfloat16) / 10
|
|
@@ -135,8 +134,8 @@ def initialize_and_return_layer_weights(layer: torch.nn.Module):
|
|
|
135
134
|
assert isinstance(quant_method, VllmAWQLinearMethod)
|
|
136
135
|
quant_config = quant_method.quant_config
|
|
137
136
|
assert isinstance(quant_config, VllmAWQConfig)
|
|
138
|
-
jax_config = quant_method.
|
|
139
|
-
assert isinstance(jax_config,
|
|
137
|
+
jax_config = quant_method.jax_config
|
|
138
|
+
assert isinstance(jax_config, JaxCommonLinearConfig)
|
|
140
139
|
|
|
141
140
|
# torch.rand returns value in the range of [0, 1). We subtract by 0.2 to
|
|
142
141
|
# simulate asymmetry
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
import os
|
|
15
16
|
import tempfile
|
|
16
17
|
|
|
17
18
|
import jax.numpy as jnp
|
|
@@ -42,6 +43,8 @@ from . import utils as test_utils
|
|
|
42
43
|
|
|
43
44
|
P = PartitionSpec
|
|
44
45
|
|
|
46
|
+
os.environ['VLLM_DISABLE_SHARED_EXPERTS_STREAM'] = '1'
|
|
47
|
+
|
|
45
48
|
MODEL = 'BCCard/Qwen3-30B-A3B-FP8-Dynamic'
|
|
46
49
|
|
|
47
50
|
|
|
@@ -16,7 +16,6 @@ import tempfile
|
|
|
16
16
|
from typing import Optional
|
|
17
17
|
|
|
18
18
|
import jax
|
|
19
|
-
import jax.numpy as jnp
|
|
20
19
|
import pytest
|
|
21
20
|
import torch
|
|
22
21
|
import torchax
|
|
@@ -37,15 +36,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|
|
37
36
|
CompressedTensorsLinearMethod
|
|
38
37
|
from vllm.model_executor.model_loader import get_model as vllm_get_model
|
|
39
38
|
|
|
40
|
-
from tpu_inference.layers.common.quantization import (dequantize_tensor,
|
|
41
|
-
quantize_tensor)
|
|
42
39
|
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
40
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
|
|
43
41
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
|
|
44
42
|
VllmCompressedTensorsConfig
|
|
45
|
-
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import
|
|
46
|
-
VllmCompressedTensorsW8A8Fp8
|
|
47
|
-
from tpu_inference.layers.vllm.quantization.configs import \
|
|
48
|
-
VllmQuantLinearConfig
|
|
43
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import (
|
|
44
|
+
VllmCompressedTensorsW8A8Fp8, requantize_with_max_scale)
|
|
49
45
|
|
|
50
46
|
from . import utils as test_utils
|
|
51
47
|
|
|
@@ -102,8 +98,8 @@ def return_ref_and_layer_output(layer: torch.nn.Module, batch_size: int = 16):
|
|
|
102
98
|
assert isinstance(layer, LinearBase)
|
|
103
99
|
scheme = layer.scheme
|
|
104
100
|
assert isinstance(scheme, VllmCompressedTensorsW8A8Fp8)
|
|
105
|
-
quant_config = scheme.
|
|
106
|
-
assert isinstance(quant_config,
|
|
101
|
+
quant_config = scheme.jax_config
|
|
102
|
+
assert isinstance(quant_config, JaxCommonLinearConfig)
|
|
107
103
|
quant_method = layer.quant_method
|
|
108
104
|
assert isinstance(quant_method, CompressedTensorsLinearMethod)
|
|
109
105
|
per_tensor = scheme.strategy == QuantizationStrategy.TENSOR
|
|
@@ -118,27 +114,8 @@ def return_ref_and_layer_output(layer: torch.nn.Module, batch_size: int = 16):
|
|
|
118
114
|
# For per_tensor with merged layers, vLLM requenzites them so all merged
|
|
119
115
|
# layers shared the same scale values.
|
|
120
116
|
if per_tensor:
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
weight = t2j(weight)
|
|
124
|
-
weight_scale = t2j(weight_scale)
|
|
125
|
-
weights = []
|
|
126
|
-
start = 0
|
|
127
|
-
# Multiple weights may have been concatenated. Loop through
|
|
128
|
-
# each weight and perform dequantization.
|
|
129
|
-
for i, output_size in enumerate(quant_config.output_sizes):
|
|
130
|
-
end = start + output_size
|
|
131
|
-
weights.append(
|
|
132
|
-
dequantize_tensor(weight[start:end], weight_scale[i]))
|
|
133
|
-
start = end
|
|
134
|
-
weight = jnp.concat(weights, axis=0)
|
|
135
|
-
weight, weight_scale = quantize_tensor(
|
|
136
|
-
jnp.float8_e4m3fn,
|
|
137
|
-
weight,
|
|
138
|
-
None,
|
|
139
|
-
)
|
|
140
|
-
weight = j2t(weight.astype(jnp.float32)).to(dtype)
|
|
141
|
-
weight_scale = j2t(weight_scale)
|
|
117
|
+
weight_scale, weight = requantize_with_max_scale(
|
|
118
|
+
layer.weight, layer.weight_scale, quant_config.output_sizes)
|
|
142
119
|
if input_scale is not None:
|
|
143
120
|
input_scale = input_scale.max()
|
|
144
121
|
|
|
@@ -174,8 +151,8 @@ def initialize_layer_weights(layer: torch.nn.Module):
|
|
|
174
151
|
assert isinstance(layer, LinearBase)
|
|
175
152
|
scheme = layer.scheme
|
|
176
153
|
assert isinstance(scheme, VllmCompressedTensorsW8A8Fp8)
|
|
177
|
-
quant_config = scheme.
|
|
178
|
-
assert isinstance(quant_config,
|
|
154
|
+
quant_config = scheme.jax_config
|
|
155
|
+
assert isinstance(quant_config, JaxCommonLinearConfig)
|
|
179
156
|
per_tensor = scheme.strategy == QuantizationStrategy.TENSOR
|
|
180
157
|
|
|
181
158
|
weight_list = []
|
|
@@ -185,7 +185,7 @@ def test_row_parallel_linear(model, bias, num_devices, enable_sp,
|
|
|
185
185
|
if bias:
|
|
186
186
|
jax_row_linear.bias.data = bias_data
|
|
187
187
|
|
|
188
|
-
input_tensor = torch.rand(10,
|
|
188
|
+
input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
|
|
189
189
|
input_tensor = input_tensor.to('cpu')
|
|
190
190
|
|
|
191
191
|
jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
|
|
@@ -259,8 +259,7 @@ def test_column_parallel_linear(model, bias, num_devices, enable_sp,
|
|
|
259
259
|
if bias:
|
|
260
260
|
jax_column_linear.bias.data = bias_data
|
|
261
261
|
|
|
262
|
-
input_tensor = torch.rand(10,
|
|
263
|
-
dtype=dtype) / 10
|
|
262
|
+
input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
|
|
264
263
|
input_tensor = input_tensor.to('cpu')
|
|
265
264
|
|
|
266
265
|
jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
|
|
@@ -339,7 +338,7 @@ def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
|
|
|
339
338
|
if bias:
|
|
340
339
|
jax_qkv_linear.bias.data = bias_data
|
|
341
340
|
|
|
342
|
-
input_tensor = torch.rand(10,
|
|
341
|
+
input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
|
|
343
342
|
input_tensor = input_tensor.to('cpu')
|
|
344
343
|
|
|
345
344
|
jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
|
|
@@ -415,8 +414,7 @@ def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
|
|
|
415
414
|
if bias:
|
|
416
415
|
jax_merged_column_linear.bias.data = bias_data
|
|
417
416
|
|
|
418
|
-
input_tensor = torch.rand(
|
|
419
|
-
10, jax_merged_column_linear.input_size, dtype=dtype) / 10
|
|
417
|
+
input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
|
|
420
418
|
input_tensor = input_tensor.to('cpu')
|
|
421
419
|
|
|
422
420
|
jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
|
{tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/vllm/test_mxfp4.py
RENAMED
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import tempfile
|
|
16
|
-
from unittest import mock
|
|
17
16
|
|
|
18
17
|
import jax
|
|
19
18
|
import jax.numpy as jnp
|
|
@@ -30,7 +29,6 @@ from vllm.engine.arg_utils import EngineArgs
|
|
|
30
29
|
from vllm.forward_context import set_forward_context
|
|
31
30
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
32
31
|
|
|
33
|
-
from tpu_inference.layers.vllm.fused_moe import FusedMoEBackend
|
|
34
32
|
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
35
33
|
from tpu_inference.layers.vllm.quantization.mxfp4 import (VllmMxfp4Config,
|
|
36
34
|
VllmMxfp4MoEMethod)
|
|
@@ -162,8 +160,6 @@ def test_mxfp4_fused_moe(num_devices, num_tokens, intermediate_size,
|
|
|
162
160
|
)
|
|
163
161
|
vllm_config = engine_args.create_engine_config()
|
|
164
162
|
vllm_config.model_config.dtype = dtype
|
|
165
|
-
vllm_config.parallel_config = ParallelConfig(
|
|
166
|
-
tensor_parallel_size=mesh.devices.size, enable_expert_parallel=use_ep)
|
|
167
163
|
|
|
168
164
|
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
169
165
|
with set_current_vllm_config(vllm_config):
|
|
@@ -194,16 +190,13 @@ def test_mxfp4_fused_moe(num_devices, num_tokens, intermediate_size,
|
|
|
194
190
|
|
|
195
191
|
with torchax.default_env(), set_forward_context(None, vllm_config):
|
|
196
192
|
assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod)
|
|
197
|
-
if use_ep:
|
|
198
|
-
assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.GMM_EP
|
|
199
|
-
else:
|
|
200
|
-
assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.GMM_TP
|
|
201
193
|
|
|
202
194
|
jax_a = a.to('jax')
|
|
203
195
|
score = score.to('jax')
|
|
204
196
|
|
|
205
197
|
vllm_fused_moe.quant_method.process_weights_after_loading(
|
|
206
198
|
vllm_fused_moe)
|
|
199
|
+
|
|
207
200
|
actual = vllm_fused_moe(jax_a, score)
|
|
208
201
|
|
|
209
202
|
torch.testing.assert_close(expected,
|
|
@@ -220,7 +213,6 @@ def test_mxfp4_fused_moe(num_devices, num_tokens, intermediate_size,
|
|
|
220
213
|
@pytest.mark.parametrize("num_experts", [8])
|
|
221
214
|
@pytest.mark.parametrize("topk", [2])
|
|
222
215
|
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
223
|
-
@mock.patch("os.environ", {"USE_MOE_EP_KERNEL": "1"})
|
|
224
216
|
def test_mxfp4_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
|
|
225
217
|
hidden_size, num_experts, topk,
|
|
226
218
|
enable_attn_dp):
|
|
@@ -261,7 +253,7 @@ def test_mxfp4_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
|
|
|
261
253
|
vllm_config = engine_args.create_engine_config()
|
|
262
254
|
vllm_config.model_config.dtype = dtype
|
|
263
255
|
vllm_config.parallel_config = ParallelConfig(
|
|
264
|
-
tensor_parallel_size=mesh.devices.size
|
|
256
|
+
tensor_parallel_size=mesh.devices.size)
|
|
265
257
|
|
|
266
258
|
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
267
259
|
with set_current_vllm_config(vllm_config):
|
|
@@ -293,14 +285,14 @@ def test_mxfp4_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
|
|
|
293
285
|
|
|
294
286
|
with torchax.default_env(), set_forward_context(None, vllm_config):
|
|
295
287
|
assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod)
|
|
296
|
-
assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.FUSED_MOE
|
|
297
288
|
|
|
298
289
|
jax_a = a.to('jax')
|
|
299
290
|
score = score.to('jax')
|
|
300
291
|
|
|
292
|
+
vllm_fused_moe.quant_method.use_kernel = True
|
|
301
293
|
vllm_fused_moe.quant_method.process_weights_after_loading(
|
|
302
294
|
vllm_fused_moe)
|
|
303
|
-
vllm_fused_moe.quant_method.
|
|
295
|
+
vllm_fused_moe.quant_method.block_size = {
|
|
304
296
|
"bt": 32,
|
|
305
297
|
"bf": 512,
|
|
306
298
|
"bd1": 1024,
|
|
@@ -309,7 +301,7 @@ def test_mxfp4_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
|
|
|
309
301
|
"bfc": 512,
|
|
310
302
|
"bd1c": 1024,
|
|
311
303
|
"bd2c": 1024,
|
|
312
|
-
}
|
|
304
|
+
}
|
|
313
305
|
|
|
314
306
|
actual = vllm_fused_moe(jax_a, score)
|
|
315
307
|
|
{tpu_inference-0.13.2.dev20260104 → tpu_inference-0.13.2rc1}/tests/layers/vllm/test_unquantized.py
RENAMED
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import tempfile
|
|
16
|
-
from unittest import mock
|
|
17
16
|
|
|
18
17
|
import jax
|
|
19
18
|
import pytest
|
|
@@ -36,7 +35,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
|
36
35
|
RowParallelLinear)
|
|
37
36
|
from vllm.model_executor.model_loader import get_model as vllm_get_model
|
|
38
37
|
|
|
39
|
-
from tpu_inference.layers.vllm.fused_moe import FusedMoEBackend
|
|
40
38
|
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
41
39
|
from tpu_inference.layers.vllm.quantization.unquantized import (
|
|
42
40
|
VllmUnquantizedConfig, VllmUnquantizedFusedMoEMethod,
|
|
@@ -141,6 +139,9 @@ def test_row_parallel_linear(model, bias, num_devices, enable_sp,
|
|
|
141
139
|
vllm_config = engine_args.create_engine_config()
|
|
142
140
|
vllm_config.compilation_config.pass_config.enable_sp = enable_sp
|
|
143
141
|
|
|
142
|
+
input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
|
|
143
|
+
input_tensor = input_tensor.to('cpu')
|
|
144
|
+
|
|
144
145
|
with set_current_vllm_config(vllm_config):
|
|
145
146
|
row_linear = RowParallelLinear(
|
|
146
147
|
input_size=4096,
|
|
@@ -150,9 +151,6 @@ def test_row_parallel_linear(model, bias, num_devices, enable_sp,
|
|
|
150
151
|
return_bias=False,
|
|
151
152
|
)
|
|
152
153
|
|
|
153
|
-
input_tensor = torch.rand(10, row_linear.input_size, dtype=dtype) / 10
|
|
154
|
-
input_tensor = input_tensor.to('cpu')
|
|
155
|
-
|
|
156
154
|
weight_data = torch.rand_like(row_linear.weight.data) / 10
|
|
157
155
|
if bias:
|
|
158
156
|
bias_data = torch.rand_like(row_linear.bias.data)
|
|
@@ -218,6 +216,9 @@ def test_column_parallel_linear(model, bias, num_devices, enable_sp,
|
|
|
218
216
|
vllm_config = engine_args.create_engine_config()
|
|
219
217
|
vllm_config.compilation_config.pass_config.enable_sp = enable_sp
|
|
220
218
|
|
|
219
|
+
input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
|
|
220
|
+
input_tensor = input_tensor.to('cpu')
|
|
221
|
+
|
|
221
222
|
with set_current_vllm_config(vllm_config):
|
|
222
223
|
column_linear = ColumnParallelLinear(
|
|
223
224
|
input_size=4096,
|
|
@@ -227,9 +228,6 @@ def test_column_parallel_linear(model, bias, num_devices, enable_sp,
|
|
|
227
228
|
return_bias=False,
|
|
228
229
|
)
|
|
229
230
|
|
|
230
|
-
input_tensor = torch.rand(10, column_linear.input_size, dtype=dtype) / 10
|
|
231
|
-
input_tensor = input_tensor.to('cpu')
|
|
232
|
-
|
|
233
231
|
weight_data = torch.rand_like(column_linear.weight.data) / 10
|
|
234
232
|
if bias:
|
|
235
233
|
bias_data = torch.rand_like(column_linear.bias.data)
|
|
@@ -295,6 +293,9 @@ def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
|
|
|
295
293
|
vllm_config = engine_args.create_engine_config()
|
|
296
294
|
vllm_config.compilation_config.pass_config.enable_sp = enable_sp
|
|
297
295
|
|
|
296
|
+
input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
|
|
297
|
+
input_tensor = input_tensor.to('cpu')
|
|
298
|
+
|
|
298
299
|
with set_current_vllm_config(vllm_config):
|
|
299
300
|
qkv_linear = QKVParallelLinear(
|
|
300
301
|
hidden_size=4096,
|
|
@@ -306,9 +307,6 @@ def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
|
|
|
306
307
|
return_bias=False,
|
|
307
308
|
)
|
|
308
309
|
|
|
309
|
-
input_tensor = torch.rand(10, qkv_linear.input_size, dtype=dtype) / 10
|
|
310
|
-
input_tensor = input_tensor.to('cpu')
|
|
311
|
-
|
|
312
310
|
weight_data = torch.rand_like(qkv_linear.weight.data) / 10
|
|
313
311
|
if bias:
|
|
314
312
|
bias_data = torch.rand_like(qkv_linear.bias.data)
|
|
@@ -377,6 +375,9 @@ def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
|
|
|
377
375
|
vllm_config = engine_args.create_engine_config()
|
|
378
376
|
vllm_config.compilation_config.pass_config.enable_sp = enable_sp
|
|
379
377
|
|
|
378
|
+
input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
|
|
379
|
+
input_tensor = input_tensor.to('cpu')
|
|
380
|
+
|
|
380
381
|
# Call vLLM code
|
|
381
382
|
with set_current_vllm_config(vllm_config):
|
|
382
383
|
merged_column_linear = MergedColumnParallelLinear(
|
|
@@ -387,10 +388,6 @@ def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
|
|
|
387
388
|
return_bias=False,
|
|
388
389
|
)
|
|
389
390
|
|
|
390
|
-
input_tensor = torch.rand(10, merged_column_linear.input_size,
|
|
391
|
-
dtype=dtype) / 10
|
|
392
|
-
input_tensor = input_tensor.to('cpu')
|
|
393
|
-
|
|
394
391
|
weight_data = torch.rand_like(merged_column_linear.weight.data) / 10
|
|
395
392
|
if bias:
|
|
396
393
|
bias_data = torch.rand_like(merged_column_linear.bias.data)
|
|
@@ -478,8 +475,6 @@ def test_fused_moe(use_ep, num_devices, num_tokens, intermediate_size,
|
|
|
478
475
|
)
|
|
479
476
|
vllm_config = engine_args.create_engine_config()
|
|
480
477
|
vllm_config.model_config.dtype = dtype
|
|
481
|
-
vllm_config.parallel_config = ParallelConfig(
|
|
482
|
-
tensor_parallel_size=mesh.devices.size, enable_expert_parallel=use_ep)
|
|
483
478
|
|
|
484
479
|
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
485
480
|
with set_current_vllm_config(vllm_config):
|
|
@@ -511,10 +506,6 @@ def test_fused_moe(use_ep, num_devices, num_tokens, intermediate_size,
|
|
|
511
506
|
with torchax.default_env(), set_forward_context(None, vllm_config):
|
|
512
507
|
assert isinstance(vllm_fused_moe.quant_method,
|
|
513
508
|
VllmUnquantizedFusedMoEMethod)
|
|
514
|
-
if use_ep:
|
|
515
|
-
assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.GMM_EP
|
|
516
|
-
else:
|
|
517
|
-
assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.GMM_TP
|
|
518
509
|
|
|
519
510
|
jax_a = a.to('jax')
|
|
520
511
|
score = score.to('jax')
|
|
@@ -538,7 +529,6 @@ def test_fused_moe(use_ep, num_devices, num_tokens, intermediate_size,
|
|
|
538
529
|
@pytest.mark.parametrize("topk", [8])
|
|
539
530
|
@pytest.mark.parametrize("has_bias", [False, True])
|
|
540
531
|
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
541
|
-
@mock.patch("os.environ", {"USE_MOE_EP_KERNEL": "1"})
|
|
542
532
|
def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
|
|
543
533
|
hidden_size, num_experts, topk, has_bias,
|
|
544
534
|
enable_attn_dp):
|
|
@@ -602,7 +592,7 @@ def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
|
|
|
602
592
|
vllm_config = engine_args.create_engine_config()
|
|
603
593
|
vllm_config.model_config.dtype = dtype
|
|
604
594
|
vllm_config.parallel_config = ParallelConfig(
|
|
605
|
-
tensor_parallel_size=mesh.devices.size
|
|
595
|
+
tensor_parallel_size=mesh.devices.size)
|
|
606
596
|
|
|
607
597
|
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
608
598
|
with set_current_vllm_config(vllm_config):
|
|
@@ -619,6 +609,7 @@ def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
|
|
|
619
609
|
has_bias=has_bias,
|
|
620
610
|
)
|
|
621
611
|
vllm_fused_moe.moe_parallel_config.use_ep = True
|
|
612
|
+
vllm_fused_moe.quant_method.use_kernel = True
|
|
622
613
|
|
|
623
614
|
vllm_fused_moe.w13_weight.data = w1
|
|
624
615
|
vllm_fused_moe.w2_weight.data = w2
|
|
@@ -634,14 +625,12 @@ def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
|
|
|
634
625
|
with torchax.default_env(), set_forward_context(None, vllm_config):
|
|
635
626
|
assert isinstance(vllm_fused_moe.quant_method,
|
|
636
627
|
VllmUnquantizedFusedMoEMethod)
|
|
637
|
-
assert vllm_fused_moe.quant_method.moe_backend == FusedMoEBackend.FUSED_MOE
|
|
638
|
-
|
|
639
628
|
jax_a = a.to('jax')
|
|
640
629
|
score = score.to('jax')
|
|
641
630
|
|
|
642
631
|
vllm_fused_moe.quant_method.process_weights_after_loading(
|
|
643
632
|
vllm_fused_moe)
|
|
644
|
-
vllm_fused_moe.quant_method.
|
|
633
|
+
vllm_fused_moe.quant_method.block_size = {
|
|
645
634
|
"bt": 32,
|
|
646
635
|
"bf": 512,
|
|
647
636
|
"bd1": 512,
|
|
@@ -650,7 +639,7 @@ def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
|
|
|
650
639
|
"bfc": 256,
|
|
651
640
|
"bd1c": 256,
|
|
652
641
|
"bd2c": 256,
|
|
653
|
-
}
|
|
642
|
+
}
|
|
654
643
|
actual = vllm_fused_moe(jax_a, score)
|
|
655
644
|
|
|
656
645
|
torch.testing.assert_close(
|
|
@@ -42,12 +42,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
|
42
42
|
from vllm.model_executor.utils import set_random_seed
|
|
43
43
|
from vllm.platforms import current_platform
|
|
44
44
|
|
|
45
|
-
from tpu_inference.layers.vllm.
|
|
46
|
-
_shard_module_to_tpu
|
|
47
|
-
from tpu_inference.layers.vllm.quantization.configs import \
|
|
48
|
-
VllmQuantLinearConfig
|
|
45
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
|
|
49
46
|
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
50
47
|
VllmUnquantizedLinearMethod
|
|
48
|
+
from tpu_inference.layers.vllm.sharding import _shard_module_to_tpu
|
|
51
49
|
|
|
52
50
|
from .utils import DummyLoRAManager
|
|
53
51
|
|
|
@@ -631,7 +629,7 @@ def _create_lora_wrapper(linear,
|
|
|
631
629
|
mesh,
|
|
632
630
|
repeats=1):
|
|
633
631
|
base_linear.weight.data = linear.weight.data
|
|
634
|
-
jax_config =
|
|
632
|
+
jax_config = JaxCommonLinearConfig(vllm_config, mesh, base_linear)
|
|
635
633
|
linear_method = VllmUnquantizedLinearMethod(jax_config)
|
|
636
634
|
base_linear.quant_method = linear_method
|
|
637
635
|
linear_method.process_weights_after_loading(
|
|
@@ -20,7 +20,7 @@ import ray
|
|
|
20
20
|
import vllm.envs as envs
|
|
21
21
|
from ray.util.placement_group import PlacementGroup
|
|
22
22
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
|
23
|
-
from vllm.multimodal.inputs import
|
|
23
|
+
from vllm.multimodal.inputs import MultiModalKwargs
|
|
24
24
|
from vllm.platforms import current_platform
|
|
25
25
|
from vllm.ray.ray_env import get_env_vars_to_copy
|
|
26
26
|
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
|
|
@@ -53,7 +53,7 @@ logger = init_logger(__name__)
|
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
def _encode_hook(obj: Any) -> Any:
|
|
56
|
-
"""Custom msgspec enc hook that supports array types and
|
|
56
|
+
"""Custom msgspec enc hook that supports array types and MultiModalKwargs.
|
|
57
57
|
|
|
58
58
|
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
|
|
59
59
|
"""
|
|
@@ -62,7 +62,7 @@ def _encode_hook(obj: Any) -> Any:
|
|
|
62
62
|
f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
|
|
63
63
|
f"Given array has a type code of {obj.typecode}.")
|
|
64
64
|
return obj.tobytes()
|
|
65
|
-
if isinstance(obj,
|
|
65
|
+
if isinstance(obj, MultiModalKwargs):
|
|
66
66
|
return dict(obj)
|
|
67
67
|
|
|
68
68
|
|
|
@@ -52,7 +52,7 @@ def quantize_tensor_to_mxfp4_packed(
|
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
|
|
55
|
-
"""Unpack e2m1 tensor
|
|
55
|
+
"""Unpack e2m1 tensor packed into u8."""
|
|
56
56
|
assert u8_packed_e2m1.dtype == jnp.uint8
|
|
57
57
|
e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
|
|
58
58
|
# bitcast creates one more dimension that splits 8 bits into two e2m1.
|
|
@@ -61,7 +61,7 @@ def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
|
|
|
61
61
|
|
|
62
62
|
|
|
63
63
|
def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
|
|
64
|
-
"""Convert e8m0 (that was bitcasted to u8) into fp32
|
|
64
|
+
"""Convert e8m0 (that was bitcasted to u8) into fp32"""
|
|
65
65
|
assert u8.dtype == jnp.uint8
|
|
66
66
|
|
|
67
67
|
e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
@@ -70,18 +70,6 @@ def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
|
|
|
70
70
|
return jnp.ldexp(ones, exponents)
|
|
71
71
|
|
|
72
72
|
|
|
73
|
-
def awq_u32_unpack_u4(awq_u32_packed: jax.Array) -> jax.Array:
|
|
74
|
-
"""Unpack u4 tensor that was packed into u32 in awq ordering."""
|
|
75
|
-
|
|
76
|
-
awq_u4 = jax.lax.bitcast_convert_type(awq_u32_packed, jnp.uint4)
|
|
77
|
-
|
|
78
|
-
# AWQ packs 8 uint4 into 32-bits in this order: (0, 2, 4, 6, 1, 3, 5, 7).
|
|
79
|
-
# Following list maps the order used by AWQ into an ascending order.
|
|
80
|
-
reverse_awq_order = (0, 4, 1, 5, 2, 6, 3, 7)
|
|
81
|
-
u4 = awq_u4[..., reverse_awq_order]
|
|
82
|
-
return jnp.reshape(u4, u4.shape[:-2] + (-1, ))
|
|
83
|
-
|
|
84
|
-
|
|
85
73
|
def dequantize_tensor(
|
|
86
74
|
tensor_q: jax.Array,
|
|
87
75
|
scale: jax.Array,
|
|
@@ -21,7 +21,7 @@ from jax.sharding import PartitionSpec as P
|
|
|
21
21
|
|
|
22
22
|
from tpu_inference.kernels.megablox.gmm import gmm
|
|
23
23
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
24
|
-
from tpu_inference.layers.
|
|
24
|
+
from tpu_inference.layers.vllm.linear_common import \
|
|
25
25
|
slice_sharded_tensor_for_concatenation
|
|
26
26
|
from tpu_inference.utils import get_mesh_shape_product
|
|
27
27
|
|