tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,11 +1,26 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Optional, Union
|
|
2
16
|
|
|
3
17
|
import jax
|
|
4
18
|
import jax.numpy as jnp
|
|
5
19
|
import torch
|
|
6
|
-
from jax.sharding import
|
|
20
|
+
from jax.sharding import PartitionSpec
|
|
21
|
+
from torch.nn.parameter import Parameter
|
|
7
22
|
from torchax.interop import jax_view, torch_view
|
|
8
|
-
from
|
|
23
|
+
from torchax.ops.mappings import t2j
|
|
9
24
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
10
25
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
|
11
26
|
from vllm.model_executor.layers.quantization import \
|
|
@@ -14,24 +29,29 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
|
|
14
29
|
AWQLinearMethod)
|
|
15
30
|
from vllm.model_executor.layers.quantization.base_config import \
|
|
16
31
|
QuantizeMethodBase
|
|
17
|
-
from vllm.model_executor.layers.quantization.utils.quant_utils import
|
|
18
|
-
is_layer_skipped
|
|
19
|
-
from vllm.scalar_type import scalar_types
|
|
32
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
33
|
+
is_layer_skipped
|
|
20
34
|
|
|
21
35
|
from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
|
|
22
|
-
from tpu_inference.layers.
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
36
|
+
from tpu_inference.layers.common.quantization import awq_u32_unpack_u4
|
|
37
|
+
from tpu_inference.layers.common.utils import \
|
|
38
|
+
slice_sharded_tensor_for_concatenation
|
|
39
|
+
from tpu_inference.layers.vllm.process_weights.linear_weights import (
|
|
40
|
+
LinearWeights, process_lienar_weights, shard_linear_weights,
|
|
41
|
+
to_parameter_list)
|
|
42
|
+
from tpu_inference.layers.vllm.quantization.configs import (
|
|
43
|
+
VllmQuantConfig, VllmQuantLinearConfig)
|
|
26
44
|
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
27
45
|
VllmUnquantizedLinearMethod
|
|
46
|
+
from tpu_inference.logger import init_logger
|
|
28
47
|
|
|
29
48
|
P = PartitionSpec
|
|
49
|
+
|
|
30
50
|
logger = init_logger(__name__)
|
|
31
51
|
|
|
32
52
|
|
|
33
53
|
@register_quantization_config(get_tpu_quant_method(AWQ))
|
|
34
|
-
class VllmAWQConfig(AWQConfig,
|
|
54
|
+
class VllmAWQConfig(AWQConfig, VllmQuantConfig):
|
|
35
55
|
|
|
36
56
|
@classmethod
|
|
37
57
|
def get_name(cls):
|
|
@@ -39,7 +59,7 @@ class VllmAWQConfig(AWQConfig, JaxCommonConfig):
|
|
|
39
59
|
|
|
40
60
|
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
41
61
|
# NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
|
|
42
|
-
# bfloat16 is
|
|
62
|
+
# bfloat16 is significantly preferred over float16. This might lead to
|
|
43
63
|
# some numeric output change.
|
|
44
64
|
return [torch.bfloat16]
|
|
45
65
|
|
|
@@ -60,72 +80,79 @@ class VllmAWQConfig(AWQConfig, JaxCommonConfig):
|
|
|
60
80
|
class VllmAWQLinearMethod(AWQLinearMethod):
|
|
61
81
|
|
|
62
82
|
def __init__(self, quant_config: VllmAWQConfig,
|
|
63
|
-
|
|
83
|
+
linear_config: VllmQuantLinearConfig):
|
|
64
84
|
super().__init__(quant_config)
|
|
65
|
-
self.
|
|
66
|
-
|
|
67
|
-
out_sharding, in_sharding = self.jax_config.weight_sharding[:]
|
|
68
|
-
self.jax_config.weight_sharding = P(in_sharding, None, out_sharding)
|
|
69
|
-
self.jax_config.scale_sharding = P(in_sharding, out_sharding)
|
|
85
|
+
self.linear_config = linear_config
|
|
70
86
|
|
|
71
87
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
72
|
-
qweight
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
group_size = self.quant_config.group_size
|
|
76
|
-
# Reshape so that each qweight[i] were quantized with same scales[i].
|
|
77
|
-
qweight = qweight.reshape((-1, group_size, layer.output_size))
|
|
78
|
-
qweight = torch_to_jax_param(qweight,
|
|
79
|
-
NamedSharding(
|
|
80
|
-
self.jax_config.mesh,
|
|
81
|
-
self.jax_config.weight_sharding),
|
|
82
|
-
self.jax_config.output_sizes,
|
|
83
|
-
self.jax_config.n_shards,
|
|
84
|
-
self.jax_config.fuse_matmuls,
|
|
85
|
-
dim=2,
|
|
86
|
-
jax_dtype=jnp.uint4)
|
|
88
|
+
assert layer.qweight.packed_dim == layer.qweight.ndim - 1
|
|
89
|
+
weight = t2j(layer.qweight, use_dlpack=False)
|
|
87
90
|
delattr(layer, "qweight")
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
qzeros = layer.qzeros
|
|
91
|
-
qzeros = unpack_awq_weight(qzeros, qzeros.packed_dim)
|
|
92
|
-
qzeros = torch_to_jax_param(qzeros,
|
|
93
|
-
NamedSharding(
|
|
94
|
-
self.jax_config.mesh,
|
|
95
|
-
self.jax_config.scale_sharding),
|
|
96
|
-
self.jax_config.output_sizes,
|
|
97
|
-
self.jax_config.n_shards,
|
|
98
|
-
self.jax_config.fuse_matmuls,
|
|
99
|
-
dim=1,
|
|
100
|
-
jax_dtype=jnp.uint4)
|
|
101
|
-
delattr(layer, "qzeros")
|
|
102
|
-
layer.qzeros = qzeros
|
|
103
|
-
|
|
104
|
-
scales = torch_to_jax_param(layer.scales,
|
|
105
|
-
NamedSharding(
|
|
106
|
-
self.jax_config.mesh,
|
|
107
|
-
self.jax_config.scale_sharding),
|
|
108
|
-
self.jax_config.output_sizes,
|
|
109
|
-
self.jax_config.n_shards,
|
|
110
|
-
self.jax_config.fuse_matmuls,
|
|
111
|
-
dim=1)
|
|
91
|
+
|
|
92
|
+
weight_scale = t2j(layer.scales, use_dlpack=False)
|
|
112
93
|
delattr(layer, "scales")
|
|
113
|
-
|
|
94
|
+
|
|
95
|
+
assert layer.qzeros.packed_dim == layer.qzeros.ndim - 1
|
|
96
|
+
zero_point = t2j(layer.qzeros, use_dlpack=False)
|
|
97
|
+
delattr(layer, "qzeros")
|
|
114
98
|
|
|
115
99
|
if layer.bias is not None and not layer.skip_bias_add:
|
|
116
100
|
if layer.return_bias:
|
|
117
101
|
logger.warning_once("Bias might return incorrect value.")
|
|
118
|
-
|
|
119
|
-
bias = torch_to_jax_param(
|
|
120
|
-
layer.bias,
|
|
121
|
-
NamedSharding(self.jax_config.mesh,
|
|
122
|
-
self.jax_config.bias_sharding),
|
|
123
|
-
self.jax_config.output_sizes,
|
|
124
|
-
self.jax_config.n_shards,
|
|
125
|
-
self.jax_config.fuse_matmuls,
|
|
126
|
-
)
|
|
102
|
+
bias = t2j(layer.bias, use_dlpack=False)
|
|
127
103
|
delattr(layer, "bias")
|
|
128
|
-
|
|
104
|
+
else:
|
|
105
|
+
bias = None
|
|
106
|
+
|
|
107
|
+
@jax.jit
|
|
108
|
+
def process_awq_linear_weights(
|
|
109
|
+
weight: jax.Array,
|
|
110
|
+
weight_scale: jax.Array,
|
|
111
|
+
zero_point: jax.Array,
|
|
112
|
+
bias: jax.Array | None,
|
|
113
|
+
) -> LinearWeights:
|
|
114
|
+
weight = awq_u32_unpack_u4(weight)
|
|
115
|
+
group_size = self.quant_config.group_size
|
|
116
|
+
weight = weight.reshape((-1, group_size, weight.shape[-1]))
|
|
117
|
+
|
|
118
|
+
zero_point = awq_u32_unpack_u4(zero_point)
|
|
119
|
+
|
|
120
|
+
return process_lienar_weights(
|
|
121
|
+
LinearWeights(
|
|
122
|
+
weight=weight,
|
|
123
|
+
weight_scale=weight_scale,
|
|
124
|
+
zero_point=zero_point,
|
|
125
|
+
bias=bias,
|
|
126
|
+
),
|
|
127
|
+
fused=self.linear_config.fuse_matmuls,
|
|
128
|
+
output_sizes=self.linear_config.output_sizes,
|
|
129
|
+
reorder_size=self.linear_config.n_shards,
|
|
130
|
+
transposed=False,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
weights = process_awq_linear_weights(weight, weight_scale, zero_point,
|
|
134
|
+
bias)
|
|
135
|
+
weights = torch_view(
|
|
136
|
+
shard_linear_weights(
|
|
137
|
+
weights,
|
|
138
|
+
mesh=self.linear_config.mesh,
|
|
139
|
+
weight_p_spec=self.linear_config.weight_sharding,
|
|
140
|
+
bias_p_spec=self.linear_config.bias_sharding,
|
|
141
|
+
transposed=False,
|
|
142
|
+
))
|
|
143
|
+
|
|
144
|
+
if self.linear_config.fuse_matmuls:
|
|
145
|
+
layer.qweight = Parameter(weights.weight, requires_grad=False)
|
|
146
|
+
layer.scales = Parameter(weights.weight_scale, requires_grad=False)
|
|
147
|
+
layer.qzeros = Parameter(weights.zero_point, requires_grad=False)
|
|
148
|
+
if bias is not None:
|
|
149
|
+
layer.bias = Parameter(weights.bias, requires_grad=False)
|
|
150
|
+
else:
|
|
151
|
+
layer.qweight = to_parameter_list(weights.weight)
|
|
152
|
+
layer.scales = to_parameter_list(weights.weight_scale)
|
|
153
|
+
layer.qzeros = to_parameter_list(weights.zero_point)
|
|
154
|
+
if bias is not None:
|
|
155
|
+
layer.bias = to_parameter_list(weights.bias)
|
|
129
156
|
|
|
130
157
|
def apply(self,
|
|
131
158
|
layer: torch.nn.Module,
|
|
@@ -133,7 +160,7 @@ class VllmAWQLinearMethod(AWQLinearMethod):
|
|
|
133
160
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
134
161
|
|
|
135
162
|
with jax.named_scope(layer._get_name()):
|
|
136
|
-
if self.
|
|
163
|
+
if self.linear_config.fuse_matmuls:
|
|
137
164
|
out = self._apply_fused(layer, x, bias)
|
|
138
165
|
else:
|
|
139
166
|
out = self._apply_split(layer, x, bias)
|
|
@@ -161,7 +188,7 @@ class VllmAWQLinearMethod(AWQLinearMethod):
|
|
|
161
188
|
outs += bias.jax()
|
|
162
189
|
|
|
163
190
|
outs = slice_sharded_tensor_for_concatenation(
|
|
164
|
-
outs, self.
|
|
191
|
+
outs, self.linear_config.output_sizes, self.linear_config.n_shards)
|
|
165
192
|
out = jnp.concatenate(outs, axis=-1)
|
|
166
193
|
return torch_view(out)
|
|
167
194
|
|
|
@@ -192,16 +219,3 @@ class VllmAWQLinearMethod(AWQLinearMethod):
|
|
|
192
219
|
outs.append(out)
|
|
193
220
|
out = jnp.concatenate(outs, axis=-1)
|
|
194
221
|
return torch_view(out)
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
def unpack_awq_weight(weight: torch.Tensor, packed_dim: int):
|
|
198
|
-
weight = unpack_quantized_values_into_int32(weight, scalar_types.uint4,
|
|
199
|
-
packed_dim)
|
|
200
|
-
|
|
201
|
-
# AWQ packs 8 uint4 into 32-bits in this order: (0, 2, 4, 6, 1, 3, 5, 7).
|
|
202
|
-
# Following list maps the order used by AWQ into an ascending order.
|
|
203
|
-
reverse_awq_order = (0, 4, 1, 5, 2, 6, 3, 7)
|
|
204
|
-
|
|
205
|
-
orig_shape = weight.shape
|
|
206
|
-
weight = weight.reshape(orig_shape[:-1] + (-1, 8))
|
|
207
|
-
return weight[..., reverse_awq_order].reshape(orig_shape)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,9 +1,22 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Optional
|
|
2
16
|
|
|
3
17
|
import torch
|
|
4
18
|
from jax.sharding import PartitionSpec
|
|
5
19
|
from vllm.attention.layer import Attention
|
|
6
|
-
from vllm.logger import init_logger
|
|
7
20
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
8
21
|
from vllm.model_executor.layers.linear import LinearBase
|
|
9
22
|
from vllm.model_executor.layers.quantization import \
|
|
@@ -18,22 +31,23 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
|
18
31
|
|
|
19
32
|
from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
|
|
20
33
|
get_tpu_quant_method)
|
|
21
|
-
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
22
34
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
|
|
23
|
-
|
|
35
|
+
VllmCompressedTensorsMoEMethod
|
|
24
36
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
|
|
25
37
|
VllmCompressedTensorsW8A8Fp8
|
|
26
38
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
|
|
27
39
|
VllmCompressedTensorsW8A8Int8
|
|
40
|
+
from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig
|
|
28
41
|
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
29
42
|
VllmUnquantizedConfig
|
|
43
|
+
from tpu_inference.logger import init_logger
|
|
30
44
|
|
|
31
45
|
P = PartitionSpec
|
|
32
46
|
logger = init_logger(__name__)
|
|
33
47
|
|
|
34
48
|
|
|
35
49
|
@register_quantization_config(get_tpu_quant_method(COMPRESSED_TENSORS))
|
|
36
|
-
class VllmCompressedTensorsConfig(CompressedTensorsConfig,
|
|
50
|
+
class VllmCompressedTensorsConfig(CompressedTensorsConfig, VllmQuantConfig):
|
|
37
51
|
|
|
38
52
|
@classmethod
|
|
39
53
|
def get_name(cls) -> str:
|
|
@@ -84,14 +98,14 @@ class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
|
|
|
84
98
|
return VllmCompressedTensorsW8A8Fp8(
|
|
85
99
|
weight_quant=weight_quant,
|
|
86
100
|
is_static_input_scheme=is_static_input_scheme,
|
|
87
|
-
|
|
101
|
+
linear_config=linear_config,
|
|
88
102
|
)
|
|
89
103
|
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
|
90
104
|
return VllmCompressedTensorsW8A8Int8(
|
|
91
105
|
strategy=weight_quant.strategy,
|
|
92
106
|
is_static_input_scheme=False,
|
|
93
107
|
input_symmetric=input_quant.symmetric,
|
|
94
|
-
|
|
108
|
+
linear_config=linear_config,
|
|
95
109
|
)
|
|
96
110
|
raise NotImplementedError(
|
|
97
111
|
"No compressed-tensors compatible scheme was found.")
|
|
@@ -113,8 +127,9 @@ class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
|
|
|
113
127
|
layer.scheme = scheme
|
|
114
128
|
return CompressedTensorsLinearMethod(self)
|
|
115
129
|
if isinstance(layer, FusedMoE):
|
|
116
|
-
|
|
117
|
-
|
|
130
|
+
layer.moe_config = self.get_moe_config(layer)
|
|
131
|
+
return VllmCompressedTensorsMoEMethod.get_moe_method(
|
|
132
|
+
self, layer, layer_name=prefix)
|
|
118
133
|
if isinstance(layer, Attention):
|
|
119
134
|
return CompressedTensorsKVCacheMethod(self)
|
|
120
135
|
return None
|