tpu-inference 0.12.0.dev20251207__tar.gz → 0.12.0.dev20251219__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- {tpu_inference-0.12.0.dev20251207/tpu_inference.egg-info → tpu_inference-0.12.0.dev20251219}/PKG-INFO +5 -7
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/README.md +4 -6
- tpu_inference-0.12.0.dev20251219/tests/kernels/gmm_test.py +191 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/quantized_matmul_kernel_test.py +2 -34
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/lora/test_layers.py +7 -3
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/lora/test_lora.py +1 -1
- tpu_inference-0.12.0.dev20251219/tests/lora/test_lora_perf.py +53 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/test_envs.py +78 -1
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/distributed/tpu_connector.py +3 -3
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/envs.py +38 -7
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/executors/ray_distributed_executor.py +3 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/fused_moe/v1/kernel.py +357 -324
- tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/megablox/common.py +41 -0
- tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/megablox/gmm.py +633 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +181 -101
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +82 -78
- tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4447 -0
- tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +535 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/common/attention_interface.py +1 -7
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/common/quant_methods.py +1 -0
- tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/pp_utils.py +39 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/fused_moe.py +87 -67
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/linear_common.py +43 -21
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/quantization/__init__.py +2 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/quantization/awq.py +1 -1
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/quantization/common.py +5 -5
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +252 -0
- tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/fp8.py +104 -0
- tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/mxfp4.py +448 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/quantization/unquantized.py +83 -47
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/lora/torch_lora_ops.py +8 -13
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/common/model_loader.py +43 -18
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/llama3.py +79 -33
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/weight_utils.py +19 -1
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/vllm/vllm_model_wrapper.py +1 -1
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/platforms/tpu_platform.py +8 -34
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/kv_cache.py +3 -1
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/tpu_runner.py +5 -5
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/utils.py +2 -1
- tpu_inference-0.12.0.dev20251219/tpu_inference/worker/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/worker/tpu_worker.py +22 -36
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219/tpu_inference.egg-info}/PKG-INFO +5 -7
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference.egg-info/SOURCES.txt +7 -0
- tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -4147
- tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +0 -367
- tpu_inference-0.12.0.dev20251207/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +0 -203
- tpu_inference-0.12.0.dev20251207/tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/MANIFEST.in +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/pyproject.toml +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/requirements.txt +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/setup.cfg +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/setup.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/core/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/core/test_core_tpu.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/core/test_disagg_executor.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/core/test_disagg_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/core/test_dp_scheduler.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/core/test_init.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/fused_moe_v1_test.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/mla_v1_test.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/ragged_kv_cache_update_v2_test.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/kernels/ragged_paged_attention_kernel_v2_test.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/lora/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/lora/conftest.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/lora/test_bgmv.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/lora/utils.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/test_base.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/test_quantization.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/test_tpu_info.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/test_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/core/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/core/core_tpu.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/core/disagg_executor.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/core/disagg_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/core/sched/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/core/sched/dp_scheduler.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/distributed/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/distributed/jax_parallel_state.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/distributed/utils.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/env_override.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/executors/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/experimental/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/experimental/llama3_jax_stashed.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/collectives/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/collectives/util.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/flash_attention/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/flash_attention/kernel.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/fused_moe/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/mla → tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/megablox}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/mla/v1 → tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/mla}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/quantized_matmul → tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/mla/v1}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/mla/v1/kernel.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/ragged_paged_attention → tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/quantized_matmul}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/quantized_matmul/util.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/ragged_paged_attention/v2 → tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/kernels/ragged_paged_attention/v3 → tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v2}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/layers → tpu_inference-0.12.0.dev20251219/tpu_inference/kernels/ragged_paged_attention/v3}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/common → tpu_inference-0.12.0.dev20251219/tpu_inference/layers}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/jax → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/common}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/common/attention_metadata.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/common/binary_search.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/common/sharding.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/jax/attention → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/jax/moe → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/attention}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/attention/attention.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/attention/deepseek_v3_attention.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/attention/gpt_oss_attention.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/attention/llama4_attention.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/base.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/constants.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/layers.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/misc.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/jax/sample → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/moe}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/moe/deepseek_v3_moe.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/moe/gpt_oss_moe.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/moe/moe.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/rope.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/rope_interface.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/vllm → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/jax/sample}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/sample/rejection_sampler.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/sample/sampling.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/sample/sampling_metadata.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/jax/transformer_block.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/vllm/quantization/compressed_tensors → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/attention.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/compressed_tensors}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/lora → tpu_inference-0.12.0.dev20251219/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/layers/vllm/sharding.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/logger.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/models → tpu_inference-0.12.0.dev20251219/tpu_inference/lora}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/lora/torch_punica_tpu.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/models/common → tpu_inference-0.12.0.dev20251219/tpu_inference/models}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/models/jax → tpu_inference-0.12.0.dev20251219/tpu_inference/models/common}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/models/jax/utils → tpu_inference-0.12.0.dev20251219/tpu_inference/models/jax}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/deepseek_v3.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/gpt_oss.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/jax_intermediate_tensor.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/llama4.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/llama_eagle3.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/llama_guard_4.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/qwen2.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/qwen2_5_vl.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/qwen3.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/models/jax/utils/quantization → tpu_inference-0.12.0.dev20251219/tpu_inference/models/jax/utils}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/file_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/multi_modal_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/models/vllm → tpu_inference-0.12.0.dev20251219/tpu_inference/models/jax/utils/quantization}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/jax/utils/quantization/quantization_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/runner → tpu_inference-0.12.0.dev20251219/tpu_inference/models/vllm}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/models/vllm/vllm_model_wrapper_context.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/platforms/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/spec_decode → tpu_inference-0.12.0.dev20251219/tpu_inference/runner}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/block_table.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/compilation_manager.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/input_batch.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/kv_cache_manager.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/lora_utils.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/multimodal_manager.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/persistent_batch_manager.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/speculative_decoding_manager.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/structured_decoding_manager.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/runner/utils.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/spec_decode/jax → tpu_inference-0.12.0.dev20251219/tpu_inference/spec_decode}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207/tpu_inference/worker → tpu_inference-0.12.0.dev20251219/tpu_inference/spec_decode/jax}/__init__.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/spec_decode/jax/eagle3.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference/tpu_info.py +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference.egg-info/dependency_links.txt +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference.egg-info/requires.txt +0 -0
- {tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tpu_inference.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tpu_inference
|
|
3
|
-
Version: 0.12.0.
|
|
3
|
+
Version: 0.12.0.dev20251219
|
|
4
4
|
Author: tpu_inference Contributors
|
|
5
5
|
Classifier: Development Status :: 3 - Alpha
|
|
6
6
|
Classifier: Intended Audience :: Developers
|
|
@@ -53,14 +53,12 @@ Dynamic: requires-python
|
|
|
53
53
|
|
|
54
54
|
---
|
|
55
55
|
|
|
56
|
-
_Upcoming Events_ 🔥
|
|
57
|
-
|
|
58
|
-
- Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
|
|
59
|
-
- Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
60
|
-
- Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
61
|
-
|
|
62
56
|
_Latest News_ 🔥
|
|
63
57
|
|
|
58
|
+
- [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
|
|
59
|
+
- Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
60
|
+
- Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
61
|
+
|
|
64
62
|
- [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
|
|
65
63
|
|
|
66
64
|
<details>
|
|
@@ -11,14 +11,12 @@
|
|
|
11
11
|
|
|
12
12
|
---
|
|
13
13
|
|
|
14
|
-
_Upcoming Events_ 🔥
|
|
15
|
-
|
|
16
|
-
- Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
|
|
17
|
-
- Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
18
|
-
- Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
19
|
-
|
|
20
14
|
_Latest News_ 🔥
|
|
21
15
|
|
|
16
|
+
- [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
|
|
17
|
+
- Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
|
|
18
|
+
- Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
|
|
19
|
+
|
|
22
20
|
- [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
|
|
23
21
|
|
|
24
22
|
<details>
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
from absl.testing import absltest, parameterized
|
|
4
|
+
from jax._src import test_util as jtu
|
|
5
|
+
|
|
6
|
+
from tpu_inference.kernels.megablox.gmm import gmm
|
|
7
|
+
|
|
8
|
+
jax.config.parse_flags_with_absl()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def quantize_tensor(x: jax.Array,
|
|
12
|
+
dtype: jnp.dtype,
|
|
13
|
+
axis: int = -1,
|
|
14
|
+
block_size: int = 256):
|
|
15
|
+
if jnp.issubdtype(dtype, jnp.integer):
|
|
16
|
+
dtype_info = jnp.iinfo(dtype)
|
|
17
|
+
max_val = int(dtype_info.max)
|
|
18
|
+
min_val = int(dtype_info.min)
|
|
19
|
+
else:
|
|
20
|
+
dtype_info = jnp.finfo(dtype)
|
|
21
|
+
max_val = float(dtype_info.max)
|
|
22
|
+
min_val = float(dtype_info.min)
|
|
23
|
+
|
|
24
|
+
orig_shape = x.shape
|
|
25
|
+
blocked_shape = orig_shape[:axis] + (-1,
|
|
26
|
+
block_size) + orig_shape[axis + 1:]
|
|
27
|
+
x_blocked = x.reshape(blocked_shape)
|
|
28
|
+
|
|
29
|
+
x_blocked_abs_max = jnp.max(jnp.abs(x_blocked),
|
|
30
|
+
axis=axis + 1,
|
|
31
|
+
keepdims=True)
|
|
32
|
+
scale = x_blocked_abs_max / max_val
|
|
33
|
+
x_blocked_q = jnp.clip(x_blocked / scale, min_val, max_val).astype(dtype)
|
|
34
|
+
|
|
35
|
+
x_q = x_blocked_q.reshape(orig_shape)
|
|
36
|
+
scale = scale.squeeze(axis=axis + 1).astype(jnp.float32)
|
|
37
|
+
return x_q, scale
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def reference_gmm(
|
|
41
|
+
lhs: jax.Array,
|
|
42
|
+
rhs: jax.Array,
|
|
43
|
+
group_sizes: jax.Array,
|
|
44
|
+
rhs_scale: jax.Array | None = None,
|
|
45
|
+
rhs_bias: jax.Array | None = None,
|
|
46
|
+
group_offset: jax.Array | None = None,
|
|
47
|
+
):
|
|
48
|
+
num_groups, out_size, in_size = rhs.shape
|
|
49
|
+
assert lhs.shape[1] == in_size
|
|
50
|
+
|
|
51
|
+
if group_offset is None:
|
|
52
|
+
group_offset = jnp.array(0, dtype=jnp.int32)
|
|
53
|
+
start = group_sizes[:group_offset].sum()
|
|
54
|
+
group_sizes = group_sizes[group_offset:]
|
|
55
|
+
assert len(group_sizes) == num_groups
|
|
56
|
+
|
|
57
|
+
if rhs_scale is not None:
|
|
58
|
+
num_blocks = rhs_scale.shape[1]
|
|
59
|
+
else:
|
|
60
|
+
num_blocks = 1
|
|
61
|
+
block_size = in_size // num_blocks
|
|
62
|
+
|
|
63
|
+
gmm_out = [jnp.zeros((start, out_size), lhs.dtype)]
|
|
64
|
+
for group in range(num_groups):
|
|
65
|
+
end = start + group_sizes[group]
|
|
66
|
+
|
|
67
|
+
lhs_slice = lhs[start:end]
|
|
68
|
+
rhs_slice = rhs[group]
|
|
69
|
+
|
|
70
|
+
out = 0
|
|
71
|
+
for block in range(num_blocks):
|
|
72
|
+
block_start = block * block_size
|
|
73
|
+
block_end = block_start + block_size
|
|
74
|
+
lhs_block = lhs_slice[:, block_start:block_end].astype(jnp.float32)
|
|
75
|
+
rhs_block = rhs_slice[:, block_start:block_end].astype(jnp.float32)
|
|
76
|
+
|
|
77
|
+
acc = jnp.einsum("bd,hd->bh", lhs_block, rhs_block)
|
|
78
|
+
if rhs_scale is not None:
|
|
79
|
+
acc *= rhs_scale[group][block]
|
|
80
|
+
out += acc
|
|
81
|
+
if rhs_bias is not None:
|
|
82
|
+
out = out + rhs_bias[group]
|
|
83
|
+
|
|
84
|
+
gmm_out.append(out.astype(lhs.dtype))
|
|
85
|
+
start = end
|
|
86
|
+
|
|
87
|
+
return jnp.concat(gmm_out, axis=0)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
91
|
+
class GmmTest(jtu.JaxTestCase):
|
|
92
|
+
|
|
93
|
+
@parameterized.product(
|
|
94
|
+
batch_size=[128],
|
|
95
|
+
in_size=[1024],
|
|
96
|
+
out_size=[1024],
|
|
97
|
+
num_groups=[16, 32],
|
|
98
|
+
has_bias=[True, False],
|
|
99
|
+
)
|
|
100
|
+
def test_gmm(self, batch_size, in_size, out_size, num_groups, has_bias):
|
|
101
|
+
key = jax.random.key(0)
|
|
102
|
+
|
|
103
|
+
lhs = jax.random.normal(key, (batch_size, in_size), dtype=jnp.bfloat16)
|
|
104
|
+
rhs = jax.random.normal(key, (num_groups, out_size, in_size),
|
|
105
|
+
dtype=jnp.bfloat16)
|
|
106
|
+
rhs_bias = None
|
|
107
|
+
if has_bias:
|
|
108
|
+
rhs_bias = jax.random.normal(key, (num_groups, 1, out_size),
|
|
109
|
+
dtype=jnp.bfloat16)
|
|
110
|
+
|
|
111
|
+
group_sizes = jax.random.randint(key, (num_groups, ),
|
|
112
|
+
0,
|
|
113
|
+
batch_size,
|
|
114
|
+
dtype=jnp.int32)
|
|
115
|
+
|
|
116
|
+
expected = reference_gmm(lhs, rhs, group_sizes, rhs_bias=rhs_bias)
|
|
117
|
+
actual = gmm(
|
|
118
|
+
lhs,
|
|
119
|
+
rhs,
|
|
120
|
+
group_sizes,
|
|
121
|
+
rhs_bias=rhs_bias,
|
|
122
|
+
transpose_rhs=True,
|
|
123
|
+
preferred_element_type=jnp.bfloat16,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
self.assertArraysAllClose(actual, expected)
|
|
127
|
+
|
|
128
|
+
@parameterized.product(
|
|
129
|
+
batch_size=[128],
|
|
130
|
+
in_size=[1024],
|
|
131
|
+
out_size=[1024],
|
|
132
|
+
num_groups=[16, 32],
|
|
133
|
+
has_bias=[True, False],
|
|
134
|
+
weight_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn],
|
|
135
|
+
block_size=[256, 512],
|
|
136
|
+
)
|
|
137
|
+
def test_gmm_weight_quantized(
|
|
138
|
+
self,
|
|
139
|
+
batch_size,
|
|
140
|
+
in_size,
|
|
141
|
+
out_size,
|
|
142
|
+
num_groups,
|
|
143
|
+
has_bias,
|
|
144
|
+
weight_dtype,
|
|
145
|
+
block_size,
|
|
146
|
+
):
|
|
147
|
+
if weight_dtype == jnp.float4_e2m1fn and not jtu.is_device_tpu_at_least(
|
|
148
|
+
version=7):
|
|
149
|
+
self.skipTest("Expect TPUv7+")
|
|
150
|
+
key = jax.random.key(0)
|
|
151
|
+
|
|
152
|
+
lhs = jax.random.normal(key, (batch_size, in_size), dtype=jnp.bfloat16)
|
|
153
|
+
rhs = jax.random.normal(key, (num_groups, out_size, in_size),
|
|
154
|
+
dtype=jnp.bfloat16)
|
|
155
|
+
rhs_q, rhs_scale = quantize_tensor(rhs,
|
|
156
|
+
weight_dtype,
|
|
157
|
+
axis=2,
|
|
158
|
+
block_size=block_size)
|
|
159
|
+
rhs_scale = jnp.swapaxes(rhs_scale, 1, 2)
|
|
160
|
+
rhs_scale = jnp.expand_dims(rhs_scale, axis=2)
|
|
161
|
+
|
|
162
|
+
rhs_bias = None
|
|
163
|
+
if has_bias:
|
|
164
|
+
rhs_bias = jax.random.normal(key, (num_groups, 1, out_size),
|
|
165
|
+
dtype=jnp.bfloat16)
|
|
166
|
+
|
|
167
|
+
group_sizes = jax.random.randint(key, (num_groups, ),
|
|
168
|
+
0,
|
|
169
|
+
batch_size,
|
|
170
|
+
dtype=jnp.int32)
|
|
171
|
+
|
|
172
|
+
expected = reference_gmm(lhs,
|
|
173
|
+
rhs_q,
|
|
174
|
+
group_sizes,
|
|
175
|
+
rhs_scale=rhs_scale,
|
|
176
|
+
rhs_bias=rhs_bias)
|
|
177
|
+
actual = gmm(
|
|
178
|
+
lhs,
|
|
179
|
+
rhs_q,
|
|
180
|
+
group_sizes,
|
|
181
|
+
rhs_scale=rhs_scale,
|
|
182
|
+
rhs_bias=rhs_bias,
|
|
183
|
+
transpose_rhs=True,
|
|
184
|
+
preferred_element_type=jnp.bfloat16,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
self.assertArraysAllClose(actual, expected, atol=3e-1, rtol=3e-1)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
if __name__ == "__main__":
|
|
191
|
+
absltest.main(testLoader=jtu.JaxTestLoader())
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
import functools
|
|
4
|
-
|
|
5
3
|
import jax
|
|
6
4
|
import jax.numpy as jnp
|
|
7
5
|
from absl.testing import absltest, parameterized
|
|
@@ -10,6 +8,7 @@ from jax._src import test_util as jtu
|
|
|
10
8
|
from tpu_inference.kernels.quantized_matmul import (kernel, tuned_block_sizes,
|
|
11
9
|
util)
|
|
12
10
|
|
|
11
|
+
xla_quantized_matmul = kernel.xla_quantized_matmul
|
|
13
12
|
quantized_matmul_kernel = kernel.quantized_matmul_kernel
|
|
14
13
|
quantize_tensor = util.quantize_tensor
|
|
15
14
|
get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
|
|
@@ -17,37 +16,6 @@ get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
|
|
|
17
16
|
jax.config.parse_flags_with_absl()
|
|
18
17
|
|
|
19
18
|
|
|
20
|
-
@functools.partial(jax.jit, static_argnames=["quantize_activation"])
|
|
21
|
-
def reference_quantized_matmul(
|
|
22
|
-
x: jax.Array,
|
|
23
|
-
w_q: jax.Array,
|
|
24
|
-
w_scale: jax.Array,
|
|
25
|
-
quantize_activation=True,
|
|
26
|
-
):
|
|
27
|
-
if quantize_activation:
|
|
28
|
-
acc_dtype = jnp.float32
|
|
29
|
-
if quantize_activation and jnp.issubdtype(w_q.dtype, jnp.integer):
|
|
30
|
-
acc_dtype = jnp.int32
|
|
31
|
-
|
|
32
|
-
x_q, x_scale = quantize_tensor(x, w_q.dtype)
|
|
33
|
-
out = jax.lax.dot_general(
|
|
34
|
-
x_q,
|
|
35
|
-
w_q,
|
|
36
|
-
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
37
|
-
preferred_element_type=acc_dtype,
|
|
38
|
-
).astype(jnp.float32)
|
|
39
|
-
out *= x_scale
|
|
40
|
-
else:
|
|
41
|
-
out = jax.lax.dot_general(
|
|
42
|
-
x,
|
|
43
|
-
w_q,
|
|
44
|
-
dimension_numbers=(((1, ), (1, )), ((), ())),
|
|
45
|
-
preferred_element_type=jnp.float32,
|
|
46
|
-
)
|
|
47
|
-
out *= jnp.expand_dims(w_scale, 0)
|
|
48
|
-
return out.astype(x.dtype)
|
|
49
|
-
|
|
50
|
-
|
|
51
19
|
@jtu.with_config(jax_numpy_dtype_promotion="standard")
|
|
52
20
|
class QuantizedMatmulKernelTest(jtu.JaxTestCase):
|
|
53
21
|
|
|
@@ -94,7 +62,7 @@ class QuantizedMatmulKernelTest(jtu.JaxTestCase):
|
|
|
94
62
|
x_q_dtype=x_q_dtype,
|
|
95
63
|
tuned_value=tuned_value,
|
|
96
64
|
)
|
|
97
|
-
expected =
|
|
65
|
+
expected = xla_quantized_matmul(
|
|
98
66
|
x, w_q, w_scale, quantize_activation=quantize_activation)
|
|
99
67
|
|
|
100
68
|
self.assertAllClose(output,
|
|
@@ -176,7 +176,9 @@ class RaggedPagedAttentionHeadDim64KernelTest(jtu.JaxTestCase):
|
|
|
176
176
|
)
|
|
177
177
|
output = output[:cu_q_lens[distribution[-1]]]
|
|
178
178
|
|
|
179
|
-
dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
|
|
179
|
+
dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
|
|
180
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(
|
|
181
|
+
jnp.dtype(kv_dtype)))
|
|
180
182
|
tols = {
|
|
181
183
|
32: 0.15,
|
|
182
184
|
16: 0.2,
|
|
@@ -162,7 +162,9 @@ class RaggedPagedAttentionKernelTest(jtu.JaxTestCase):
|
|
|
162
162
|
)
|
|
163
163
|
output = output[:cu_q_lens[distribution[-1]]]
|
|
164
164
|
|
|
165
|
-
dtype_bits = dtypes.bit_width(jnp.dtype(kv_dtype))
|
|
165
|
+
dtype_bits = (dtypes.bit_width(jnp.dtype(kv_dtype)) if hasattr(
|
|
166
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(
|
|
167
|
+
jnp.dtype(kv_dtype)))
|
|
166
168
|
tols = {
|
|
167
169
|
32: 0.15,
|
|
168
170
|
16: 0.2,
|
{tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/lora/test_layers.py
RENAMED
|
@@ -18,7 +18,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
|
|
18
18
|
ReplicatedLinearWithLoRA,
|
|
19
19
|
RowParallelLinearWithLoRA)
|
|
20
20
|
# yapf: enable
|
|
21
|
-
from vllm.lora.
|
|
21
|
+
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
|
22
22
|
from vllm.lora.punica_wrapper import get_punica_wrapper
|
|
23
23
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
24
24
|
MergedColumnParallelLinear,
|
|
@@ -499,9 +499,13 @@ def _create_random_linear_parallel_layer(layer_type, vllm_config, mesh):
|
|
|
499
499
|
return linear, lora_linear
|
|
500
500
|
|
|
501
501
|
|
|
502
|
+
def _get_devices():
|
|
503
|
+
return jax.devices()
|
|
504
|
+
|
|
505
|
+
|
|
502
506
|
def _create_mesh():
|
|
503
507
|
axis_names = ("data", "model")
|
|
504
|
-
devices =
|
|
508
|
+
devices = _get_devices()
|
|
505
509
|
mesh_shape = (1, len(devices))
|
|
506
510
|
mesh = jax.make_mesh(mesh_shape, axis_names, devices=devices)
|
|
507
511
|
return mesh
|
|
@@ -513,7 +517,7 @@ def _verify_lora_linear_layer(linear, lora_linear):
|
|
|
513
517
|
# BaseLinearLayerWithLoRA.weight property guarantees this.
|
|
514
518
|
# if len(devices) != 1, `reorder_concatenated_tensor_for_sharding` function may reorder the out_features dimension of the weight matrix.
|
|
515
519
|
# So the below check will fail.
|
|
516
|
-
if len(
|
|
520
|
+
if len(_get_devices()) == 1:
|
|
517
521
|
assert torch.equal(linear.weight.data,
|
|
518
522
|
lora_linear.weight.to('cpu'))
|
|
519
523
|
|
{tpu_inference-0.12.0.dev20251207 → tpu_inference-0.12.0.dev20251219}/tests/lora/test_lora.py
RENAMED
|
@@ -29,7 +29,7 @@ def setup_vllm(num_loras: int, tp: int = 1) -> vllm.LLM:
|
|
|
29
29
|
|
|
30
30
|
|
|
31
31
|
# For multi-chip test, we only use TP=2 because the base model Qwen/Qwen2.5-3B-Instruct has 2 kv heads and the current attention kernel requires it to be divisible by tp_size.
|
|
32
|
-
TP = [2] if os.environ.get("
|
|
32
|
+
TP = [2] if os.environ.get("TEST_LORA_TP", False) else [1]
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
@pytest.mark.parametrize("tp", TP)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import vllm
|
|
6
|
+
from vllm.lora.request import LoRARequest
|
|
7
|
+
|
|
8
|
+
TP = [2] if os.environ.get("USE_V6E8_QUEUE", False) else [1]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@pytest.mark.parametrize("tp", TP)
|
|
12
|
+
def test_lora_performance(tp):
|
|
13
|
+
prompt = "What is 1+1? \n"
|
|
14
|
+
llm_without_lora = vllm.LLM(
|
|
15
|
+
model="Qwen/Qwen2.5-3B-Instruct",
|
|
16
|
+
max_model_len=256,
|
|
17
|
+
max_num_batched_tokens=64,
|
|
18
|
+
max_num_seqs=8,
|
|
19
|
+
tensor_parallel_size=tp,
|
|
20
|
+
)
|
|
21
|
+
start_time = time.time()
|
|
22
|
+
llm_without_lora.generate(
|
|
23
|
+
prompt,
|
|
24
|
+
sampling_params=vllm.SamplingParams(max_tokens=16, temperature=0),
|
|
25
|
+
)[0].outputs[0].text
|
|
26
|
+
base_time = time.time() - start_time
|
|
27
|
+
|
|
28
|
+
del llm_without_lora
|
|
29
|
+
# Waiting for TPUs to be released
|
|
30
|
+
time.sleep(10)
|
|
31
|
+
|
|
32
|
+
llm_with_lora = vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
|
33
|
+
max_model_len=256,
|
|
34
|
+
max_num_batched_tokens=64,
|
|
35
|
+
max_num_seqs=8,
|
|
36
|
+
tensor_parallel_size=tp,
|
|
37
|
+
enable_lora=True,
|
|
38
|
+
max_loras=1,
|
|
39
|
+
max_lora_rank=8)
|
|
40
|
+
lora_request = LoRARequest(
|
|
41
|
+
"lora_adapter_2", 2,
|
|
42
|
+
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
|
|
43
|
+
start_time = time.time()
|
|
44
|
+
llm_with_lora.generate(prompt,
|
|
45
|
+
sampling_params=vllm.SamplingParams(max_tokens=16,
|
|
46
|
+
temperature=0),
|
|
47
|
+
lora_request=lora_request)[0].outputs[0].text
|
|
48
|
+
lora_time = time.time() - start_time
|
|
49
|
+
print(f"Base time: {base_time}, LoRA time: {lora_time}")
|
|
50
|
+
assert (base_time /
|
|
51
|
+
lora_time) < 8, f"Base time: {base_time}, LoRA time: {lora_time}"
|
|
52
|
+
|
|
53
|
+
del llm_with_lora
|
|
@@ -60,6 +60,7 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
60
60
|
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
|
|
61
61
|
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "0")
|
|
62
62
|
monkeypatch.setenv("NEW_MODEL_DESIGN", "0")
|
|
63
|
+
monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "0")
|
|
63
64
|
monkeypatch.setenv("USE_MOE_EP_KERNEL", "0")
|
|
64
65
|
|
|
65
66
|
# Test SKIP_JAX_PRECOMPILE (default False)
|
|
@@ -86,6 +87,82 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
86
87
|
monkeypatch.setenv("USE_MOE_EP_KERNEL", "1")
|
|
87
88
|
assert envs.USE_MOE_EP_KERNEL is True
|
|
88
89
|
|
|
90
|
+
# Test ENABLE_QUANTIZED_MATMUL_KERNEL (default False)
|
|
91
|
+
assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is False
|
|
92
|
+
monkeypatch.setenv("ENABLE_QUANTIZED_MATMUL_KERNEL", "1")
|
|
93
|
+
assert envs.ENABLE_QUANTIZED_MATMUL_KERNEL is True
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_boolean_env_vars_string_values(monkeypatch: pytest.MonkeyPatch):
|
|
97
|
+
"""Test that boolean env vars accept string values like 'True' and 'False'"""
|
|
98
|
+
|
|
99
|
+
# Test NEW_MODEL_DESIGN with string "True"
|
|
100
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "True")
|
|
101
|
+
assert envs.NEW_MODEL_DESIGN is True
|
|
102
|
+
|
|
103
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "true")
|
|
104
|
+
assert envs.NEW_MODEL_DESIGN is True
|
|
105
|
+
|
|
106
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "False")
|
|
107
|
+
assert envs.NEW_MODEL_DESIGN is False
|
|
108
|
+
|
|
109
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "false")
|
|
110
|
+
assert envs.NEW_MODEL_DESIGN is False
|
|
111
|
+
|
|
112
|
+
# Test SKIP_JAX_PRECOMPILE with string values
|
|
113
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "True")
|
|
114
|
+
assert envs.SKIP_JAX_PRECOMPILE is True
|
|
115
|
+
|
|
116
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "false")
|
|
117
|
+
assert envs.SKIP_JAX_PRECOMPILE is False
|
|
118
|
+
|
|
119
|
+
# Test VLLM_XLA_CHECK_RECOMPILATION with string values
|
|
120
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "TRUE")
|
|
121
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
|
|
122
|
+
|
|
123
|
+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "FALSE")
|
|
124
|
+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
|
|
125
|
+
|
|
126
|
+
# Test USE_MOE_EP_KERNEL with string values
|
|
127
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "true")
|
|
128
|
+
assert envs.USE_MOE_EP_KERNEL is True
|
|
129
|
+
|
|
130
|
+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "False")
|
|
131
|
+
assert envs.USE_MOE_EP_KERNEL is False
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_boolean_env_vars_invalid_values(monkeypatch: pytest.MonkeyPatch):
|
|
135
|
+
"""Test that boolean env vars raise errors for invalid values"""
|
|
136
|
+
|
|
137
|
+
# Test invalid value for NEW_MODEL_DESIGN
|
|
138
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "yes")
|
|
139
|
+
with pytest.raises(
|
|
140
|
+
ValueError,
|
|
141
|
+
match="Invalid boolean value 'yes' for NEW_MODEL_DESIGN"):
|
|
142
|
+
_ = envs.NEW_MODEL_DESIGN
|
|
143
|
+
|
|
144
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "2")
|
|
145
|
+
with pytest.raises(ValueError,
|
|
146
|
+
match="Invalid boolean value '2' for NEW_MODEL_DESIGN"):
|
|
147
|
+
_ = envs.NEW_MODEL_DESIGN
|
|
148
|
+
|
|
149
|
+
# Test invalid value for SKIP_JAX_PRECOMPILE
|
|
150
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "invalid")
|
|
151
|
+
with pytest.raises(
|
|
152
|
+
ValueError,
|
|
153
|
+
match="Invalid boolean value 'invalid' for SKIP_JAX_PRECOMPILE"):
|
|
154
|
+
_ = envs.SKIP_JAX_PRECOMPILE
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_boolean_env_vars_empty_string(monkeypatch: pytest.MonkeyPatch):
|
|
158
|
+
"""Test that empty string returns default value"""
|
|
159
|
+
|
|
160
|
+
monkeypatch.setenv("NEW_MODEL_DESIGN", "")
|
|
161
|
+
assert envs.NEW_MODEL_DESIGN is False # Should return default
|
|
162
|
+
|
|
163
|
+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "")
|
|
164
|
+
assert envs.SKIP_JAX_PRECOMPILE is False # Should return default
|
|
165
|
+
|
|
89
166
|
|
|
90
167
|
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
91
168
|
# Ensure clean environment for integer vars by setting to defaults
|
|
@@ -179,7 +256,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):
|
|
|
179
256
|
|
|
180
257
|
def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
|
|
181
258
|
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
|
|
182
|
-
assert envs.MODEL_IMPL_TYPE == "
|
|
259
|
+
assert envs.MODEL_IMPL_TYPE == "auto"
|
|
183
260
|
|
|
184
261
|
|
|
185
262
|
def test_cache_preserves_values_across_env_changes(
|
|
@@ -694,9 +694,9 @@ class TPUConnectorWorker:
|
|
|
694
694
|
|
|
695
695
|
def get_uuid() -> int:
|
|
696
696
|
int128 = uuid4().int
|
|
697
|
-
# Must be 64-bit int, otherwise vllm output encoder would raise error.
|
|
698
|
-
|
|
699
|
-
return
|
|
697
|
+
# Must be less than 64-bit int, otherwise vllm output encoder would raise error.
|
|
698
|
+
# use 50 bit to avoid GO trunk the int when doing JSon serialization
|
|
699
|
+
return int128 >> 78
|
|
700
700
|
|
|
701
701
|
|
|
702
702
|
@jax.jit
|
|
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
|
|
16
16
|
DECODE_SLICES: str = ""
|
|
17
17
|
SKIP_JAX_PRECOMPILE: bool = False
|
|
18
18
|
VLLM_XLA_CHECK_RECOMPILATION: bool = False
|
|
19
|
-
MODEL_IMPL_TYPE: str = "
|
|
19
|
+
MODEL_IMPL_TYPE: str = "auto"
|
|
20
20
|
NEW_MODEL_DESIGN: bool = False
|
|
21
21
|
PHASED_PROFILING_DIR: str = ""
|
|
22
22
|
PYTHON_TRACER_LEVEL: int = 1
|
|
@@ -24,6 +24,7 @@ if TYPE_CHECKING:
|
|
|
24
24
|
NUM_SLICES: int = 1
|
|
25
25
|
RAY_USAGE_STATS_ENABLED: str = "0"
|
|
26
26
|
VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
|
|
27
|
+
ENABLE_QUANTIZED_MATMUL_KERNEL: bool = False
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
def env_with_choices(
|
|
@@ -69,6 +70,34 @@ def env_with_choices(
|
|
|
69
70
|
return _get_validated_env
|
|
70
71
|
|
|
71
72
|
|
|
73
|
+
def env_bool(env_name: str, default: bool = False) -> Callable[[], bool]:
|
|
74
|
+
"""
|
|
75
|
+
Accepts both numeric strings ("0", "1") and boolean strings
|
|
76
|
+
("true", "false", "True", "False").
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
env_name: Name of the environment variable
|
|
80
|
+
default: Default boolean value if not set
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def _get_bool_env() -> bool:
|
|
84
|
+
value = os.getenv(env_name)
|
|
85
|
+
if value is None or value == "":
|
|
86
|
+
return default
|
|
87
|
+
|
|
88
|
+
value_lower = value.lower()
|
|
89
|
+
if value_lower in ("true", "1"):
|
|
90
|
+
return True
|
|
91
|
+
elif value_lower in ("false", "0"):
|
|
92
|
+
return False
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Invalid boolean value '{value}' for {env_name}. "
|
|
96
|
+
f"Valid options: '0', '1', 'true', 'false', 'True', 'False'.")
|
|
97
|
+
|
|
98
|
+
return _get_bool_env
|
|
99
|
+
|
|
100
|
+
|
|
72
101
|
environment_variables: dict[str, Callable[[], Any]] = {
|
|
73
102
|
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
|
|
74
103
|
"JAX_PLATFORMS":
|
|
@@ -93,17 +122,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
93
122
|
lambda: os.getenv("DECODE_SLICES", ""),
|
|
94
123
|
# Skip JAX precompilation step during initialization
|
|
95
124
|
"SKIP_JAX_PRECOMPILE":
|
|
96
|
-
|
|
125
|
+
env_bool("SKIP_JAX_PRECOMPILE", default=False),
|
|
97
126
|
# Check for XLA recompilation during execution
|
|
98
127
|
"VLLM_XLA_CHECK_RECOMPILATION":
|
|
99
|
-
|
|
128
|
+
env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
|
|
100
129
|
# Model implementation type (e.g., "flax_nnx")
|
|
101
130
|
"MODEL_IMPL_TYPE":
|
|
102
|
-
env_with_choices("MODEL_IMPL_TYPE", "
|
|
103
|
-
["vllm", "flax_nnx", "jetpack"]),
|
|
131
|
+
env_with_choices("MODEL_IMPL_TYPE", "auto",
|
|
132
|
+
["auto", "vllm", "flax_nnx", "jetpack"]),
|
|
104
133
|
# Enable new experimental model design
|
|
105
134
|
"NEW_MODEL_DESIGN":
|
|
106
|
-
|
|
135
|
+
env_bool("NEW_MODEL_DESIGN", default=False),
|
|
107
136
|
# Directory to store phased profiling output
|
|
108
137
|
"PHASED_PROFILING_DIR":
|
|
109
138
|
lambda: os.getenv("PHASED_PROFILING_DIR", ""),
|
|
@@ -112,7 +141,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
112
141
|
lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
|
|
113
142
|
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
|
|
114
143
|
"USE_MOE_EP_KERNEL":
|
|
115
|
-
|
|
144
|
+
env_bool("USE_MOE_EP_KERNEL", default=False),
|
|
116
145
|
# Number of TPU slices for multi-slice mesh
|
|
117
146
|
"NUM_SLICES":
|
|
118
147
|
lambda: int(os.getenv("NUM_SLICES") or "1"),
|
|
@@ -122,6 +151,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|
|
122
151
|
# Ray compiled DAG channel type for TPU
|
|
123
152
|
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
|
|
124
153
|
env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
|
|
154
|
+
"ENABLE_QUANTIZED_MATMUL_KERNEL":
|
|
155
|
+
lambda: bool(int(os.getenv("ENABLE_QUANTIZED_MATMUL_KERNEL") or "0")),
|
|
125
156
|
}
|
|
126
157
|
|
|
127
158
|
|
|
@@ -145,6 +145,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
|
|
|
145
145
|
device_str: node['Resources'][device_str]
|
|
146
146
|
} for node in ray_nodes]
|
|
147
147
|
else:
|
|
148
|
+
assert pp_size == len(
|
|
149
|
+
ray_nodes
|
|
150
|
+
), f"Cannot use PP across hosts, please set --pipeline-parallel-size to 1 or {len(ray_nodes)}"
|
|
148
151
|
num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
|
|
149
152
|
placement_group_specs = [{
|
|
150
153
|
device_str: num_devices_per_pp_rank
|
|
@@ -540,12 +540,16 @@ def get_vmem_estimate_bytes(
|
|
|
540
540
|
"""Returns the total vmem bytes used by the kernel."""
|
|
541
541
|
m_per_device = m // tp_size
|
|
542
542
|
n_per_device = n // tp_size
|
|
543
|
-
y_vmem_bytes = n_per_device * k * dtypes.bit_width(y_dtype)
|
|
543
|
+
y_vmem_bytes = (n_per_device * k * (dtypes.bit_width(y_dtype) if hasattr(
|
|
544
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(y_dtype)) // 8)
|
|
544
545
|
total_bytes = (
|
|
545
|
-
2 * m_per_device * k *
|
|
546
|
-
|
|
546
|
+
2 * m_per_device * k *
|
|
547
|
+
(dtypes.bit_width(x_dtype) if hasattr(dtypes, "bit_width") else
|
|
548
|
+
dtypes.itemsize_bits(x_dtype)) // 8 # x_vmem_scratch_ref
|
|
547
549
|
+ y_vmem_bytes # y_vmem_scratch_ref
|
|
548
|
-
+ 2 * m * bn *
|
|
550
|
+
+ 2 * m * bn *
|
|
551
|
+
(dtypes.bit_width(out_dtype) if hasattr(dtypes, "bit_width") else
|
|
552
|
+
dtypes.itemsize_bits(out_dtype)) // 8 # o_vmem_scratch_ref
|
|
549
553
|
+ acc_bytes # acc_vmem_scratch_ref, jnp.float32
|
|
550
554
|
)
|
|
551
555
|
return total_bytes
|
|
@@ -639,8 +643,10 @@ def all_gather_matmul(
|
|
|
639
643
|
# NOTE(chengjiyao): acc buffer is not used in the grid_k == 1 case.
|
|
640
644
|
if grid_k == 1:
|
|
641
645
|
acc_shape = (8, 128)
|
|
642
|
-
acc_bytes =
|
|
643
|
-
|
|
646
|
+
acc_bytes = (
|
|
647
|
+
acc_shape[0] *
|
|
648
|
+
acc_shape[1] * (dtypes.bit_width(jnp.float32) if hasattr(
|
|
649
|
+
dtypes, "bit_width") else dtypes.itemsize_bits(jnp.float32)) // 8)
|
|
644
650
|
y_vmem_shape = (n_per_device, k) if rhs_transpose else (k, n_per_device)
|
|
645
651
|
estimated_vmem_bytes = get_vmem_estimate_bytes(
|
|
646
652
|
m,
|