tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass, fields
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import torch
|
|
19
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
20
|
+
from torch.nn import ParameterList
|
|
21
|
+
from torch.nn.parameter import Parameter
|
|
22
|
+
from torchax.tensor import Tensor
|
|
23
|
+
|
|
24
|
+
from tpu_inference.layers.common.utils import \
|
|
25
|
+
reorder_concatenated_tensor_for_sharding
|
|
26
|
+
from tpu_inference.logger import init_logger
|
|
27
|
+
|
|
28
|
+
P = PartitionSpec
|
|
29
|
+
|
|
30
|
+
logger = init_logger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@jax.tree_util.register_dataclass
|
|
34
|
+
@dataclass
|
|
35
|
+
class LinearWeights:
|
|
36
|
+
weight: jax.Array | Tensor | list[jax.Array | Tensor]
|
|
37
|
+
weight_scale: jax.Array | Tensor | list[jax.Array | Tensor] | None
|
|
38
|
+
zero_point: jax.Array | Tensor | list[jax.Array | Tensor] | None
|
|
39
|
+
bias: jax.Array | Tensor | list[jax.Array | Tensor] | None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
MODEL_MATMUL_FUSION_TRUTH_TABLE = {
|
|
43
|
+
("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "QKVParallelLinear"):
|
|
44
|
+
True,
|
|
45
|
+
("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
|
|
46
|
+
False,
|
|
47
|
+
("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "QKVParallelLinear"):
|
|
48
|
+
False,
|
|
49
|
+
("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
|
|
50
|
+
False,
|
|
51
|
+
("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "QKVParallelLinear"):
|
|
52
|
+
False,
|
|
53
|
+
("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
|
|
54
|
+
False,
|
|
55
|
+
("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "QKVParallelLinear"):
|
|
56
|
+
False,
|
|
57
|
+
("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
|
|
58
|
+
False,
|
|
59
|
+
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "QKVParallelLinear"):
|
|
60
|
+
False,
|
|
61
|
+
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "MergedColumnParallelLinear"):
|
|
62
|
+
False,
|
|
63
|
+
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "QKVParallelLinear"):
|
|
64
|
+
False,
|
|
65
|
+
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "MergedColumnParallelLinear"):
|
|
66
|
+
False,
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def to_parameter_list(tensor: list[torch.Tensor]):
|
|
71
|
+
tensor = [Parameter(t, requires_grad=False) for t in tensor]
|
|
72
|
+
return ParameterList(tensor)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def get_model_matmul_fusion_assignment(model_name: str, batch_size: int,
|
|
76
|
+
tp_size: int, layer_name: str):
|
|
77
|
+
key = (model_name, batch_size, tp_size, layer_name)
|
|
78
|
+
return MODEL_MATMUL_FUSION_TRUTH_TABLE.get(key, True)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def process_lienar_weights(
|
|
82
|
+
weights: LinearWeights,
|
|
83
|
+
fused: bool = False,
|
|
84
|
+
output_sizes: list[int] | None = None,
|
|
85
|
+
reorder_size: int | None = None,
|
|
86
|
+
transposed: bool = True,
|
|
87
|
+
per_tensor: bool = False,
|
|
88
|
+
) -> LinearWeights:
|
|
89
|
+
weight = weights.weight
|
|
90
|
+
weight_scale = weights.weight_scale
|
|
91
|
+
zero_point = weights.zero_point
|
|
92
|
+
bias = weights.bias
|
|
93
|
+
|
|
94
|
+
dim = 0 if transposed else -1
|
|
95
|
+
if output_sizes is None:
|
|
96
|
+
output_sizes = [weight.shape[dim]]
|
|
97
|
+
|
|
98
|
+
if fused:
|
|
99
|
+
assert reorder_size is not None
|
|
100
|
+
weight = reorder_concatenated_tensor_for_sharding(
|
|
101
|
+
weight, output_sizes, reorder_size, dim)
|
|
102
|
+
|
|
103
|
+
if weight_scale is not None and not per_tensor:
|
|
104
|
+
weight_scale = reorder_concatenated_tensor_for_sharding(
|
|
105
|
+
weight_scale, output_sizes, reorder_size, dim)
|
|
106
|
+
if zero_point is not None:
|
|
107
|
+
zero_point = reorder_concatenated_tensor_for_sharding(
|
|
108
|
+
zero_point, output_sizes, reorder_size, dim)
|
|
109
|
+
if bias is not None:
|
|
110
|
+
bias = reorder_concatenated_tensor_for_sharding(
|
|
111
|
+
bias, output_sizes, reorder_size, dim)
|
|
112
|
+
else:
|
|
113
|
+
|
|
114
|
+
def slice_tensor(tensor):
|
|
115
|
+
tensors = []
|
|
116
|
+
start = 0
|
|
117
|
+
for size in output_sizes:
|
|
118
|
+
end = start + size
|
|
119
|
+
tensor_split = jax.lax.slice_in_dim(tensor,
|
|
120
|
+
start,
|
|
121
|
+
end,
|
|
122
|
+
axis=dim)
|
|
123
|
+
tensors.append(tensor_split)
|
|
124
|
+
start = end
|
|
125
|
+
return tensors
|
|
126
|
+
|
|
127
|
+
weight = slice_tensor(weight)
|
|
128
|
+
if weight_scale is not None and not per_tensor:
|
|
129
|
+
weight_scale = slice_tensor(weight_scale)
|
|
130
|
+
if zero_point is not None:
|
|
131
|
+
zero_point = slice_tensor(zero_point)
|
|
132
|
+
if bias is not None:
|
|
133
|
+
bias = slice_tensor(bias)
|
|
134
|
+
|
|
135
|
+
return LinearWeights(
|
|
136
|
+
weight=weight,
|
|
137
|
+
weight_scale=weight_scale,
|
|
138
|
+
zero_point=zero_point,
|
|
139
|
+
bias=bias,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def shard_linear_weights(
|
|
144
|
+
weights: LinearWeights,
|
|
145
|
+
mesh: Mesh,
|
|
146
|
+
weight_p_spec: PartitionSpec,
|
|
147
|
+
bias_p_spec: PartitionSpec,
|
|
148
|
+
transposed: bool = True,
|
|
149
|
+
per_tensor: bool = False,
|
|
150
|
+
) -> LinearWeights:
|
|
151
|
+
|
|
152
|
+
if not transposed:
|
|
153
|
+
# By defualt, we use transposed weights. If it is not transposed,
|
|
154
|
+
# we need to transpose the sharding as well.
|
|
155
|
+
weight_p_spec = PartitionSpec(*weight_p_spec[::-1])
|
|
156
|
+
bias_p_spec = PartitionSpec(weight_p_spec[0])
|
|
157
|
+
|
|
158
|
+
weight_sharding = NamedSharding(mesh, weight_p_spec)
|
|
159
|
+
bias_sharding = NamedSharding(mesh, bias_p_spec)
|
|
160
|
+
|
|
161
|
+
weight_shardings = LinearWeights(
|
|
162
|
+
weight=weight_sharding,
|
|
163
|
+
weight_scale=NamedSharding(mesh, P()) if per_tensor else bias_sharding,
|
|
164
|
+
zero_point=bias_sharding,
|
|
165
|
+
bias=bias_sharding,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
for field in fields(LinearWeights):
|
|
169
|
+
key = field.name
|
|
170
|
+
if (weight := getattr(weights, key, None)) is not None:
|
|
171
|
+
sharding = getattr(weight_shardings, key)
|
|
172
|
+
weight = jax.device_put(weight, sharding)
|
|
173
|
+
setattr(weights, key, weight)
|
|
174
|
+
return weights
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import copy
|
|
2
16
|
|
|
3
17
|
from jax.sharding import Mesh
|
|
@@ -7,9 +21,10 @@ from vllm.model_executor.layers.quantization.base_config import \
|
|
|
7
21
|
|
|
8
22
|
from tpu_inference.layers.common import quant_methods
|
|
9
23
|
from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
|
|
10
|
-
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
11
24
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
|
|
12
|
-
VllmCompressedTensorsConfig
|
|
25
|
+
VllmCompressedTensorsConfig
|
|
26
|
+
from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig
|
|
27
|
+
from tpu_inference.layers.vllm.quantization.fp8 import VllmFp8Config
|
|
13
28
|
from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
|
|
14
29
|
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
15
30
|
VllmUnquantizedConfig
|
|
@@ -23,6 +38,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
|
|
|
23
38
|
None: VllmUnquantizedConfig,
|
|
24
39
|
quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
|
|
25
40
|
quant_methods.AWQ: VllmAWQConfig,
|
|
41
|
+
quant_methods.FP8: VllmFp8Config,
|
|
26
42
|
quant_methods.MXFP4: VllmMxfp4Config,
|
|
27
43
|
}
|
|
28
44
|
if model_config.quantization not in method_to_config:
|
|
@@ -30,7 +46,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
|
|
|
30
46
|
f"{model_config.quantization} quantization method not supported."
|
|
31
47
|
f" Supported methods are {method_to_config.keys()}")
|
|
32
48
|
quant_config = method_to_config[model_config.quantization]
|
|
33
|
-
assert issubclass(quant_config,
|
|
49
|
+
assert issubclass(quant_config, VllmQuantConfig)
|
|
34
50
|
quant_config.set_configs(vllm_config, mesh)
|
|
35
51
|
|
|
36
52
|
model_config.quantization = quant_methods.get_tpu_quant_method(
|
|
@@ -1,11 +1,26 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Optional, Union
|
|
2
16
|
|
|
3
17
|
import jax
|
|
4
18
|
import jax.numpy as jnp
|
|
5
19
|
import torch
|
|
6
|
-
from jax.sharding import
|
|
20
|
+
from jax.sharding import PartitionSpec
|
|
21
|
+
from torch.nn.parameter import Parameter
|
|
7
22
|
from torchax.interop import jax_view, torch_view
|
|
8
|
-
from
|
|
23
|
+
from torchax.ops.mappings import t2j
|
|
9
24
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
10
25
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
|
11
26
|
from vllm.model_executor.layers.quantization import \
|
|
@@ -14,24 +29,29 @@ from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
|
|
14
29
|
AWQLinearMethod)
|
|
15
30
|
from vllm.model_executor.layers.quantization.base_config import \
|
|
16
31
|
QuantizeMethodBase
|
|
17
|
-
from vllm.model_executor.layers.quantization.utils.quant_utils import
|
|
18
|
-
is_layer_skipped
|
|
19
|
-
from vllm.scalar_type import scalar_types
|
|
32
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
33
|
+
is_layer_skipped
|
|
20
34
|
|
|
21
35
|
from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
|
|
22
|
-
from tpu_inference.layers.
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
36
|
+
from tpu_inference.layers.common.quantization import awq_u32_unpack_u4
|
|
37
|
+
from tpu_inference.layers.common.utils import \
|
|
38
|
+
slice_sharded_tensor_for_concatenation
|
|
39
|
+
from tpu_inference.layers.vllm.process_weights.linear_weights import (
|
|
40
|
+
LinearWeights, process_lienar_weights, shard_linear_weights,
|
|
41
|
+
to_parameter_list)
|
|
42
|
+
from tpu_inference.layers.vllm.quantization.configs import (
|
|
43
|
+
VllmQuantConfig, VllmQuantLinearConfig)
|
|
26
44
|
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
27
45
|
VllmUnquantizedLinearMethod
|
|
46
|
+
from tpu_inference.logger import init_logger
|
|
28
47
|
|
|
29
48
|
P = PartitionSpec
|
|
49
|
+
|
|
30
50
|
logger = init_logger(__name__)
|
|
31
51
|
|
|
32
52
|
|
|
33
53
|
@register_quantization_config(get_tpu_quant_method(AWQ))
|
|
34
|
-
class VllmAWQConfig(AWQConfig,
|
|
54
|
+
class VllmAWQConfig(AWQConfig, VllmQuantConfig):
|
|
35
55
|
|
|
36
56
|
@classmethod
|
|
37
57
|
def get_name(cls):
|
|
@@ -39,7 +59,7 @@ class VllmAWQConfig(AWQConfig, JaxCommonConfig):
|
|
|
39
59
|
|
|
40
60
|
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
41
61
|
# NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
|
|
42
|
-
# bfloat16 is
|
|
62
|
+
# bfloat16 is significantly preferred over float16. This might lead to
|
|
43
63
|
# some numeric output change.
|
|
44
64
|
return [torch.bfloat16]
|
|
45
65
|
|
|
@@ -60,72 +80,79 @@ class VllmAWQConfig(AWQConfig, JaxCommonConfig):
|
|
|
60
80
|
class VllmAWQLinearMethod(AWQLinearMethod):
|
|
61
81
|
|
|
62
82
|
def __init__(self, quant_config: VllmAWQConfig,
|
|
63
|
-
|
|
83
|
+
linear_config: VllmQuantLinearConfig):
|
|
64
84
|
super().__init__(quant_config)
|
|
65
|
-
self.
|
|
66
|
-
|
|
67
|
-
out_sharding, in_sharding = self.jax_config.weight_sharding[:]
|
|
68
|
-
self.jax_config.weight_sharding = P(in_sharding, None, out_sharding)
|
|
69
|
-
self.jax_config.scale_sharding = P(in_sharding, out_sharding)
|
|
85
|
+
self.linear_config = linear_config
|
|
70
86
|
|
|
71
87
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
72
|
-
qweight
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
group_size = self.quant_config.group_size
|
|
76
|
-
# Reshape so that each qweight[i] were quantized with same scales[i].
|
|
77
|
-
qweight = qweight.reshape((-1, group_size, layer.output_size))
|
|
78
|
-
qweight = torch_to_jax_param(qweight,
|
|
79
|
-
NamedSharding(
|
|
80
|
-
self.jax_config.mesh,
|
|
81
|
-
self.jax_config.weight_sharding),
|
|
82
|
-
self.jax_config.output_sizes,
|
|
83
|
-
self.jax_config.n_shards,
|
|
84
|
-
self.jax_config.fuse_matmuls,
|
|
85
|
-
dim=2,
|
|
86
|
-
jax_dtype=jnp.uint4)
|
|
88
|
+
assert layer.qweight.packed_dim == layer.qweight.ndim - 1
|
|
89
|
+
weight = t2j(layer.qweight, use_dlpack=False)
|
|
87
90
|
delattr(layer, "qweight")
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
qzeros = layer.qzeros
|
|
91
|
-
qzeros = unpack_awq_weight(qzeros, qzeros.packed_dim)
|
|
92
|
-
qzeros = torch_to_jax_param(qzeros,
|
|
93
|
-
NamedSharding(
|
|
94
|
-
self.jax_config.mesh,
|
|
95
|
-
self.jax_config.scale_sharding),
|
|
96
|
-
self.jax_config.output_sizes,
|
|
97
|
-
self.jax_config.n_shards,
|
|
98
|
-
self.jax_config.fuse_matmuls,
|
|
99
|
-
dim=1,
|
|
100
|
-
jax_dtype=jnp.uint4)
|
|
101
|
-
delattr(layer, "qzeros")
|
|
102
|
-
layer.qzeros = qzeros
|
|
103
|
-
|
|
104
|
-
scales = torch_to_jax_param(layer.scales,
|
|
105
|
-
NamedSharding(
|
|
106
|
-
self.jax_config.mesh,
|
|
107
|
-
self.jax_config.scale_sharding),
|
|
108
|
-
self.jax_config.output_sizes,
|
|
109
|
-
self.jax_config.n_shards,
|
|
110
|
-
self.jax_config.fuse_matmuls,
|
|
111
|
-
dim=1)
|
|
91
|
+
|
|
92
|
+
weight_scale = t2j(layer.scales, use_dlpack=False)
|
|
112
93
|
delattr(layer, "scales")
|
|
113
|
-
|
|
94
|
+
|
|
95
|
+
assert layer.qzeros.packed_dim == layer.qzeros.ndim - 1
|
|
96
|
+
zero_point = t2j(layer.qzeros, use_dlpack=False)
|
|
97
|
+
delattr(layer, "qzeros")
|
|
114
98
|
|
|
115
99
|
if layer.bias is not None and not layer.skip_bias_add:
|
|
116
100
|
if layer.return_bias:
|
|
117
101
|
logger.warning_once("Bias might return incorrect value.")
|
|
118
|
-
|
|
119
|
-
bias = torch_to_jax_param(
|
|
120
|
-
layer.bias,
|
|
121
|
-
NamedSharding(self.jax_config.mesh,
|
|
122
|
-
self.jax_config.bias_sharding),
|
|
123
|
-
self.jax_config.output_sizes,
|
|
124
|
-
self.jax_config.n_shards,
|
|
125
|
-
self.jax_config.fuse_matmuls,
|
|
126
|
-
)
|
|
102
|
+
bias = t2j(layer.bias, use_dlpack=False)
|
|
127
103
|
delattr(layer, "bias")
|
|
128
|
-
|
|
104
|
+
else:
|
|
105
|
+
bias = None
|
|
106
|
+
|
|
107
|
+
@jax.jit
|
|
108
|
+
def process_awq_linear_weights(
|
|
109
|
+
weight: jax.Array,
|
|
110
|
+
weight_scale: jax.Array,
|
|
111
|
+
zero_point: jax.Array,
|
|
112
|
+
bias: jax.Array | None,
|
|
113
|
+
) -> LinearWeights:
|
|
114
|
+
weight = awq_u32_unpack_u4(weight)
|
|
115
|
+
group_size = self.quant_config.group_size
|
|
116
|
+
weight = weight.reshape((-1, group_size, weight.shape[-1]))
|
|
117
|
+
|
|
118
|
+
zero_point = awq_u32_unpack_u4(zero_point)
|
|
119
|
+
|
|
120
|
+
return process_lienar_weights(
|
|
121
|
+
LinearWeights(
|
|
122
|
+
weight=weight,
|
|
123
|
+
weight_scale=weight_scale,
|
|
124
|
+
zero_point=zero_point,
|
|
125
|
+
bias=bias,
|
|
126
|
+
),
|
|
127
|
+
fused=self.linear_config.fuse_matmuls,
|
|
128
|
+
output_sizes=self.linear_config.output_sizes,
|
|
129
|
+
reorder_size=self.linear_config.n_shards,
|
|
130
|
+
transposed=False,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
weights = process_awq_linear_weights(weight, weight_scale, zero_point,
|
|
134
|
+
bias)
|
|
135
|
+
weights = torch_view(
|
|
136
|
+
shard_linear_weights(
|
|
137
|
+
weights,
|
|
138
|
+
mesh=self.linear_config.mesh,
|
|
139
|
+
weight_p_spec=self.linear_config.weight_sharding,
|
|
140
|
+
bias_p_spec=self.linear_config.bias_sharding,
|
|
141
|
+
transposed=False,
|
|
142
|
+
))
|
|
143
|
+
|
|
144
|
+
if self.linear_config.fuse_matmuls:
|
|
145
|
+
layer.qweight = Parameter(weights.weight, requires_grad=False)
|
|
146
|
+
layer.scales = Parameter(weights.weight_scale, requires_grad=False)
|
|
147
|
+
layer.qzeros = Parameter(weights.zero_point, requires_grad=False)
|
|
148
|
+
if bias is not None:
|
|
149
|
+
layer.bias = Parameter(weights.bias, requires_grad=False)
|
|
150
|
+
else:
|
|
151
|
+
layer.qweight = to_parameter_list(weights.weight)
|
|
152
|
+
layer.scales = to_parameter_list(weights.weight_scale)
|
|
153
|
+
layer.qzeros = to_parameter_list(weights.zero_point)
|
|
154
|
+
if bias is not None:
|
|
155
|
+
layer.bias = to_parameter_list(weights.bias)
|
|
129
156
|
|
|
130
157
|
def apply(self,
|
|
131
158
|
layer: torch.nn.Module,
|
|
@@ -133,7 +160,7 @@ class VllmAWQLinearMethod(AWQLinearMethod):
|
|
|
133
160
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
134
161
|
|
|
135
162
|
with jax.named_scope(layer._get_name()):
|
|
136
|
-
if self.
|
|
163
|
+
if self.linear_config.fuse_matmuls:
|
|
137
164
|
out = self._apply_fused(layer, x, bias)
|
|
138
165
|
else:
|
|
139
166
|
out = self._apply_split(layer, x, bias)
|
|
@@ -161,7 +188,7 @@ class VllmAWQLinearMethod(AWQLinearMethod):
|
|
|
161
188
|
outs += bias.jax()
|
|
162
189
|
|
|
163
190
|
outs = slice_sharded_tensor_for_concatenation(
|
|
164
|
-
outs, self.
|
|
191
|
+
outs, self.linear_config.output_sizes, self.linear_config.n_shards)
|
|
165
192
|
out = jnp.concatenate(outs, axis=-1)
|
|
166
193
|
return torch_view(out)
|
|
167
194
|
|
|
@@ -192,16 +219,3 @@ class VllmAWQLinearMethod(AWQLinearMethod):
|
|
|
192
219
|
outs.append(out)
|
|
193
220
|
out = jnp.concatenate(outs, axis=-1)
|
|
194
221
|
return torch_view(out)
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
def unpack_awq_weight(weight: torch.Tensor, packed_dim: int):
|
|
198
|
-
weight = unpack_quantized_values_into_int32(weight, scalar_types.uint4,
|
|
199
|
-
packed_dim)
|
|
200
|
-
|
|
201
|
-
# AWQ packs 8 uint4 into 32-bits in this order: (0, 2, 4, 6, 1, 3, 5, 7).
|
|
202
|
-
# Following list maps the order used by AWQ into an ascending order.
|
|
203
|
-
reverse_awq_order = (0, 4, 1, 5, 2, 6, 3, 7)
|
|
204
|
-
|
|
205
|
-
orig_shape = weight.shape
|
|
206
|
-
weight = weight.reshape(orig_shape[:-1] + (-1, 8))
|
|
207
|
-
return weight[..., reverse_awq_order].reshape(orig_shape)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,9 +1,22 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Optional
|
|
2
16
|
|
|
3
17
|
import torch
|
|
4
18
|
from jax.sharding import PartitionSpec
|
|
5
19
|
from vllm.attention.layer import Attention
|
|
6
|
-
from vllm.logger import init_logger
|
|
7
20
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
8
21
|
from vllm.model_executor.layers.linear import LinearBase
|
|
9
22
|
from vllm.model_executor.layers.quantization import \
|
|
@@ -18,22 +31,23 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
|
18
31
|
|
|
19
32
|
from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
|
|
20
33
|
get_tpu_quant_method)
|
|
21
|
-
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
22
34
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
|
|
23
35
|
VllmCompressedTensorsMoEMethod
|
|
24
36
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
|
|
25
37
|
VllmCompressedTensorsW8A8Fp8
|
|
26
38
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
|
|
27
39
|
VllmCompressedTensorsW8A8Int8
|
|
40
|
+
from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig
|
|
28
41
|
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
29
42
|
VllmUnquantizedConfig
|
|
43
|
+
from tpu_inference.logger import init_logger
|
|
30
44
|
|
|
31
45
|
P = PartitionSpec
|
|
32
46
|
logger = init_logger(__name__)
|
|
33
47
|
|
|
34
48
|
|
|
35
49
|
@register_quantization_config(get_tpu_quant_method(COMPRESSED_TENSORS))
|
|
36
|
-
class VllmCompressedTensorsConfig(CompressedTensorsConfig,
|
|
50
|
+
class VllmCompressedTensorsConfig(CompressedTensorsConfig, VllmQuantConfig):
|
|
37
51
|
|
|
38
52
|
@classmethod
|
|
39
53
|
def get_name(cls) -> str:
|
|
@@ -84,14 +98,14 @@ class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
|
|
|
84
98
|
return VllmCompressedTensorsW8A8Fp8(
|
|
85
99
|
weight_quant=weight_quant,
|
|
86
100
|
is_static_input_scheme=is_static_input_scheme,
|
|
87
|
-
|
|
101
|
+
linear_config=linear_config,
|
|
88
102
|
)
|
|
89
103
|
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
|
90
104
|
return VllmCompressedTensorsW8A8Int8(
|
|
91
105
|
strategy=weight_quant.strategy,
|
|
92
106
|
is_static_input_scheme=False,
|
|
93
107
|
input_symmetric=input_quant.symmetric,
|
|
94
|
-
|
|
108
|
+
linear_config=linear_config,
|
|
95
109
|
)
|
|
96
110
|
raise NotImplementedError(
|
|
97
111
|
"No compressed-tensors compatible scheme was found.")
|