tpu-inference 0.12.0.dev20251219__tar.gz → 0.12.0rc1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- {tpu_inference-0.12.0.dev20251219/tpu_inference.egg-info → tpu_inference-0.12.0rc1}/PKG-INFO +8 -6
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/README.md +6 -4
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/requirements.txt +1 -1
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/mla_v1_test.py +41 -129
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/quantized_matmul_kernel_test.py +34 -2
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +1 -3
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/ragged_paged_attention_kernel_v3_test.py +1 -3
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/lora/test_layers.py +3 -7
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/lora/test_lora.py +1 -1
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/test_envs.py +1 -78
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/test_quantization.py +0 -3
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/distributed/tpu_connector.py +3 -3
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/envs.py +7 -38
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/executors/ray_distributed_executor.py +0 -3
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/collectives/all_gather_matmul.py +6 -12
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +2 -7
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/fused_moe/v1/kernel.py +324 -357
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/mla/v1/kernel.py +120 -98
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/quantized_matmul/kernel.py +8 -69
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +1 -2
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +1 -2
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +101 -181
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +78 -82
- tpu_inference-0.12.0rc1/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference-0.12.0rc1/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/ragged_paged_attention/v3/util.py +1 -2
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/common/attention_interface.py +7 -1
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/common/quant_methods.py +0 -1
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/common/sharding.py +2 -6
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/attention/deepseek_v3_attention.py +64 -232
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/fused_moe.py +247 -180
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/linear_common.py +21 -43
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/quantization/__init__.py +0 -2
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/quantization/awq.py +1 -1
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/quantization/common.py +5 -5
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +3 -4
- tpu_inference-0.12.0rc1/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference-0.12.0rc1/tpu_inference/layers/vllm/quantization/mxfp4.py +341 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/quantization/unquantized.py +81 -105
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/lora/torch_lora_ops.py +13 -8
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/common/model_loader.py +20 -48
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/deepseek_v3.py +64 -185
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/gpt_oss.py +3 -3
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/llama3.py +33 -79
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/quantization/quantization_utils.py +2 -4
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/weight_utils.py +2 -26
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/vllm/vllm_model_wrapper.py +1 -1
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/platforms/tpu_platform.py +37 -15
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/compilation_manager.py +2 -3
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/kv_cache.py +20 -40
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/kv_cache_manager.py +15 -31
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/tpu_runner.py +7 -14
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/utils.py +6 -11
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/worker/tpu_worker.py +44 -44
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1/tpu_inference.egg-info}/PKG-INFO +8 -6
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference.egg-info/SOURCES.txt +0 -7
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference.egg-info/requires.txt +1 -1
- tpu_inference-0.12.0.dev20251219/tests/kernels/gmm_test.py +0 -191
- tpu_inference-0.12.0.dev20251219/tests/lora/test_lora_perf.py +0 -53
- tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/megablox/common.py +0 -41
- tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/megablox/gmm.py +0 -633
- tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -4447
- tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +0 -535
- tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/pp_utils.py +0 -39
- tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +0 -252
- tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/fp8.py +0 -104
- tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/mxfp4.py +0 -448
- tpu_inference-0.12.0.dev20251219/tpu_inference/worker/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/MANIFEST.in +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/pyproject.toml +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/setup.cfg +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/setup.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/core/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/core/test_core_tpu.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/core/test_disagg_executor.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/core/test_disagg_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/core/test_dp_scheduler.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/core/test_init.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/fused_moe_v1_test.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/ragged_kv_cache_update_v2_test.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/kernels/ragged_paged_attention_kernel_v2_test.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/lora/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/lora/conftest.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/lora/test_bgmv.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/lora/utils.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/test_base.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/test_tpu_info.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tests/test_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/core/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/core/core_tpu.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/core/disagg_executor.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/core/disagg_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/core/sched/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/core/sched/dp_scheduler.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/distributed/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/distributed/jax_parallel_state.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/distributed/utils.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/env_override.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/executors/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/experimental/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/experimental/llama3_jax_stashed.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/collectives/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/collectives/util.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/flash_attention/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/flash_attention/kernel.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/fused_moe/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/megablox → tpu_inference-0.12.0rc1/tpu_inference/kernels/mla}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/mla → tpu_inference-0.12.0rc1/tpu_inference/kernels/mla/v1}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/mla/v1 → tpu_inference-0.12.0rc1/tpu_inference/kernels/quantized_matmul}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/quantized_matmul/util.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/quantized_matmul → tpu_inference-0.12.0rc1/tpu_inference/kernels/ragged_paged_attention}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention → tpu_inference-0.12.0rc1/tpu_inference/kernels/ragged_paged_attention/v2}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v2 → tpu_inference-0.12.0rc1/tpu_inference/kernels/ragged_paged_attention/v3}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v3 → tpu_inference-0.12.0rc1/tpu_inference/layers}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/layers → tpu_inference-0.12.0rc1/tpu_inference/layers/common}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/common/attention_metadata.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/common/binary_search.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/common → tpu_inference-0.12.0rc1/tpu_inference/layers/jax}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax → tpu_inference-0.12.0rc1/tpu_inference/layers/jax/attention}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/attention/attention.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/attention/llama4_attention.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/base.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/constants.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/layers.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/misc.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/attention → tpu_inference-0.12.0rc1/tpu_inference/layers/jax/moe}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/moe/deepseek_v3_moe.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/moe/gpt_oss_moe.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/moe/moe.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/rope.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/rope_interface.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/moe → tpu_inference-0.12.0rc1/tpu_inference/layers/jax/sample}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/sample/rejection_sampler.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/sample/sampling.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/sample/sampling_metadata.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/jax/transformer_block.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/sample → tpu_inference-0.12.0rc1/tpu_inference/layers/vllm}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/attention.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm → tpu_inference-0.12.0rc1/tpu_inference/layers/vllm/quantization/compressed_tensors}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/compressed_tensors → tpu_inference-0.12.0rc1/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/layers/vllm/sharding.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/logger.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes → tpu_inference-0.12.0rc1/tpu_inference/lora}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/lora/torch_punica_tpu.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/lora → tpu_inference-0.12.0rc1/tpu_inference/models}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/models → tpu_inference-0.12.0rc1/tpu_inference/models/common}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/models/common → tpu_inference-0.12.0rc1/tpu_inference/models/jax}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/jax_intermediate_tensor.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/llama4.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/llama_eagle3.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/llama_guard_4.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/qwen2.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/qwen2_5_vl.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/qwen3.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/models/jax → tpu_inference-0.12.0rc1/tpu_inference/models/jax/utils}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/file_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/multi_modal_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/models/jax/utils → tpu_inference-0.12.0rc1/tpu_inference/models/jax/utils/quantization}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/models/jax/utils/quantization → tpu_inference-0.12.0rc1/tpu_inference/models/vllm}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/models/vllm/vllm_model_wrapper_context.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/platforms/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/models/vllm → tpu_inference-0.12.0rc1/tpu_inference/runner}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/block_table.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/input_batch.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/lora_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/multimodal_manager.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/persistent_batch_manager.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/speculative_decoding_manager.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/structured_decoding_manager.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/runner/utils.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/runner → tpu_inference-0.12.0rc1/tpu_inference/spec_decode}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/spec_decode → tpu_inference-0.12.0rc1/tpu_inference/spec_decode/jax}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/spec_decode/jax/eagle3.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference/tpu_info.py +0 -0
- {tpu_inference-0.12.0.dev20251219/tpu_inference/spec_decode/jax → tpu_inference-0.12.0rc1/tpu_inference/worker}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference.egg-info/dependency_links.txt +0 -0
- {tpu_inference-0.12.0.dev20251219 → tpu_inference-0.12.0rc1}/tpu_inference.egg-info/top_level.txt +0 -0
{tpu_inference-0.12.0.dev20251219/tpu_inference.egg-info → tpu_inference-0.12.0rc1}/PKG-INFO
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tpu_inference
|
|
3
|
-
Version: 0.12.
|
|
3
|
+
Version: 0.12.0rc1
|
|
4
4
|
Author: tpu_inference Contributors
|
|
5
5
|
Classifier: Development Status :: 3 - Alpha
|
|
6
6
|
Classifier: Intended Audience :: Developers
|
|
@@ -25,7 +25,7 @@ Requires-Dist: jax[tpu]==0.8.0
|
|
|
25
25
|
Requires-Dist: jaxlib==0.8.0
|
|
26
26
|
Requires-Dist: jaxtyping
|
|
27
27
|
Requires-Dist: flax==0.11.1
|
|
28
|
-
Requires-Dist: torchax==0.0.
|
|
28
|
+
Requires-Dist: torchax==0.0.7
|
|
29
29
|
Requires-Dist: qwix==0.1.1
|
|
30
30
|
Requires-Dist: torchvision==0.24.0
|
|
31
31
|
Requires-Dist: pathwaysutils
|
|
@@ -53,11 +53,13 @@ Dynamic: requires-python
|
|
|
53
53
|
|
|
54
54
|
---
|
|
55
55
|
|
|
56
|
-
|
|
56
|
+
_Upcoming Events_ 🔥
|
|
57
|
+
|
|
58
|
+
- Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
|
|
59
|
+
- Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
60
|
+
- Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
57
61
|
|
|
58
|
-
|
|
59
|
-
- Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
60
|
-
- Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
62
|
+
_Latest News_ 🔥
|
|
61
63
|
|
|
62
64
|
- [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
|
|
63
65
|
|
|
@@ -11,11 +11,13 @@
|
|
|
11
11
|
|
|
12
12
|
---
|
|
13
13
|
|
|
14
|
-
|
|
14
|
+
_Upcoming Events_ 🔥
|
|
15
|
+
|
|
16
|
+
- Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
|
|
17
|
+
- Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
18
|
+
- Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
15
19
|
|
|
16
|
-
|
|
17
|
-
- Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
18
|
-
- Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
20
|
+
_Latest News_ 🔥
|
|
19
21
|
|
|
20
22
|
- [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
|
|
21
23
|
|
|
@@ -42,7 +42,6 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
|
|
|
42
42
|
|
|
43
43
|
padded_r_dim = align_to(r_dim, 128)
|
|
44
44
|
padded_lkv_dim = align_to(lkv_dim, 128)
|
|
45
|
-
padded_kv_dim = padded_lkv_dim + padded_r_dim
|
|
46
45
|
packing = get_dtype_packing(kv_dtype)
|
|
47
46
|
q_lens = [s[0] for s in seq_lens]
|
|
48
47
|
kv_lens_list = [s[1] for s in seq_lens]
|
|
@@ -70,10 +69,13 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
|
|
|
70
69
|
new_kv_c = gen_random((total_q_len, lkv_dim), kv_dtype)
|
|
71
70
|
new_k_pe = gen_random((total_q_len, r_dim), kv_dtype)
|
|
72
71
|
|
|
73
|
-
|
|
74
|
-
(total_num_pages, page_size // packing, packing,
|
|
72
|
+
cache_kv_c = gen_random(
|
|
73
|
+
(total_num_pages, page_size // packing, packing, padded_lkv_dim),
|
|
75
74
|
kv_dtype,
|
|
76
75
|
)
|
|
76
|
+
cache_k_pe = gen_random(
|
|
77
|
+
(total_num_pages, page_size // packing, packing, padded_r_dim),
|
|
78
|
+
kv_dtype)
|
|
77
79
|
kv_lens = jnp.array(kv_lens_list, dtype=jnp.int32)
|
|
78
80
|
page_indices = jnp.array(page_indices_list, dtype=jnp.int32)
|
|
79
81
|
cu_q_lens = jnp.array(cu_q_lens_list, dtype=jnp.int32)
|
|
@@ -82,13 +84,14 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
|
|
|
82
84
|
ql_nope_for_kernel = ql_nope.copy()
|
|
83
85
|
q_pe_for_kernel = q_pe.copy()
|
|
84
86
|
|
|
85
|
-
expected_out,
|
|
87
|
+
expected_out, expected_updated_kv_c, expeceted_updated_k_pe = (
|
|
86
88
|
mla.ref_mla_ragged_paged_attention(
|
|
87
89
|
ql_nope,
|
|
88
90
|
q_pe,
|
|
89
91
|
new_kv_c,
|
|
90
92
|
new_k_pe,
|
|
91
|
-
|
|
93
|
+
cache_kv_c.copy(),
|
|
94
|
+
cache_k_pe.copy(),
|
|
92
95
|
kv_lens,
|
|
93
96
|
page_indices,
|
|
94
97
|
cu_q_lens,
|
|
@@ -98,140 +101,49 @@ class MlaRaggedPagedAttentionKernelTest(jtu.JaxTestCase):
|
|
|
98
101
|
soft_cap=soft_cap,
|
|
99
102
|
))
|
|
100
103
|
|
|
101
|
-
kernel_out,
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
104
|
+
kernel_out, kernel_updated_kv_c, kernel_updated_k_pe = (
|
|
105
|
+
mla.mla_ragged_paged_attention(
|
|
106
|
+
ql_nope_for_kernel,
|
|
107
|
+
q_pe_for_kernel,
|
|
108
|
+
new_kv_c,
|
|
109
|
+
new_k_pe,
|
|
110
|
+
cache_kv_c.copy(),
|
|
111
|
+
cache_k_pe.copy(),
|
|
112
|
+
kv_lens,
|
|
113
|
+
page_indices,
|
|
114
|
+
cu_q_lens,
|
|
115
|
+
distribution,
|
|
116
|
+
sm_scale=sm_scale,
|
|
117
|
+
sliding_window=sliding_window,
|
|
118
|
+
soft_cap=soft_cap,
|
|
119
|
+
num_kv_pages_per_block=num_kv_pages_per_block,
|
|
120
|
+
num_queries_per_block=num_queries_per_block,
|
|
121
|
+
vmem_limit_bytes=vmem_limit_bytes,
|
|
122
|
+
))
|
|
118
123
|
|
|
119
124
|
self.assertEqual(expected_out.shape,
|
|
120
125
|
(total_q_len, num_heads, padded_lkv_dim))
|
|
121
126
|
self.assertEqual(
|
|
122
|
-
|
|
123
|
-
(total_num_pages, page_size // packing, packing,
|
|
127
|
+
expected_updated_kv_c.shape,
|
|
128
|
+
(total_num_pages, page_size // packing, packing, padded_lkv_dim),
|
|
129
|
+
)
|
|
130
|
+
self.assertEqual(
|
|
131
|
+
expeceted_updated_k_pe.shape,
|
|
132
|
+
(total_num_pages, page_size // packing, packing, padded_r_dim),
|
|
124
133
|
)
|
|
125
134
|
self.assertEqual(expected_out.dtype, kv_dtype)
|
|
126
|
-
self.assertEqual(
|
|
135
|
+
self.assertEqual(expected_updated_kv_c.dtype, kv_dtype)
|
|
136
|
+
self.assertEqual(expeceted_updated_k_pe.dtype, kv_dtype)
|
|
127
137
|
|
|
128
138
|
self.assertAllClose(expected_out, kernel_out, atol=0.2, rtol=0.2)
|
|
129
|
-
self.assertAllClose(
|
|
130
|
-
|
|
139
|
+
self.assertAllClose(expected_updated_kv_c,
|
|
140
|
+
kernel_updated_kv_c,
|
|
141
|
+
atol=0.2,
|
|
142
|
+
rtol=0.2)
|
|
143
|
+
self.assertAllClose(expeceted_updated_k_pe,
|
|
144
|
+
kernel_updated_k_pe,
|
|
131
145
|
atol=0.2,
|
|
132
146
|
rtol=0.2)
|
|
133
|
-
|
|
134
|
-
def test_update_kv_cache(self):
|
|
135
|
-
lkv_dim = 4
|
|
136
|
-
r_dim = 4
|
|
137
|
-
padded_lkv_dim = align_to(lkv_dim, 128)
|
|
138
|
-
padded_r_dim = align_to(r_dim, 128)
|
|
139
|
-
kv_dtype = jnp.bfloat16
|
|
140
|
-
new_kv_c = jnp.arange(16, dtype=kv_dtype).reshape((4, lkv_dim))
|
|
141
|
-
new_k_pe = (jnp.arange(16, dtype=kv_dtype).reshape((4, r_dim)) + 100)
|
|
142
|
-
total_num_pages = 2
|
|
143
|
-
page_size = 4
|
|
144
|
-
cache_kv_shape = mla.get_kv_cache_shape(
|
|
145
|
-
total_num_pages,
|
|
146
|
-
page_size,
|
|
147
|
-
padded_lkv_dim + padded_r_dim,
|
|
148
|
-
kv_dtype,
|
|
149
|
-
)
|
|
150
|
-
cache_kv = jnp.zeros(cache_kv_shape, dtype=kv_dtype)
|
|
151
|
-
|
|
152
|
-
# two sequences, first with 3 tokens, second with 1 token
|
|
153
|
-
kv_lens = jnp.array([3, 1], dtype=jnp.int32)
|
|
154
|
-
# first seq uses page 0, second uses page 1
|
|
155
|
-
page_indices = jnp.array([0, -1, 1, -1], dtype=jnp.int32)
|
|
156
|
-
# three tokens for first seq, one for second
|
|
157
|
-
cu_q_lens = jnp.array([0, 3, 4], dtype=jnp.int32)
|
|
158
|
-
distribution = jnp.array([0, 0, 2], dtype=jnp.int32)
|
|
159
|
-
|
|
160
|
-
# manually compute the expected cache
|
|
161
|
-
padded_new_kv_c = jnp.pad(new_kv_c,
|
|
162
|
-
((0, 0), (0, padded_lkv_dim - lkv_dim)),
|
|
163
|
-
constant_values=0)
|
|
164
|
-
padded_new_k_pe = jnp.pad(new_k_pe,
|
|
165
|
-
((0, 0), (0, padded_r_dim - r_dim)),
|
|
166
|
-
constant_values=0)
|
|
167
|
-
|
|
168
|
-
expected_cache = cache_kv
|
|
169
|
-
# First sequence
|
|
170
|
-
# token 0
|
|
171
|
-
page_idx, row, col = 0, 0, 0
|
|
172
|
-
expected_cache = expected_cache.at[page_idx, row,
|
|
173
|
-
col, :padded_lkv_dim].set(
|
|
174
|
-
padded_new_kv_c[0])
|
|
175
|
-
expected_cache = expected_cache.at[page_idx, row, col,
|
|
176
|
-
padded_lkv_dim:padded_lkv_dim +
|
|
177
|
-
padded_r_dim].set(
|
|
178
|
-
padded_new_k_pe[0])
|
|
179
|
-
# token 1
|
|
180
|
-
page_idx, row, col = 0, 0, 1
|
|
181
|
-
expected_cache = expected_cache.at[page_idx, row,
|
|
182
|
-
col, :padded_lkv_dim].set(
|
|
183
|
-
padded_new_kv_c[1])
|
|
184
|
-
expected_cache = expected_cache.at[page_idx, row, col,
|
|
185
|
-
padded_lkv_dim:padded_lkv_dim +
|
|
186
|
-
padded_r_dim].set(
|
|
187
|
-
padded_new_k_pe[1])
|
|
188
|
-
# token 2
|
|
189
|
-
page_idx, row, col = 0, 1, 0
|
|
190
|
-
expected_cache = expected_cache.at[page_idx, row,
|
|
191
|
-
col, :padded_lkv_dim].set(
|
|
192
|
-
padded_new_kv_c[2])
|
|
193
|
-
expected_cache = expected_cache.at[page_idx, row, col,
|
|
194
|
-
padded_lkv_dim:padded_lkv_dim +
|
|
195
|
-
padded_r_dim].set(
|
|
196
|
-
padded_new_k_pe[2])
|
|
197
|
-
|
|
198
|
-
# Second sequence
|
|
199
|
-
# token 0
|
|
200
|
-
page_idx, row, col = 1, 0, 0
|
|
201
|
-
expected_cache = expected_cache.at[page_idx, row,
|
|
202
|
-
col, :padded_lkv_dim].set(
|
|
203
|
-
padded_new_kv_c[3])
|
|
204
|
-
expected_cache = expected_cache.at[page_idx, row, col,
|
|
205
|
-
padded_lkv_dim:padded_lkv_dim +
|
|
206
|
-
padded_r_dim].set(
|
|
207
|
-
padded_new_k_pe[3])
|
|
208
|
-
|
|
209
|
-
updated_cache = mla.update_kv_cache(
|
|
210
|
-
new_kv_c,
|
|
211
|
-
new_k_pe,
|
|
212
|
-
cache_kv,
|
|
213
|
-
kv_lens,
|
|
214
|
-
page_indices,
|
|
215
|
-
cu_q_lens,
|
|
216
|
-
distribution,
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
self.assertAllClose(updated_cache, expected_cache)
|
|
220
|
-
|
|
221
|
-
def test_get_kv_cache_shape(self):
|
|
222
|
-
total_num_pages = 10
|
|
223
|
-
page_size = 16
|
|
224
|
-
lkv_dim = 128
|
|
225
|
-
kv_dtype = jnp.bfloat16
|
|
226
|
-
# The calculation for the expected shape is as follows:
|
|
227
|
-
# kv_packing is determined by the dtype, which is 2 for bfloat16.
|
|
228
|
-
# The second dimension is page_size / kv_packing = 16 / 2 = 8
|
|
229
|
-
# The third dimension is kv_packing = 2
|
|
230
|
-
# The fourth dimension is lkv_dim aligned to 128, which is 128
|
|
231
|
-
expected_shape = (10, 8, 2, 128)
|
|
232
|
-
self.assertEqual(
|
|
233
|
-
mla.get_kv_cache_shape(total_num_pages, page_size, lkv_dim,
|
|
234
|
-
kv_dtype), expected_shape)
|
|
235
147
|
|
|
236
148
|
def test_ragged_paged_attention_basic(self):
|
|
237
149
|
dtype = jnp.bfloat16
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
+
import functools
|
|
4
|
+
|
|
3
5
|
import jax
|
|
4
6
|
import jax.numpy as jnp
|
|
5
7
|
from absl.testing import absltest, parameterized
|
|
@@ -8,7 +10,6 @@ from jax._src import test_util as jtu
|
|
|
8
10
|
from tpu_inference.kernels.quantized_matmul import (kernel, tuned_block_sizes,
|
|
9
11
|
util)
|
|
10
12
|
|
|
11
|
-
xla_quantized_matmul = kernel.xla_quantized_matmul
|
|
12
13
|
quantized_matmul_kernel = kernel.quantized_matmul_kernel
|
|
13
14
|
quantize_tensor = util.quantize_tensor
|
|
14
15
|
get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
|
|
@@ -16,6 +17,37 @@ get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
|
|
|
16
17
|
jax.config.parse_flags_with_absl()
|
|
17
18
|
|
|
18
19
|
|
|
20
|
+
@functools.partial(jax.jit, static_argnames=["quantize_activation"])
|
|
21
|
+
def reference_quantized_matmul(
|
|
22
|
+
x: jax.Array,
|
|
23
|
+
w_q: jax.Array,
|
|
24
|
+
w_scale: jax.Array,
|
|
25
|
+
quantize_activation=True,
|
|
26
|
+
):
|
|
27
|
+
if quantize_activation:
|
|
28
|
+
acc_dtype = jnp.float32
|
|
29
|
+
if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
|
|
30
|
+
acc_dtype = jnp.int32
|
|
31
|
+
|
|
32
|
+
x_q, x_scale = quantize_tensor(x, w_q.dtype)
|
|
33
|
+
out = jax.lax.dot_general(
|
|
34
|
+
x_q,
|
|
35
|
+
w_q,
|
|
36
|
+
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
37
|
+
preferred_element_type=acc_dtype,
|
|
38
|
+
).astype(jnp.float32)
|
|
39
|
+
out *= x_scale
|
|
40
|
+
else:
|
|
41
|
+
out = jax.lax.dot_general(
|
|
42
|
+
x,
|
|
43
|
+
w_q,
|
|
44
|
+
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
45
|
+
preferred_element_type=jnp.float32,
|
|
46
|
+
)
|
|
47
|
+
out *= jnp.expand_dims(w_scale, 0)
|
|
48
|
+
return out.astype(x.dtype)
|
|
49
|
+
|
|
50
|
+
|
|
19
51
|
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
20
52
|
class QuantizedMatmulKernelTest(jtu.JaxTestCase):
|
|
21
53
|
|
|
@@ -62,7 +94,7 @@ class QuantizedMatmulKernelTest(jtu.JaxTestCase):
|
|
|
62
94
|
x_q_dtype=x_q_dtype,
|
|
63
95
|
tuned_value=tuned_value,
|
|
64
96
|
)
|
|
65
|
-
expected =
|
|
97
|
+
expected = reference_quantized_matmul(
|
|
66
98
|
x, w_q, w_scale, quantize_activation=quantize_activation)
|
|
67
99
|
|
|
68
100
|
self.assertAllClose(output,
|
|
@@ -176,9 +176,7 @@ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
|
|
|
176
176
|
)
|
|
177
177
|
output = output[:cu_q_lens[distribution[-1]]]
|
|
178
178
|
|
|
179
|
-
dtype_bits =
|
|
180
|
-
dtypes, "bit_width") else dtypes.itemsize_bits(
|
|
181
|
-
jnp.dtype(kv_dtype)))
|
|
179
|
+
dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
|
|
182
180
|
tols = {
|
|
183
181
|
32: 0.15,
|
|
184
182
|
16: 0.2,
|
|
@@ -162,9 +162,7 @@ class RaggedPagedAttentionKernelTest(jtu.JaxTestCase):
|
|
|
162
162
|
)
|
|
163
163
|
output = output[:cu_q_lens[distribution[-1]]]
|
|
164
164
|
|
|
165
|
-
dtype_bits =
|
|
166
|
-
dtypes, "bit_width") else dtypes.itemsize_bits(
|
|
167
|
-
jnp.dtype(kv_dtype)))
|
|
165
|
+
dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
|
|
168
166
|
tols = {
|
|
169
167
|
32: 0.15,
|
|
170
168
|
16: 0.2,
|
|
@@ -18,7 +18,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
|
|
18
18
|
ReplicatedLinearWithLoRA,
|
|
19
19
|
RowParallelLinearWithLoRA)
|
|
20
20
|
# yapf: enable
|
|
21
|
-
from vllm.lora.
|
|
21
|
+
from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
|
|
22
22
|
from vllm.lora.punica_wrapper import get_punica_wrapper
|
|
23
23
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
24
24
|
MergedColumnParallelLinear,
|
|
@@ -499,13 +499,9 @@ def _create_random_linear_parallel_layer(layer_type, vllm_config, mesh):
|
|
|
499
499
|
return linear, lora_linear
|
|
500
500
|
|
|
501
501
|
|
|
502
|
-
def _get_devices():
|
|
503
|
-
return jax.devices()
|
|
504
|
-
|
|
505
|
-
|
|
506
502
|
def _create_mesh():
|
|
507
503
|
axis_names = ("data", "model")
|
|
508
|
-
devices =
|
|
504
|
+
devices = jax.devices()
|
|
509
505
|
mesh_shape = (1, len(devices))
|
|
510
506
|
mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices)
|
|
511
507
|
return mesh
|
|
@@ -517,7 +513,7 @@ def _verify_lora_linear_layer(linear, lora_linear):
|
|
|
517
513
|
# BaseLinearLayerWithLoRA.weight property guarantees this.
|
|
518
514
|
# if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
|
|
519
515
|
# So the below check will fail.
|
|
520
|
-
if len(
|
|
516
|
+
if len(jax.devices()) == 1:
|
|
521
517
|
assert torch.equal(linear.weight.data,
|
|
522
518
|
lora_linear.weight.to('cpu'))
|
|
523
519
|
|
|
@@ -29,7 +29,7 @@ def setup_vllm(num_loras: int, tp: int = 1) -> vllm.LLM:
|
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
# For multi-chip test, we only use TP=2 because the base model Qwen/Qwen2.5-3B-Instruct has 2 kv heads and the current attention kernel requires it to be divisible by tp_size.
|
|
32
|
-
TP = [2] if os.environ.get("
|
|
32
|
+
TP = [2] if os.environ.get("USE_V6E8_QUEUE", False) else [1]
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
@pytest.mark.parametrize("tp", TP)
|
|
@@ -60,7 +60,6 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
60
60
|
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
61
61
|
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
62
62
|
monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
|
|
63
|
-
monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "0")
|
|
64
63
|
monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
|
|
65
64
|
|
|
66
65
|
# Test SKIP_JAX_PRECOMPILE (default False)
|
|
@@ -87,82 +86,6 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
87
86
|
monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
|
|
88
87
|
assert envs.USE_MOE_EP_KERNEL is True
|
|
89
88
|
|
|
90
|
-
# Test ENABLE_QUANTIZED_MATMUL_KERNEL (default False)
|
|
91
|
-
assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is False
|
|
92
|
-
monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "1")
|
|
93
|
-
assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is True
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def test_boolean_env_vars_string_values(monkeypatch: pytest.MonkeyPatch):
|
|
97
|
-
"""Test that boolean env vars accept string values like 'True' and 'False'"""
|
|
98
|
-
|
|
99
|
-
# Test NEW_MODEL_DESIGN with string "True"
|
|
100
|
-
monkeypatch.setenv("NEW_MODEL_DESIGN", "True")
|
|
101
|
-
assert envs.NEW_MODEL_DESIGN is True
|
|
102
|
-
|
|
103
|
-
monkeypatch.setenv("NEW_MODEL_DESIGN", "true")
|
|
104
|
-
assert envs.NEW_MODEL_DESIGN is True
|
|
105
|
-
|
|
106
|
-
monkeypatch.setenv("NEW_MODEL_DESIGN", "False")
|
|
107
|
-
assert envs.NEW_MODEL_DESIGN is False
|
|
108
|
-
|
|
109
|
-
monkeypatch.setenv("NEW_MODEL_DESIGN", "false")
|
|
110
|
-
assert envs.NEW_MODEL_DESIGN is False
|
|
111
|
-
|
|
112
|
-
# Test SKIP_JAX_PRECOMPILE with string values
|
|
113
|
-
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "True")
|
|
114
|
-
assert envs.SKIP_JAX_PRECOMPILE is True
|
|
115
|
-
|
|
116
|
-
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "false")
|
|
117
|
-
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
118
|
-
|
|
119
|
-
# Test VLLM_XLA_CHECK_RECOMPILATION with string values
|
|
120
|
-
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "TRUE")
|
|
121
|
-
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
|
|
122
|
-
|
|
123
|
-
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "FALSE")
|
|
124
|
-
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
125
|
-
|
|
126
|
-
# Test USE_MOE_EP_KERNEL with string values
|
|
127
|
-
monkeypatch.setenv("USE_MOE_EP_KERNEL", "true")
|
|
128
|
-
assert envs.USE_MOE_EP_KERNEL is True
|
|
129
|
-
|
|
130
|
-
monkeypatch.setenv("USE_MOE_EP_KERNEL", "False")
|
|
131
|
-
assert envs.USE_MOE_EP_KERNEL is False
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
def test_boolean_env_vars_invalid_values(monkeypatch: pytest.MonkeyPatch):
|
|
135
|
-
"""Test that boolean env vars raise errors for invalid values"""
|
|
136
|
-
|
|
137
|
-
# Test invalid value for NEW_MODEL_DESIGN
|
|
138
|
-
monkeypatch.setenv("NEW_MODEL_DESIGN", "yes")
|
|
139
|
-
with pytest.raises(
|
|
140
|
-
ValueError,
|
|
141
|
-
match="Invalid boolean value 'yes' for NEW_MODEL_DESIGN"):
|
|
142
|
-
_ = envs.NEW_MODEL_DESIGN
|
|
143
|
-
|
|
144
|
-
monkeypatch.setenv("NEW_MODEL_DESIGN", "2")
|
|
145
|
-
with pytest.raises(ValueError,
|
|
146
|
-
match="Invalid boolean value '2' for NEW_MODEL_DESIGN"):
|
|
147
|
-
_ = envs.NEW_MODEL_DESIGN
|
|
148
|
-
|
|
149
|
-
# Test invalid value for SKIP_JAX_PRECOMPILE
|
|
150
|
-
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "invalid")
|
|
151
|
-
with pytest.raises(
|
|
152
|
-
ValueError,
|
|
153
|
-
match="Invalid boolean value 'invalid' for SKIP_JAX_PRECOMPILE"):
|
|
154
|
-
_ = envs.SKIP_JAX_PRECOMPILE
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
def test_boolean_env_vars_empty_string(monkeypatch: pytest.MonkeyPatch):
|
|
158
|
-
"""Test that empty string returns default value"""
|
|
159
|
-
|
|
160
|
-
monkeypatch.setenv("NEW_MODEL_DESIGN", "")
|
|
161
|
-
assert envs.NEW_MODEL_DESIGN is False # Should return default
|
|
162
|
-
|
|
163
|
-
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "")
|
|
164
|
-
assert envs.SKIP_JAX_PRECOMPILE is False # Should return default
|
|
165
|
-
|
|
166
89
|
|
|
167
90
|
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
168
91
|
# Ensure clean environment for integer vars by setting to defaults
|
|
@@ -256,7 +179,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
256
179
|
|
|
257
180
|
def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
|
|
258
181
|
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
|
|
259
|
-
assert envs.MODEL_IMPL_TYPE == "
|
|
182
|
+
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
|
|
260
183
|
|
|
261
184
|
|
|
262
185
|
def test_cache_preserves_values_across_env_changes(
|
|
@@ -112,8 +112,6 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
|
|
|
112
112
|
self.mesh = Mesh(jax.devices(), ('model', ))
|
|
113
113
|
self.rng = jax.random.PRNGKey(0)
|
|
114
114
|
self.model = SimpleModel(rngs=nnx.Rngs(0))
|
|
115
|
-
self.model.vllm_config = MagicMock()
|
|
116
|
-
self.model.vllm_config.model_config.use_mla = False
|
|
117
115
|
|
|
118
116
|
self.qwix_config = [
|
|
119
117
|
{
|
|
@@ -133,7 +131,6 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
|
|
|
133
131
|
"""Test that qwix.quantize_model is called with the correct arguments."""
|
|
134
132
|
quantized_model_mock = MagicMock(spec=nnx.Module)
|
|
135
133
|
mock_quantize_model.return_value = quantized_model_mock
|
|
136
|
-
self.model.vllm_config.sharding_config.total_dp_size = 1
|
|
137
134
|
|
|
138
135
|
with patch(
|
|
139
136
|
"tpu_inference.models.jax.utils.quantization.quantization_utils.init_logger",
|
|
@@ -694,9 +694,9 @@ class TPUConnectorWorker:
|
|
|
694
694
|
|
|
695
695
|
def get_uuid() -> int:
|
|
696
696
|
int128 = uuid4().int
|
|
697
|
-
# Must be
|
|
698
|
-
|
|
699
|
-
return
|
|
697
|
+
# Must be 64-bit int, otherwise vllm output encoder would raise error.
|
|
698
|
+
int64 = int128 >> 64
|
|
699
|
+
return int64
|
|
700
700
|
|
|
701
701
|
|
|
702
702
|
@jax.jit
|
|
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
|
|
16
16
|
DECODE_SLICES: str = ""
|
|
17
17
|
SKIP_JAX_PRECOMPILE: bool = False
|
|
18
18
|
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
|
19
|
-
MODEL_IMPL_TYPE: str = "
|
|
19
|
+
MODEL_IMPL_TYPE: str = "flax_nnx"
|
|
20
20
|
NEW_MODEL_DESIGN: bool = False
|
|
21
21
|
PHASED_PROFILING_DIR: str = ""
|
|
22
22
|
PYTHON_TRACER_LEVEL: int = 1
|
|
@@ -24,7 +24,6 @@ if TYPE_CHECKING:
|
|
|
24
24
|
NUM_SLICES: int = 1
|
|
25
25
|
RAY_USAGE_STATS_ENABLED: str = "0"
|
|
26
26
|
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
|
|
27
|
-
ENABLE_QUANTIZED_MATMUL_KERNEL: bool = False
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
def env_with_choices(
|
|
@@ -70,34 +69,6 @@ def env_with_choices(
|
|
|
70
69
|
return _get_validated_env
|
|
71
70
|
|
|
72
71
|
|
|
73
|
-
def env_bool(env_name: str, default: bool = False) -> Callable[[], bool]:
|
|
74
|
-
"""
|
|
75
|
-
Accepts both numeric strings ("0", "1") and boolean strings
|
|
76
|
-
("true", "false", "True", "False").
|
|
77
|
-
|
|
78
|
-
Args:
|
|
79
|
-
env_name: Name of the environment variable
|
|
80
|
-
default: Default boolean value if not set
|
|
81
|
-
"""
|
|
82
|
-
|
|
83
|
-
def _get_bool_env() -> bool:
|
|
84
|
-
value = os.getenv(env_name)
|
|
85
|
-
if value is None or value == "":
|
|
86
|
-
return default
|
|
87
|
-
|
|
88
|
-
value_lower = value.lower()
|
|
89
|
-
if value_lower in ("true", "1"):
|
|
90
|
-
return True
|
|
91
|
-
elif value_lower in ("false", "0"):
|
|
92
|
-
return False
|
|
93
|
-
else:
|
|
94
|
-
raise ValueError(
|
|
95
|
-
f"Invalid boolean value '{value}' for {env_name}. "
|
|
96
|
-
f"Valid options: '0', '1', 'true', 'false', 'True', 'False'.")
|
|
97
|
-
|
|
98
|
-
return _get_bool_env
|
|
99
|
-
|
|
100
|
-
|
|
101
72
|
environment_variables: dict[str, Callable[[], Any]] = {
|
|
102
73
|
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
|
|
103
74
|
"JAX_PLATFORMS":
|
|
@@ -122,17 +93,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
122
93
|
lambda: os.getenv("DECODE_SLICES", ""),
|
|
123
94
|
# Skip JAX precompilation step during initialization
|
|
124
95
|
"SKIP_JAX_PRECOMPILE":
|
|
125
|
-
|
|
96
|
+
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE") or "0")),
|
|
126
97
|
# Check for XLA recompilation during execution
|
|
127
98
|
"VLLM_XLA_CHECK_RECOMPILATION":
|
|
128
|
-
|
|
99
|
+
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION") or "0")),
|
|
129
100
|
# Model implementation type (e.g., "flax_nnx")
|
|
130
101
|
"MODEL_IMPL_TYPE":
|
|
131
|
-
env_with_choices("MODEL_IMPL_TYPE", "
|
|
132
|
-
["
|
|
102
|
+
env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
|
|
103
|
+
["vllm", "flax_nnx", "jetpack"]),
|
|
133
104
|
# Enable new experimental model design
|
|
134
105
|
"NEW_MODEL_DESIGN":
|
|
135
|
-
|
|
106
|
+
lambda: bool(int(os.getenv("NEW_MODEL_DESIGN") or "0")),
|
|
136
107
|
# Directory to store phased profiling output
|
|
137
108
|
"PHASED_PROFILING_DIR":
|
|
138
109
|
lambda: os.getenv("PHASED_PROFILING_DIR", ""),
|
|
@@ -141,7 +112,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
141
112
|
lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
|
|
142
113
|
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
|
|
143
114
|
"USE_MOE_EP_KERNEL":
|
|
144
|
-
|
|
115
|
+
lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL") or "0")),
|
|
145
116
|
# Number of TPU slices for multi-slice mesh
|
|
146
117
|
"NUM_SLICES":
|
|
147
118
|
lambda: int(os.getenv("NUM_SLICES") or "1"),
|
|
@@ -151,8 +122,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
151
122
|
# Ray compiled DAG channel type for TPU
|
|
152
123
|
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
|
|
153
124
|
env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
|
|
154
|
-
"ENABLE_QUANTIZED_MATMUL_KERNEL":
|
|
155
|
-
lambda: bool(int(os.getenv("ENABLE_QUANTIZED_MATMUL_KERNEL") or "0")),
|
|
156
125
|
}
|
|
157
126
|
|
|
158
127
|
|
|
@@ -145,9 +145,6 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
145
145
|
device_str: node['Resources'][device_str]
|
|
146
146
|
} for node in ray_nodes]
|
|
147
147
|
else:
|
|
148
|
-
assert pp_size == len(
|
|
149
|
-
ray_nodes
|
|
150
|
-
), f"Cannot use PP across hosts, please set --pipeline-parallel-size to 1 or {len(ray_nodes)}"
|
|
151
148
|
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
|
|
152
149
|
placement_group_specs = [{
|
|
153
150
|
device_str: num_devices_per_pp_rank
|