tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,441 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import tempfile
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import pytest
|
|
20
|
+
import torch
|
|
21
|
+
import torchax
|
|
22
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
23
|
+
from torchax.interop import torch_view
|
|
24
|
+
from torchax.ops.mappings import j2t, t2j
|
|
25
|
+
from vllm.config import set_current_vllm_config
|
|
26
|
+
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|
27
|
+
init_distributed_environment)
|
|
28
|
+
from vllm.engine.arg_utils import EngineArgs
|
|
29
|
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
30
|
+
LinearBase,
|
|
31
|
+
MergedColumnParallelLinear,
|
|
32
|
+
QKVParallelLinear,
|
|
33
|
+
RowParallelLinear)
|
|
34
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import \
|
|
35
|
+
CompressedTensorsLinearMethod
|
|
36
|
+
from vllm.model_executor.model_loader import get_model as vllm_get_model
|
|
37
|
+
|
|
38
|
+
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
39
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
|
|
40
|
+
VllmCompressedTensorsConfig
|
|
41
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
|
|
42
|
+
VllmCompressedTensorsW8A8Int8
|
|
43
|
+
|
|
44
|
+
from . import utils as test_utils
|
|
45
|
+
|
|
46
|
+
P = PartitionSpec
|
|
47
|
+
MODELS = ["RedHatAI/Qwen2.5-1.5B-quantized.w8a8"]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def ref_quantize_int8(x: torch.Tensor):
|
|
51
|
+
x_abs_max = torch.amax(torch.abs(x), dim=1, keepdim=True)
|
|
52
|
+
x_s = x_abs_max / 127
|
|
53
|
+
x_q = torch.round(x / x_s).to(torch.int8)
|
|
54
|
+
return x_q, x_s.to(torch.float32)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def ref_w8a8_int8(x: torch.Tensor, w_q: torch.Tensor, w_s: torch.Tensor,
|
|
58
|
+
b: Optional[torch.Tensor]):
|
|
59
|
+
x_q, x_s = ref_quantize_int8(x)
|
|
60
|
+
out = torch.einsum('bd,fd->bf', x_q.to(torch.float32),
|
|
61
|
+
w_q.to(torch.float32))
|
|
62
|
+
out = (out * x_s) * w_s.T
|
|
63
|
+
if b is not None:
|
|
64
|
+
out += b
|
|
65
|
+
return out.to(x.dtype)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@pytest.fixture(autouse=True)
|
|
69
|
+
def setup_environment():
|
|
70
|
+
# This is a fake config used for init dist env.
|
|
71
|
+
# RowParallelLinear needs dist env to be initialized.
|
|
72
|
+
engine_args = EngineArgs(
|
|
73
|
+
model=MODELS[0],
|
|
74
|
+
max_model_len=64,
|
|
75
|
+
max_num_batched_tokens=64,
|
|
76
|
+
max_num_seqs=4,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
vllm_config = engine_args.create_engine_config()
|
|
80
|
+
|
|
81
|
+
with set_current_vllm_config(vllm_config):
|
|
82
|
+
temp_file = tempfile.mkstemp()[1]
|
|
83
|
+
init_distributed_environment(
|
|
84
|
+
1,
|
|
85
|
+
0,
|
|
86
|
+
local_rank=0,
|
|
87
|
+
distributed_init_method=f"file://{temp_file}",
|
|
88
|
+
backend="gloo")
|
|
89
|
+
ensure_model_parallel_initialized(1, 1)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@pytest.mark.parametrize("model", MODELS)
|
|
93
|
+
@pytest.mark.parametrize("mesh", [
|
|
94
|
+
test_utils.get_spmd_mesh(1),
|
|
95
|
+
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
96
|
+
])
|
|
97
|
+
def test_quant_override(model, mesh):
|
|
98
|
+
|
|
99
|
+
engine_args = EngineArgs(
|
|
100
|
+
model=model,
|
|
101
|
+
max_model_len=64,
|
|
102
|
+
max_num_batched_tokens=64,
|
|
103
|
+
max_num_seqs=4,
|
|
104
|
+
)
|
|
105
|
+
vllm_config = engine_args.create_engine_config()
|
|
106
|
+
vllm_config.model_config.dtype = torch.bfloat16
|
|
107
|
+
|
|
108
|
+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
109
|
+
assert isinstance(quant_config, VllmCompressedTensorsConfig)
|
|
110
|
+
assert quant_config.vllm_config == vllm_config
|
|
111
|
+
assert quant_config.mesh == mesh
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@pytest.mark.parametrize("model", MODELS)
|
|
115
|
+
@pytest.mark.parametrize("mesh", [
|
|
116
|
+
test_utils.get_spmd_mesh(1),
|
|
117
|
+
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
118
|
+
])
|
|
119
|
+
def test_loading_model(model, mesh):
|
|
120
|
+
engine_args = EngineArgs(
|
|
121
|
+
model=model,
|
|
122
|
+
max_model_len=64,
|
|
123
|
+
max_num_batched_tokens=64,
|
|
124
|
+
max_num_seqs=4,
|
|
125
|
+
)
|
|
126
|
+
vllm_config = engine_args.create_engine_config()
|
|
127
|
+
vllm_config.model_config.dtype = torch.bfloat16
|
|
128
|
+
vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
129
|
+
vllm_config.device_config.device = "cpu"
|
|
130
|
+
|
|
131
|
+
vllm_model = vllm_get_model(vllm_config=vllm_config)
|
|
132
|
+
layers = test_utils.find_all_layer_type(vllm_model, LinearBase)
|
|
133
|
+
for layer in layers:
|
|
134
|
+
assert isinstance(layer.quant_config, VllmCompressedTensorsConfig)
|
|
135
|
+
assert isinstance(layer.quant_method, CompressedTensorsLinearMethod)
|
|
136
|
+
assert isinstance(layer.scheme, VllmCompressedTensorsW8A8Int8)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@pytest.mark.parametrize("model", MODELS)
|
|
140
|
+
@pytest.mark.parametrize("bias", [False, True])
|
|
141
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
142
|
+
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
143
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
144
|
+
def test_row_parallel_linear(model, bias, num_devices, enable_sp,
|
|
145
|
+
enable_attn_dp):
|
|
146
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
147
|
+
if enable_attn_dp and num_devices < 2:
|
|
148
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
149
|
+
|
|
150
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
151
|
+
|
|
152
|
+
dtype = torch.bfloat16
|
|
153
|
+
|
|
154
|
+
engine_args = EngineArgs(
|
|
155
|
+
model=model,
|
|
156
|
+
max_model_len=64,
|
|
157
|
+
max_num_batched_tokens=64,
|
|
158
|
+
max_num_seqs=4,
|
|
159
|
+
)
|
|
160
|
+
vllm_config = engine_args.create_engine_config()
|
|
161
|
+
vllm_config.compilation_config.pass_config.enable_sp = enable_sp
|
|
162
|
+
|
|
163
|
+
# Call tpu_inference code
|
|
164
|
+
vllm_config.model_config.dtype = dtype
|
|
165
|
+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
166
|
+
with set_current_vllm_config(vllm_config):
|
|
167
|
+
jax_row_linear = RowParallelLinear(
|
|
168
|
+
input_size=4096,
|
|
169
|
+
output_size=8192,
|
|
170
|
+
bias=bias,
|
|
171
|
+
params_dtype=dtype,
|
|
172
|
+
return_bias=False,
|
|
173
|
+
quant_config=quant_config,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
weight_data_float = torch.rand(
|
|
177
|
+
(jax_row_linear.output_size, jax_row_linear.input_size),
|
|
178
|
+
dtype=dtype) / 10
|
|
179
|
+
weight_data, weight_scale_data = ref_quantize_int8(weight_data_float)
|
|
180
|
+
if bias:
|
|
181
|
+
bias_data = torch.rand_like(jax_row_linear.bias.data)
|
|
182
|
+
|
|
183
|
+
jax_row_linear.weight.data = weight_data
|
|
184
|
+
jax_row_linear.weight_scale.data = weight_scale_data
|
|
185
|
+
if bias:
|
|
186
|
+
jax_row_linear.bias.data = bias_data
|
|
187
|
+
|
|
188
|
+
input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
|
|
189
|
+
input_tensor = input_tensor.to('cpu')
|
|
190
|
+
|
|
191
|
+
jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
|
|
192
|
+
jax_input_tensor.apply_jax_(jax.device_put,
|
|
193
|
+
NamedSharding(mesh, P(None, None)))
|
|
194
|
+
with torchax.default_env():
|
|
195
|
+
assert isinstance(jax_row_linear.quant_method,
|
|
196
|
+
CompressedTensorsLinearMethod)
|
|
197
|
+
assert isinstance(jax_row_linear.scheme, VllmCompressedTensorsW8A8Int8)
|
|
198
|
+
jax_row_linear.quant_method.process_weights_after_loading(
|
|
199
|
+
jax_row_linear)
|
|
200
|
+
jax_output = jax_row_linear(jax_input_tensor)
|
|
201
|
+
jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
|
|
202
|
+
|
|
203
|
+
# Call reference w8a8 int8
|
|
204
|
+
output = ref_w8a8_int8(
|
|
205
|
+
input_tensor,
|
|
206
|
+
weight_data,
|
|
207
|
+
weight_scale_data,
|
|
208
|
+
bias_data if bias else None,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
torch.testing.assert_close(output, jax_output)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@pytest.mark.parametrize("model", MODELS)
|
|
215
|
+
@pytest.mark.parametrize("bias", [False, True])
|
|
216
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
217
|
+
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
218
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
219
|
+
def test_column_parallel_linear(model, bias, num_devices, enable_sp,
|
|
220
|
+
enable_attn_dp):
|
|
221
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
222
|
+
if enable_attn_dp and num_devices < 2:
|
|
223
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
224
|
+
|
|
225
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
226
|
+
dtype = torch.bfloat16
|
|
227
|
+
|
|
228
|
+
engine_args = EngineArgs(
|
|
229
|
+
model=model,
|
|
230
|
+
max_model_len=64,
|
|
231
|
+
max_num_batched_tokens=64,
|
|
232
|
+
max_num_seqs=4,
|
|
233
|
+
)
|
|
234
|
+
vllm_config = engine_args.create_engine_config()
|
|
235
|
+
vllm_config.compilation_config.pass_config.enable_sp = enable_sp
|
|
236
|
+
|
|
237
|
+
# Call tpu_inference code
|
|
238
|
+
vllm_config.model_config.dtype = torch.bfloat16
|
|
239
|
+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
240
|
+
with set_current_vllm_config(vllm_config):
|
|
241
|
+
jax_column_linear = ColumnParallelLinear(
|
|
242
|
+
input_size=4096,
|
|
243
|
+
output_size=8192,
|
|
244
|
+
bias=bias,
|
|
245
|
+
params_dtype=dtype,
|
|
246
|
+
return_bias=False,
|
|
247
|
+
quant_config=quant_config,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
weight_data_float = torch.rand(
|
|
251
|
+
(jax_column_linear.output_size, jax_column_linear.input_size),
|
|
252
|
+
dtype=dtype) / 10
|
|
253
|
+
weight_data, weight_scale_data = ref_quantize_int8(weight_data_float)
|
|
254
|
+
if bias:
|
|
255
|
+
bias_data = torch.rand_like(jax_column_linear.bias.data)
|
|
256
|
+
|
|
257
|
+
jax_column_linear.weight.data = weight_data
|
|
258
|
+
jax_column_linear.weight_scale.data = weight_scale_data
|
|
259
|
+
if bias:
|
|
260
|
+
jax_column_linear.bias.data = bias_data
|
|
261
|
+
|
|
262
|
+
input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
|
|
263
|
+
input_tensor = input_tensor.to('cpu')
|
|
264
|
+
|
|
265
|
+
jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
|
|
266
|
+
jax_input_tensor.apply_jax_(jax.device_put,
|
|
267
|
+
NamedSharding(mesh, P(None, None)))
|
|
268
|
+
with torchax.default_env():
|
|
269
|
+
assert isinstance(jax_column_linear.quant_method,
|
|
270
|
+
CompressedTensorsLinearMethod)
|
|
271
|
+
assert isinstance(jax_column_linear.scheme,
|
|
272
|
+
VllmCompressedTensorsW8A8Int8)
|
|
273
|
+
jax_column_linear.quant_method.process_weights_after_loading(
|
|
274
|
+
jax_column_linear)
|
|
275
|
+
jax_output = jax_column_linear(jax_input_tensor)
|
|
276
|
+
jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
|
|
277
|
+
|
|
278
|
+
# Call reference w8a8 int8
|
|
279
|
+
output = ref_w8a8_int8(
|
|
280
|
+
input_tensor,
|
|
281
|
+
weight_data,
|
|
282
|
+
weight_scale_data,
|
|
283
|
+
bias_data if bias else None,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
torch.testing.assert_close(output, jax_output)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
@pytest.mark.parametrize("model", MODELS)
|
|
290
|
+
@pytest.mark.parametrize("bias", [False, True])
|
|
291
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
292
|
+
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
293
|
+
@pytest.mark.parametrize("fuse_matmuls", [False, True])
|
|
294
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
295
|
+
def test_qkv_parallel_linear(model, bias, num_devices, enable_sp, fuse_matmuls,
|
|
296
|
+
enable_attn_dp):
|
|
297
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
298
|
+
if enable_attn_dp and num_devices < 2:
|
|
299
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
300
|
+
|
|
301
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
302
|
+
dtype = torch.bfloat16
|
|
303
|
+
|
|
304
|
+
engine_args = EngineArgs(
|
|
305
|
+
model=model,
|
|
306
|
+
max_model_len=64,
|
|
307
|
+
max_num_batched_tokens=64,
|
|
308
|
+
max_num_seqs=4,
|
|
309
|
+
)
|
|
310
|
+
vllm_config = engine_args.create_engine_config()
|
|
311
|
+
vllm_config.compilation_config.pass_config.enable_sp = enable_sp
|
|
312
|
+
|
|
313
|
+
# Call tpu_inference code
|
|
314
|
+
vllm_config.model_config.dtype = torch.bfloat16
|
|
315
|
+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
316
|
+
with set_current_vllm_config(vllm_config):
|
|
317
|
+
jax_qkv_linear = QKVParallelLinear(
|
|
318
|
+
hidden_size=4096,
|
|
319
|
+
head_size=128,
|
|
320
|
+
total_num_heads=32,
|
|
321
|
+
total_num_kv_heads=8,
|
|
322
|
+
bias=bias,
|
|
323
|
+
params_dtype=dtype,
|
|
324
|
+
return_bias=False,
|
|
325
|
+
quant_config=quant_config,
|
|
326
|
+
)
|
|
327
|
+
jax_qkv_linear.quant_method.fuse_matmuls = fuse_matmuls
|
|
328
|
+
|
|
329
|
+
weight_data_float = torch.rand(
|
|
330
|
+
(jax_qkv_linear.output_size, jax_qkv_linear.input_size),
|
|
331
|
+
dtype=dtype) / 10
|
|
332
|
+
weight_data, weight_scale_data = ref_quantize_int8(weight_data_float)
|
|
333
|
+
if bias:
|
|
334
|
+
bias_data = torch.rand_like(jax_qkv_linear.bias.data)
|
|
335
|
+
|
|
336
|
+
jax_qkv_linear.weight.data = weight_data
|
|
337
|
+
jax_qkv_linear.weight_scale.data = weight_scale_data
|
|
338
|
+
if bias:
|
|
339
|
+
jax_qkv_linear.bias.data = bias_data
|
|
340
|
+
|
|
341
|
+
input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
|
|
342
|
+
input_tensor = input_tensor.to('cpu')
|
|
343
|
+
|
|
344
|
+
jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
|
|
345
|
+
jax_input_tensor.apply_jax_(jax.device_put,
|
|
346
|
+
NamedSharding(mesh, P(None, None)))
|
|
347
|
+
with torchax.default_env():
|
|
348
|
+
assert isinstance(jax_qkv_linear.quant_method,
|
|
349
|
+
CompressedTensorsLinearMethod)
|
|
350
|
+
assert isinstance(jax_qkv_linear.scheme, VllmCompressedTensorsW8A8Int8)
|
|
351
|
+
jax_qkv_linear.quant_method.process_weights_after_loading(
|
|
352
|
+
jax_qkv_linear)
|
|
353
|
+
jax_output = jax_qkv_linear(jax_input_tensor)
|
|
354
|
+
jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
|
|
355
|
+
|
|
356
|
+
# Call reference w8a8 int8
|
|
357
|
+
output = ref_w8a8_int8(
|
|
358
|
+
input_tensor,
|
|
359
|
+
weight_data,
|
|
360
|
+
weight_scale_data,
|
|
361
|
+
bias_data if bias else None,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
torch.testing.assert_close(output, jax_output)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
@pytest.mark.parametrize("model", MODELS)
|
|
368
|
+
@pytest.mark.parametrize("bias", [False, True])
|
|
369
|
+
@pytest.mark.parametrize("num_devices", [1, jax.local_device_count()])
|
|
370
|
+
@pytest.mark.parametrize("fuse_matmuls", [False, True])
|
|
371
|
+
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
372
|
+
@pytest.mark.parametrize("enable_attn_dp", [False, True])
|
|
373
|
+
def test_merged_column_parallel_linear(model, bias, num_devices, fuse_matmuls,
|
|
374
|
+
enable_sp, enable_attn_dp):
|
|
375
|
+
# Skip if enable_attn_dp is True but we don't have enough devices
|
|
376
|
+
if enable_attn_dp and num_devices < 2:
|
|
377
|
+
pytest.skip("enable_attn_dp requires at least 2 devices")
|
|
378
|
+
|
|
379
|
+
mesh = test_utils.get_spmd_mesh(num_devices, enable_attn_dp)
|
|
380
|
+
dtype = torch.bfloat16
|
|
381
|
+
|
|
382
|
+
engine_args = EngineArgs(
|
|
383
|
+
model=model,
|
|
384
|
+
max_model_len=64,
|
|
385
|
+
max_num_batched_tokens=64,
|
|
386
|
+
max_num_seqs=4,
|
|
387
|
+
)
|
|
388
|
+
vllm_config = engine_args.create_engine_config()
|
|
389
|
+
vllm_config.compilation_config.pass_config.enable_sp = enable_sp
|
|
390
|
+
|
|
391
|
+
# Call tpu_inference code
|
|
392
|
+
vllm_config.model_config.dtype = torch.bfloat16
|
|
393
|
+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
394
|
+
with set_current_vllm_config(vllm_config):
|
|
395
|
+
jax_merged_column_linear = MergedColumnParallelLinear(
|
|
396
|
+
input_size=4096,
|
|
397
|
+
output_sizes=[14336] * 2,
|
|
398
|
+
bias=bias,
|
|
399
|
+
params_dtype=dtype,
|
|
400
|
+
return_bias=False,
|
|
401
|
+
quant_config=quant_config,
|
|
402
|
+
)
|
|
403
|
+
jax_merged_column_linear.quant_method.fuse_matmuls = fuse_matmuls
|
|
404
|
+
|
|
405
|
+
weight_data_float = torch.rand((jax_merged_column_linear.output_size,
|
|
406
|
+
jax_merged_column_linear.input_size),
|
|
407
|
+
dtype=dtype) / 10
|
|
408
|
+
weight_data, weight_scale_data = ref_quantize_int8(weight_data_float)
|
|
409
|
+
if bias:
|
|
410
|
+
bias_data = torch.rand_like(jax_merged_column_linear.bias.data)
|
|
411
|
+
|
|
412
|
+
jax_merged_column_linear.weight.data = weight_data
|
|
413
|
+
jax_merged_column_linear.weight_scale.data = weight_scale_data
|
|
414
|
+
if bias:
|
|
415
|
+
jax_merged_column_linear.bias.data = bias_data
|
|
416
|
+
|
|
417
|
+
input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
|
|
418
|
+
input_tensor = input_tensor.to('cpu')
|
|
419
|
+
|
|
420
|
+
jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
|
|
421
|
+
jax_input_tensor.apply_jax_(jax.device_put,
|
|
422
|
+
NamedSharding(mesh, P(None, None)))
|
|
423
|
+
with torchax.default_env():
|
|
424
|
+
assert isinstance(jax_merged_column_linear.quant_method,
|
|
425
|
+
CompressedTensorsLinearMethod)
|
|
426
|
+
assert isinstance(jax_merged_column_linear.scheme,
|
|
427
|
+
VllmCompressedTensorsW8A8Int8)
|
|
428
|
+
jax_merged_column_linear.quant_method.process_weights_after_loading(
|
|
429
|
+
jax_merged_column_linear)
|
|
430
|
+
jax_output = jax_merged_column_linear(jax_input_tensor)
|
|
431
|
+
jax_output = j2t(jax_output.to(torch.float32)).to(dtype)
|
|
432
|
+
|
|
433
|
+
# Call reference w8a8 int8
|
|
434
|
+
output = ref_w8a8_int8(
|
|
435
|
+
input_tensor,
|
|
436
|
+
weight_data,
|
|
437
|
+
weight_scale_data,
|
|
438
|
+
bias_data if bias else None,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
torch.testing.assert_close(output, jax_output)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import pytest
|
|
16
|
+
|
|
17
|
+
pytest.skip("FP8 implementation not complete yet", allow_module_level=True)
|