tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,406 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import tempfile
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import pytest
|
|
20
|
+
import torch
|
|
21
|
+
import torchax
|
|
22
|
+
from jax.sharding import PartitionSpec
|
|
23
|
+
from torchax.interop import torch_view
|
|
24
|
+
from torchax.ops.mappings import j2t, t2j
|
|
25
|
+
from vllm.config import set_current_vllm_config
|
|
26
|
+
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|
27
|
+
init_distributed_environment)
|
|
28
|
+
from vllm.engine.arg_utils import EngineArgs
|
|
29
|
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
30
|
+
LinearBase,
|
|
31
|
+
MergedColumnParallelLinear,
|
|
32
|
+
QKVParallelLinear,
|
|
33
|
+
RowParallelLinear)
|
|
34
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
35
|
+
pack_quantized_values_into_int32
|
|
36
|
+
from vllm.model_executor.model_loader import get_model as vllm_get_model
|
|
37
|
+
from vllm.scalar_type import scalar_types
|
|
38
|
+
|
|
39
|
+
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
40
|
+
from tpu_inference.layers.vllm.quantization.awq import (VllmAWQConfig,
|
|
41
|
+
VllmAWQLinearMethod)
|
|
42
|
+
from tpu_inference.layers.vllm.quantization.configs import \
|
|
43
|
+
VllmQuantLinearConfig
|
|
44
|
+
|
|
45
|
+
from . import utils as test_utils
|
|
46
|
+
|
|
47
|
+
P = PartitionSpec
|
|
48
|
+
MODELS = ["Qwen/Qwen2.5-1.5B-Instruct-AWQ"]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def ref_quantize_uint4(x: torch.Tensor, group_size: int):
|
|
52
|
+
uint4_max = 15
|
|
53
|
+
|
|
54
|
+
# For group quantization, we reshape so that x[0], x[1], ... x[i] are
|
|
55
|
+
# quantized with different scale values.
|
|
56
|
+
x = torch.reshape(x, (-1, group_size) + (x.shape[1:]))
|
|
57
|
+
|
|
58
|
+
# Equation for asymmetric quantization is x_q = (x + x_z) / scale where
|
|
59
|
+
# x_z is calculated to ensure x + x_z does not contain any negative values.
|
|
60
|
+
offset = torch.clamp(-torch.amin(x, dim=1, keepdim=True), min=0)
|
|
61
|
+
x += offset
|
|
62
|
+
# After adding offset, x will not contain any negative values.
|
|
63
|
+
assert x.min() >= 0
|
|
64
|
+
|
|
65
|
+
x_abs_max = torch.amax(x, dim=1, keepdim=True)
|
|
66
|
+
x_s = x_abs_max / uint4_max
|
|
67
|
+
# torch does not support uint4, therefore, we cast to int32 instead.
|
|
68
|
+
x_q = torch.clip(x / x_s, 0, uint4_max).to(torch.int32)
|
|
69
|
+
x_z = torch.clip(offset / x_s, 0, uint4_max).to(torch.int32)
|
|
70
|
+
return x_q, x_z, x_s.to(torch.float32)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def ref_w4a16(x: torch.Tensor, w_q: torch.Tensor, w_z: torch.Tensor,
|
|
74
|
+
w_s: torch.Tensor, b: Optional[torch.Tensor]):
|
|
75
|
+
# Dequantize asymetric quantized weight.
|
|
76
|
+
w = (w_q.to(torch.float32) - w_z.to(torch.float32)) * w_s
|
|
77
|
+
w = w.reshape((-1, w.shape[-1]))
|
|
78
|
+
out = torch.einsum('bd,df->bf', x.to(torch.float32), w)
|
|
79
|
+
if b is not None:
|
|
80
|
+
out += b
|
|
81
|
+
return out.to(x.dtype)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def pack_awq_weight_into_int32(weight: torch.Tensor):
|
|
85
|
+
# AWQ packs 8 uint4 into 32-bits in this order.
|
|
86
|
+
awq_order = (0, 2, 4, 6, 1, 3, 5, 7)
|
|
87
|
+
|
|
88
|
+
orig_shape = weight.shape
|
|
89
|
+
weight = weight.reshape(orig_shape[:-1] + (-1, 8))
|
|
90
|
+
weight = weight[..., awq_order].reshape(orig_shape)
|
|
91
|
+
|
|
92
|
+
return pack_quantized_values_into_int32(weight, scalar_types.uint4, 1)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def return_ref_and_layer_output(
|
|
96
|
+
layer: torch.nn.Module,
|
|
97
|
+
qweight: torch.Tensor,
|
|
98
|
+
qzeros: torch.Tensor,
|
|
99
|
+
scales: torch.Tensor,
|
|
100
|
+
batch_size: int = 16,
|
|
101
|
+
):
|
|
102
|
+
assert isinstance(layer, LinearBase)
|
|
103
|
+
quant_method = layer.quant_method
|
|
104
|
+
assert isinstance(quant_method, VllmAWQLinearMethod)
|
|
105
|
+
quant_config = quant_method.quant_config
|
|
106
|
+
assert isinstance(quant_config, VllmAWQConfig)
|
|
107
|
+
jax_config = quant_method.linear_config
|
|
108
|
+
assert isinstance(jax_config, VllmQuantLinearConfig)
|
|
109
|
+
|
|
110
|
+
input_tensor = torch.rand(
|
|
111
|
+
batch_size, layer.input_size, dtype=torch.bfloat16) / 10
|
|
112
|
+
input_tensor = input_tensor.to('cpu')
|
|
113
|
+
|
|
114
|
+
ref_output = ref_w4a16(
|
|
115
|
+
input_tensor,
|
|
116
|
+
qweight,
|
|
117
|
+
qzeros,
|
|
118
|
+
scales,
|
|
119
|
+
layer.bias,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Run torchax/jax function
|
|
123
|
+
quant_method.process_weights_after_loading(layer)
|
|
124
|
+
with torchax.default_env():
|
|
125
|
+
jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
|
|
126
|
+
layer_output = layer(jax_input_tensor)
|
|
127
|
+
layer_output = j2t(layer_output.to(torch.float32)).to(torch.bfloat16)
|
|
128
|
+
|
|
129
|
+
return ref_output, layer_output
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def initialize_and_return_layer_weights(layer: torch.nn.Module):
|
|
133
|
+
assert isinstance(layer, LinearBase)
|
|
134
|
+
quant_method = layer.quant_method
|
|
135
|
+
assert isinstance(quant_method, VllmAWQLinearMethod)
|
|
136
|
+
quant_config = quant_method.quant_config
|
|
137
|
+
assert isinstance(quant_config, VllmAWQConfig)
|
|
138
|
+
jax_config = quant_method.linear_config
|
|
139
|
+
assert isinstance(jax_config, VllmQuantLinearConfig)
|
|
140
|
+
|
|
141
|
+
# torch.rand returns value in the range of [0, 1). We subtract by 0.2 to
|
|
142
|
+
# simulate asymmetry
|
|
143
|
+
weight = torch.rand((layer.input_size, layer.output_size)) - 0.2
|
|
144
|
+
qweight, qzeros, scales = ref_quantize_uint4(weight,
|
|
145
|
+
quant_config.group_size)
|
|
146
|
+
|
|
147
|
+
# We modify uint4 quantized weights into AWQ format.
|
|
148
|
+
layer_qweight = qweight.reshape((-1, layer.output_size))
|
|
149
|
+
layer_qzeros = qzeros.reshape((-1, layer.output_size))
|
|
150
|
+
layer_scales = scales.reshape((-1, layer.output_size))
|
|
151
|
+
|
|
152
|
+
layer_qweight = pack_awq_weight_into_int32(layer_qweight)
|
|
153
|
+
layer_qzeros = pack_awq_weight_into_int32(layer_qzeros)
|
|
154
|
+
|
|
155
|
+
assert layer.qweight.data.shape == layer_qweight.shape
|
|
156
|
+
assert layer.qzeros.data.shape == layer_qzeros.shape
|
|
157
|
+
assert layer.scales.data.shape == layer_scales.shape
|
|
158
|
+
|
|
159
|
+
layer.qweight.data = layer_qweight
|
|
160
|
+
layer.qzeros.data = layer_qzeros
|
|
161
|
+
layer.scales.data = layer_scales
|
|
162
|
+
|
|
163
|
+
bias = None
|
|
164
|
+
if layer.bias is not None:
|
|
165
|
+
bias = torch.rand_like(layer.bias.data)
|
|
166
|
+
layer.bias.data = bias
|
|
167
|
+
|
|
168
|
+
return qweight, qzeros, scales, bias
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@pytest.fixture(autouse=True)
|
|
172
|
+
def setup_environment():
|
|
173
|
+
# This is a fake config used for init dist env.
|
|
174
|
+
# RowParallelLinear needs dist env to be initialized.
|
|
175
|
+
engine_args = EngineArgs(
|
|
176
|
+
model=MODELS[0],
|
|
177
|
+
max_model_len=64,
|
|
178
|
+
max_num_batched_tokens=64,
|
|
179
|
+
max_num_seqs=4,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
vllm_config = engine_args.create_engine_config()
|
|
183
|
+
|
|
184
|
+
with set_current_vllm_config(vllm_config):
|
|
185
|
+
temp_file = tempfile.mkstemp()[1]
|
|
186
|
+
init_distributed_environment(
|
|
187
|
+
1,
|
|
188
|
+
0,
|
|
189
|
+
local_rank=0,
|
|
190
|
+
distributed_init_method=f"file://{temp_file}",
|
|
191
|
+
backend="gloo")
|
|
192
|
+
ensure_model_parallel_initialized(1, 1)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@pytest.mark.parametrize("model", MODELS)
|
|
196
|
+
@pytest.mark.parametrize("mesh", [
|
|
197
|
+
test_utils.get_spmd_mesh(1),
|
|
198
|
+
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
199
|
+
])
|
|
200
|
+
def test_quant_override(model, mesh):
|
|
201
|
+
|
|
202
|
+
engine_args = EngineArgs(
|
|
203
|
+
model=model,
|
|
204
|
+
max_model_len=64,
|
|
205
|
+
max_num_batched_tokens=64,
|
|
206
|
+
max_num_seqs=4,
|
|
207
|
+
)
|
|
208
|
+
vllm_config = engine_args.create_engine_config()
|
|
209
|
+
vllm_config.model_config.dtype = torch.bfloat16
|
|
210
|
+
|
|
211
|
+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
212
|
+
assert isinstance(quant_config, VllmAWQConfig)
|
|
213
|
+
assert quant_config.vllm_config == vllm_config
|
|
214
|
+
assert quant_config.mesh == mesh
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@pytest.mark.parametrize("model", MODELS)
|
|
218
|
+
@pytest.mark.parametrize(
|
|
219
|
+
"mesh",
|
|
220
|
+
[
|
|
221
|
+
test_utils.get_spmd_mesh(1),
|
|
222
|
+
# We limit device count by 2 instead of using all devices (like 8) since
|
|
223
|
+
# AWQ requires n_groups to be divisible by number of shards. Qwen uses
|
|
224
|
+
# group size of 128 and one of the layer has input size of 1536, meaning
|
|
225
|
+
# n_groups = 1536//128 = 12 - which is not divisible by 8.
|
|
226
|
+
test_utils.get_spmd_mesh(min(jax.local_device_count(), 2))
|
|
227
|
+
])
|
|
228
|
+
def test_loading_model(model, mesh):
|
|
229
|
+
engine_args = EngineArgs(
|
|
230
|
+
model=model,
|
|
231
|
+
max_model_len=64,
|
|
232
|
+
max_num_batched_tokens=64,
|
|
233
|
+
max_num_seqs=4,
|
|
234
|
+
)
|
|
235
|
+
vllm_config = engine_args.create_engine_config()
|
|
236
|
+
vllm_config.model_config.dtype = torch.bfloat16
|
|
237
|
+
vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
238
|
+
vllm_config.device_config.device = "cpu"
|
|
239
|
+
|
|
240
|
+
vllm_model = vllm_get_model(vllm_config=vllm_config)
|
|
241
|
+
layers = test_utils.find_all_layer_type(vllm_model, LinearBase)
|
|
242
|
+
for layer in layers:
|
|
243
|
+
assert isinstance(layer.quant_config, VllmAWQConfig)
|
|
244
|
+
assert isinstance(layer.quant_method, VllmAWQLinearMethod)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@pytest.mark.parametrize("model", MODELS)
|
|
248
|
+
@pytest.mark.parametrize("bias", [False, True])
|
|
249
|
+
@pytest.mark.parametrize("mesh", [
|
|
250
|
+
test_utils.get_spmd_mesh(1),
|
|
251
|
+
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
252
|
+
])
|
|
253
|
+
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
254
|
+
def test_row_parallel_linear(model, bias, mesh, enable_sp):
|
|
255
|
+
dtype = torch.bfloat16
|
|
256
|
+
|
|
257
|
+
engine_args = EngineArgs(
|
|
258
|
+
model=model,
|
|
259
|
+
max_model_len=64,
|
|
260
|
+
max_num_batched_tokens=64,
|
|
261
|
+
max_num_seqs=4,
|
|
262
|
+
)
|
|
263
|
+
vllm_config = engine_args.create_engine_config()
|
|
264
|
+
vllm_config.compilation_config.pass_config.enable_sp = enable_sp
|
|
265
|
+
|
|
266
|
+
vllm_config.model_config.dtype = dtype
|
|
267
|
+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
268
|
+
with set_current_vllm_config(vllm_config):
|
|
269
|
+
linear_layer = RowParallelLinear(
|
|
270
|
+
input_size=4096,
|
|
271
|
+
output_size=8192,
|
|
272
|
+
bias=bias,
|
|
273
|
+
params_dtype=dtype,
|
|
274
|
+
return_bias=False,
|
|
275
|
+
quant_config=quant_config,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
qweight, qzeros, scales, _ = initialize_and_return_layer_weights(
|
|
279
|
+
linear_layer)
|
|
280
|
+
ref_output, layer_output = return_ref_and_layer_output(
|
|
281
|
+
linear_layer, qweight, qzeros, scales)
|
|
282
|
+
torch.testing.assert_close(ref_output, layer_output)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
@pytest.mark.parametrize("model", MODELS)
|
|
286
|
+
@pytest.mark.parametrize("bias", [False, True])
|
|
287
|
+
@pytest.mark.parametrize("mesh", [
|
|
288
|
+
test_utils.get_spmd_mesh(1),
|
|
289
|
+
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
290
|
+
])
|
|
291
|
+
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
292
|
+
def test_column_parallel_linear(model, bias, mesh, enable_sp):
|
|
293
|
+
dtype = torch.bfloat16
|
|
294
|
+
|
|
295
|
+
engine_args = EngineArgs(
|
|
296
|
+
model=model,
|
|
297
|
+
max_model_len=64,
|
|
298
|
+
max_num_batched_tokens=64,
|
|
299
|
+
max_num_seqs=4,
|
|
300
|
+
)
|
|
301
|
+
vllm_config = engine_args.create_engine_config()
|
|
302
|
+
vllm_config.compilation_config.pass_config.enable_sp = enable_sp
|
|
303
|
+
|
|
304
|
+
# Call tpu_inference code
|
|
305
|
+
vllm_config.model_config.dtype = torch.bfloat16
|
|
306
|
+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
307
|
+
with set_current_vllm_config(vllm_config):
|
|
308
|
+
linear_layer = ColumnParallelLinear(
|
|
309
|
+
input_size=4096,
|
|
310
|
+
output_size=8192,
|
|
311
|
+
bias=bias,
|
|
312
|
+
params_dtype=dtype,
|
|
313
|
+
return_bias=False,
|
|
314
|
+
quant_config=quant_config,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
qweight, qzeros, scales, _ = initialize_and_return_layer_weights(
|
|
318
|
+
linear_layer)
|
|
319
|
+
ref_output, layer_output = return_ref_and_layer_output(
|
|
320
|
+
linear_layer, qweight, qzeros, scales)
|
|
321
|
+
torch.testing.assert_close(ref_output, layer_output)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@pytest.mark.parametrize("model", MODELS)
|
|
325
|
+
@pytest.mark.parametrize("bias", [False, True])
|
|
326
|
+
@pytest.mark.parametrize("mesh", [
|
|
327
|
+
test_utils.get_spmd_mesh(1),
|
|
328
|
+
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
329
|
+
])
|
|
330
|
+
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
331
|
+
@pytest.mark.parametrize("fuse_matmuls", [False, True])
|
|
332
|
+
def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls):
|
|
333
|
+
dtype = torch.bfloat16
|
|
334
|
+
|
|
335
|
+
engine_args = EngineArgs(
|
|
336
|
+
model=model,
|
|
337
|
+
max_model_len=64,
|
|
338
|
+
max_num_batched_tokens=64,
|
|
339
|
+
max_num_seqs=4,
|
|
340
|
+
)
|
|
341
|
+
vllm_config = engine_args.create_engine_config()
|
|
342
|
+
vllm_config.compilation_config.pass_config.enable_sp = enable_sp
|
|
343
|
+
|
|
344
|
+
# Call tpu_inference code
|
|
345
|
+
vllm_config.model_config.dtype = torch.bfloat16
|
|
346
|
+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
347
|
+
with set_current_vllm_config(vllm_config):
|
|
348
|
+
linear_layer = QKVParallelLinear(
|
|
349
|
+
hidden_size=4096,
|
|
350
|
+
head_size=128,
|
|
351
|
+
total_num_heads=32,
|
|
352
|
+
total_num_kv_heads=8,
|
|
353
|
+
bias=bias,
|
|
354
|
+
params_dtype=dtype,
|
|
355
|
+
return_bias=False,
|
|
356
|
+
quant_config=quant_config,
|
|
357
|
+
)
|
|
358
|
+
linear_layer.quant_method.fuse_matmuls = fuse_matmuls
|
|
359
|
+
|
|
360
|
+
qweight, qzeros, scales, _ = initialize_and_return_layer_weights(
|
|
361
|
+
linear_layer)
|
|
362
|
+
ref_output, layer_output = return_ref_and_layer_output(
|
|
363
|
+
linear_layer, qweight, qzeros, scales)
|
|
364
|
+
torch.testing.assert_close(ref_output, layer_output)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
@pytest.mark.parametrize("model", MODELS)
|
|
368
|
+
@pytest.mark.parametrize("bias", [False, True])
|
|
369
|
+
@pytest.mark.parametrize("mesh", [
|
|
370
|
+
test_utils.get_spmd_mesh(1),
|
|
371
|
+
test_utils.get_spmd_mesh(jax.local_device_count())
|
|
372
|
+
])
|
|
373
|
+
@pytest.mark.parametrize("fuse_matmuls", [False, True])
|
|
374
|
+
@pytest.mark.parametrize("enable_sp", [False, True])
|
|
375
|
+
def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
|
|
376
|
+
enable_sp):
|
|
377
|
+
dtype = torch.bfloat16
|
|
378
|
+
|
|
379
|
+
engine_args = EngineArgs(
|
|
380
|
+
model=model,
|
|
381
|
+
max_model_len=64,
|
|
382
|
+
max_num_batched_tokens=64,
|
|
383
|
+
max_num_seqs=4,
|
|
384
|
+
)
|
|
385
|
+
vllm_config = engine_args.create_engine_config()
|
|
386
|
+
vllm_config.compilation_config.pass_config.enable_sp = enable_sp
|
|
387
|
+
|
|
388
|
+
# Call tpu_inference code
|
|
389
|
+
vllm_config.model_config.dtype = torch.bfloat16
|
|
390
|
+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
391
|
+
with set_current_vllm_config(vllm_config):
|
|
392
|
+
linear_layer = MergedColumnParallelLinear(
|
|
393
|
+
input_size=4096,
|
|
394
|
+
output_sizes=[14336] * 2,
|
|
395
|
+
bias=bias,
|
|
396
|
+
params_dtype=dtype,
|
|
397
|
+
return_bias=False,
|
|
398
|
+
quant_config=quant_config,
|
|
399
|
+
)
|
|
400
|
+
linear_layer.quant_method.fuse_matmuls = fuse_matmuls
|
|
401
|
+
|
|
402
|
+
qweight, qzeros, scales, _ = initialize_and_return_layer_weights(
|
|
403
|
+
linear_layer)
|
|
404
|
+
ref_output, layer_output = return_ref_and_layer_output(
|
|
405
|
+
linear_layer, qweight, qzeros, scales)
|
|
406
|
+
torch.testing.assert_close(ref_output, layer_output)
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import tempfile
|
|
16
|
+
|
|
17
|
+
import jax.numpy as jnp
|
|
18
|
+
import pytest
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn.functional as F
|
|
21
|
+
import torchax
|
|
22
|
+
from compressed_tensors.quantization import QuantizationArgs
|
|
23
|
+
from jax.sharding import PartitionSpec
|
|
24
|
+
from vllm.config import set_current_vllm_config
|
|
25
|
+
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
|
26
|
+
init_distributed_environment)
|
|
27
|
+
from vllm.engine.arg_utils import EngineArgs
|
|
28
|
+
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
29
|
+
# yapf: disable
|
|
30
|
+
from vllm.model_executor.layers.fused_moe.config import (
|
|
31
|
+
FusedMoEConfig, FusedMoEParallelConfig)
|
|
32
|
+
|
|
33
|
+
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
34
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
|
|
35
|
+
VllmCompressedTensorsConfig
|
|
36
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
|
|
37
|
+
VllmCompressedTensorsW8A8Fp8MoEMethod
|
|
38
|
+
|
|
39
|
+
from . import utils as test_utils
|
|
40
|
+
|
|
41
|
+
# yapf: enable
|
|
42
|
+
|
|
43
|
+
P = PartitionSpec
|
|
44
|
+
|
|
45
|
+
MODEL = 'BCCard/Qwen3-30B-A3B-FP8-Dynamic'
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@pytest.fixture(autouse=True)
|
|
49
|
+
def setup_environment():
|
|
50
|
+
# This is a fake config used for init dist env.
|
|
51
|
+
# RowParallelLinear needs dist env to be initialized.
|
|
52
|
+
engine_args = EngineArgs(
|
|
53
|
+
model=MODEL,
|
|
54
|
+
max_model_len=64,
|
|
55
|
+
max_num_batched_tokens=64,
|
|
56
|
+
max_num_seqs=4,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
vllm_config = engine_args.create_engine_config()
|
|
60
|
+
|
|
61
|
+
with set_current_vllm_config(vllm_config):
|
|
62
|
+
temp_file = tempfile.mkstemp()[1]
|
|
63
|
+
init_distributed_environment(
|
|
64
|
+
1,
|
|
65
|
+
0,
|
|
66
|
+
local_rank=0,
|
|
67
|
+
distributed_init_method=f"file://{temp_file}",
|
|
68
|
+
backend="gloo")
|
|
69
|
+
ensure_model_parallel_initialized(1, 1)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _ref_math_in_bf16(w1, w2, w3, x, router_logits, top_k):
|
|
73
|
+
seqlen = x.shape[0]
|
|
74
|
+
expert_weights = F.softmax(router_logits, dim=-1)
|
|
75
|
+
expert_weights, expert_indices = torch.topk(expert_weights, top_k, dim=-1)
|
|
76
|
+
expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
|
|
77
|
+
|
|
78
|
+
# cond ffn
|
|
79
|
+
# e = total num of exp = 160
|
|
80
|
+
# t = seqlen
|
|
81
|
+
# o = config.imtermediate size
|
|
82
|
+
# i = config.dim
|
|
83
|
+
x1 = torch.einsum("ti, eoi -> teo", x, w1)
|
|
84
|
+
x1 = F.silu(x1)
|
|
85
|
+
x3 = torch.einsum("ti, eoi -> teo", x, w3)
|
|
86
|
+
expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), w2)
|
|
87
|
+
|
|
88
|
+
seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
|
|
89
|
+
expert_outs = expert_outs[seq_indexes, expert_indices]
|
|
90
|
+
out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
|
|
91
|
+
return out
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@pytest.mark.parametrize(
|
|
95
|
+
"mesh", [test_utils.get_spmd_mesh(1),
|
|
96
|
+
test_utils.get_spmd_mesh(2)])
|
|
97
|
+
@pytest.mark.parametrize("num_tokens", [8])
|
|
98
|
+
@pytest.mark.parametrize("intermediate_size", [1024])
|
|
99
|
+
@pytest.mark.parametrize("hidden_size", [128])
|
|
100
|
+
@pytest.mark.parametrize("num_experts", [8])
|
|
101
|
+
@pytest.mark.parametrize("topk", [2])
|
|
102
|
+
@pytest.mark.parametrize("use_ep", [True, False])
|
|
103
|
+
def test_fused_moe_method(mesh, num_tokens, intermediate_size, hidden_size,
|
|
104
|
+
num_experts, topk, use_ep):
|
|
105
|
+
engine_args = EngineArgs(
|
|
106
|
+
model=MODEL,
|
|
107
|
+
max_model_len=64,
|
|
108
|
+
max_num_batched_tokens=64,
|
|
109
|
+
max_num_seqs=4,
|
|
110
|
+
)
|
|
111
|
+
vllm_config = engine_args.create_engine_config()
|
|
112
|
+
vllm_config.compilation_config.pass_config.enable_sp = False
|
|
113
|
+
|
|
114
|
+
# Call tpu_inference code
|
|
115
|
+
vllm_config.model_config.dtype = torch.bfloat16
|
|
116
|
+
quant_config = get_tpu_quantization_config(vllm_config, mesh)
|
|
117
|
+
|
|
118
|
+
with set_current_vllm_config(vllm_config):
|
|
119
|
+
layer = FusedMoE(num_experts=num_experts,
|
|
120
|
+
top_k=topk,
|
|
121
|
+
hidden_size=hidden_size,
|
|
122
|
+
intermediate_size=intermediate_size)
|
|
123
|
+
quant_config = VllmCompressedTensorsConfig(
|
|
124
|
+
target_scheme_map={
|
|
125
|
+
'Linear': {
|
|
126
|
+
'weights':
|
|
127
|
+
QuantizationArgs(num_bits=8,
|
|
128
|
+
type='float',
|
|
129
|
+
symmetric=True,
|
|
130
|
+
group_size=None,
|
|
131
|
+
strategy='channel',
|
|
132
|
+
block_structure=None,
|
|
133
|
+
dynamic=False,
|
|
134
|
+
actorder=None,
|
|
135
|
+
observer='minmax',
|
|
136
|
+
observer_kwargs={}),
|
|
137
|
+
'input_activations':
|
|
138
|
+
QuantizationArgs(num_bits=8,
|
|
139
|
+
type='float',
|
|
140
|
+
symmetric=True,
|
|
141
|
+
group_size=None,
|
|
142
|
+
strategy='token',
|
|
143
|
+
block_structure=None,
|
|
144
|
+
dynamic=True,
|
|
145
|
+
actorder=None,
|
|
146
|
+
observer=None,
|
|
147
|
+
observer_kwargs={}),
|
|
148
|
+
'format':
|
|
149
|
+
None
|
|
150
|
+
}
|
|
151
|
+
},
|
|
152
|
+
ignore=[],
|
|
153
|
+
quant_format='compressed-tensors',
|
|
154
|
+
sparsity_scheme_map={},
|
|
155
|
+
sparsity_ignore_list=[],
|
|
156
|
+
)
|
|
157
|
+
moe = FusedMoEConfig(
|
|
158
|
+
num_experts=num_experts,
|
|
159
|
+
experts_per_token=topk,
|
|
160
|
+
hidden_dim=hidden_size,
|
|
161
|
+
num_local_experts=num_experts,
|
|
162
|
+
moe_parallel_config=FusedMoEParallelConfig(
|
|
163
|
+
tp_size=1,
|
|
164
|
+
dp_size=1,
|
|
165
|
+
ep_size=1,
|
|
166
|
+
tp_rank=0,
|
|
167
|
+
dp_rank=0,
|
|
168
|
+
ep_rank=0,
|
|
169
|
+
use_ep=use_ep,
|
|
170
|
+
all2all_backend='',
|
|
171
|
+
),
|
|
172
|
+
in_dtype=torch.bfloat16,
|
|
173
|
+
)
|
|
174
|
+
method = VllmCompressedTensorsW8A8Fp8MoEMethod(quant_config, moe, mesh)
|
|
175
|
+
method.create_weights(layer,
|
|
176
|
+
num_experts,
|
|
177
|
+
hidden_size,
|
|
178
|
+
intermediate_size,
|
|
179
|
+
params_dtype=torch.float8_e4m3fn)
|
|
180
|
+
method.process_weights_after_loading(layer)
|
|
181
|
+
|
|
182
|
+
seqlen = num_tokens
|
|
183
|
+
with torchax.default_env():
|
|
184
|
+
x = torch.ones((seqlen, hidden_size), dtype=torch.bfloat16).to('jax')
|
|
185
|
+
router_logits = torch.randn((seqlen, num_experts),
|
|
186
|
+
dtype=torch.bfloat16).to('jax')
|
|
187
|
+
result = method.apply(layer,
|
|
188
|
+
x,
|
|
189
|
+
router_logits,
|
|
190
|
+
top_k=topk,
|
|
191
|
+
renormalize=True)
|
|
192
|
+
|
|
193
|
+
result_reference = _ref_math_in_bf16(
|
|
194
|
+
layer.w13_weight.to(torch.bfloat16) * layer.w13_weight_scale,
|
|
195
|
+
layer.w2_weight.to(torch.bfloat16) * layer.w2_weight_scale,
|
|
196
|
+
layer.w3_weight.to(torch.bfloat16) * layer.w3_weight_scale, x,
|
|
197
|
+
router_logits, topk)
|
|
198
|
+
|
|
199
|
+
assert jnp.allclose(result.jax(), result_reference.jax())
|