tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,653 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
import functools
|
|
4
|
+
import os
|
|
5
|
+
from typing import TYPE_CHECKING, Callable, List
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import jax.numpy as jnp
|
|
9
|
+
import qwix
|
|
10
|
+
import qwix.pallas as qpl
|
|
11
|
+
import yaml
|
|
12
|
+
from flax import nnx
|
|
13
|
+
from flax.typing import PRNGKey
|
|
14
|
+
from jax.sharding import Mesh, NamedSharding
|
|
15
|
+
from jax.sharding import PartitionSpec as P
|
|
16
|
+
from qwix._src.core.qarray import QArray
|
|
17
|
+
from qwix._src.providers import ptq
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from vllm.config import VllmConfig
|
|
21
|
+
|
|
22
|
+
from tpu_inference import utils
|
|
23
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
24
|
+
from tpu_inference.logger import init_logger
|
|
25
|
+
from tpu_inference.runner.kv_cache import (DEFAULT_KV_CACHE_DTYPE,
|
|
26
|
+
create_kv_caches)
|
|
27
|
+
from tpu_inference.utils import device_array
|
|
28
|
+
|
|
29
|
+
logger = init_logger(__name__)
|
|
30
|
+
|
|
31
|
+
QUANTIZATION_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs")
|
|
32
|
+
DEFAULT_NUM_BLOCKS_FOR_JIT_KV_CACHE = 2000
|
|
33
|
+
DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS = 512
|
|
34
|
+
DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS = 256
|
|
35
|
+
DEFAULT_MAX_NUM_BLOCKS_PER_REQ = 16
|
|
36
|
+
|
|
37
|
+
DEFAULT_DEEPSEEK_FP8_CONFIG = {
|
|
38
|
+
"qwix": {
|
|
39
|
+
"use_abstract_model":
|
|
40
|
+
True,
|
|
41
|
+
"scale_dtype":
|
|
42
|
+
"bfloat16",
|
|
43
|
+
"rules": [
|
|
44
|
+
{
|
|
45
|
+
"module_path": ".*.custom_module.router.*",
|
|
46
|
+
"weight_qtype": None,
|
|
47
|
+
},
|
|
48
|
+
{
|
|
49
|
+
"module_path": ".*",
|
|
50
|
+
"weight_qtype": "float8_e4m3fn",
|
|
51
|
+
"act_qtype": "float8_e4m3fn",
|
|
52
|
+
},
|
|
53
|
+
],
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
DEFAULT_LLAMA4_FP8_CONFIG = {
|
|
58
|
+
"qwix": {
|
|
59
|
+
"use_abstract_model":
|
|
60
|
+
True,
|
|
61
|
+
"scale_dtype":
|
|
62
|
+
"bfloat16",
|
|
63
|
+
"rules": [
|
|
64
|
+
{
|
|
65
|
+
"module_path": "layers.*.moe_ffw",
|
|
66
|
+
"op_names": "einsum",
|
|
67
|
+
"weight_qtype": "float8_e4m3fn",
|
|
68
|
+
"act_qtype": "float8_e4m3fn",
|
|
69
|
+
},
|
|
70
|
+
],
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
# Default Qwix config for GPT-OSS MXFP4 checkpoints.
|
|
75
|
+
# Notes:
|
|
76
|
+
# - We quantize only the MoE expert weights by default (router stays in BF16).
|
|
77
|
+
# - We use Qwix's abstract-model path so weights can be set directly into QArray
|
|
78
|
+
# fields during weight loading (similar to DeepSeek's flow).
|
|
79
|
+
# - Activation quantization is not set but Qwix would pickup MoE sum if activated
|
|
80
|
+
DEFAULT_GPT_OSS_FP4_CONFIG = {
|
|
81
|
+
"qwix": {
|
|
82
|
+
"use_abstract_model":
|
|
83
|
+
True,
|
|
84
|
+
"scale_dtype":
|
|
85
|
+
"bfloat16",
|
|
86
|
+
"rules": [
|
|
87
|
+
{
|
|
88
|
+
"module_path": ".*custom_module",
|
|
89
|
+
"weight_qtype": "float4_e2m1fn",
|
|
90
|
+
"act_qtype": None,
|
|
91
|
+
"tile_size": 32,
|
|
92
|
+
},
|
|
93
|
+
],
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def parse_qwix_config_to_rules(
|
|
99
|
+
qwix_config: List[dict]) -> List[qwix.QuantizationRule]:
|
|
100
|
+
"""
|
|
101
|
+
Parse a list of dictionaries containing Qwix quantization rules into a list of QuantizationRule objects.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
qwix_config: a dictionary containing the Qwix quantization rules
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
a list of QuantizationRule objects
|
|
108
|
+
"""
|
|
109
|
+
rules = []
|
|
110
|
+
for rule in qwix_config:
|
|
111
|
+
rules.append(qwix.QuantizationRule(**rule))
|
|
112
|
+
|
|
113
|
+
return rules
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
|
|
117
|
+
rng: jax.Array, mesh: Mesh, num_hidden_layers: int,
|
|
118
|
+
kv_cache_block_size: int,
|
|
119
|
+
kv_cache_num_kv_heads: int,
|
|
120
|
+
kv_cache_head_size: int,
|
|
121
|
+
kv_cache_dtype: str) -> nnx.Module:
|
|
122
|
+
"""
|
|
123
|
+
Quantizes a Flax NNX model using Qwix.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
model: the model to quantize
|
|
127
|
+
qwix_config: a list of dictionaries, where each dictionary corresponds to a Qwix quantization rule
|
|
128
|
+
For example:
|
|
129
|
+
[
|
|
130
|
+
{
|
|
131
|
+
"module_path": ".*attn.*",
|
|
132
|
+
"weight_qtype": "int8",
|
|
133
|
+
},
|
|
134
|
+
{
|
|
135
|
+
"module_path": ".*mlp.*",
|
|
136
|
+
"weight_qtype": "int8",
|
|
137
|
+
"act_qtype": "int8",
|
|
138
|
+
"tile_size": None,
|
|
139
|
+
},
|
|
140
|
+
]
|
|
141
|
+
rng: the random number generator to use
|
|
142
|
+
mesh: the mesh to use
|
|
143
|
+
num_hidden_layers: the number of hidden layers in the model
|
|
144
|
+
kv_cache_page_size: the page size of the kv cache
|
|
145
|
+
kv_cache_num_kv_heads: the number of kv heads
|
|
146
|
+
head_size: the head size of the kv cache
|
|
147
|
+
kv_cache_dtype: the dtype of the kv cache
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
model: the quantized model
|
|
151
|
+
"""
|
|
152
|
+
qwix_rules = parse_qwix_config_to_rules(qwix_config)
|
|
153
|
+
logger.info(f"Qwix rules: {qwix_rules}")
|
|
154
|
+
logger.info(f"Memory usage before applying quantization of params: "
|
|
155
|
+
f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
|
|
156
|
+
|
|
157
|
+
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
|
|
158
|
+
kv_cache_jnp_dtype = utils.get_jax_dtype_from_str_dtype(kv_cache_dtype)
|
|
159
|
+
|
|
160
|
+
# Handle the case where kv_cache_dtype is "auto"
|
|
161
|
+
if kv_cache_jnp_dtype is None:
|
|
162
|
+
assert kv_cache_dtype == "auto", "kv_cache_dtype must be 'auto' if kv_cache_jnp_dtype is None"
|
|
163
|
+
kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
|
|
164
|
+
|
|
165
|
+
kv_caches = create_kv_caches(
|
|
166
|
+
num_blocks=DEFAULT_NUM_BLOCKS_FOR_JIT_KV_CACHE,
|
|
167
|
+
block_size=kv_cache_block_size,
|
|
168
|
+
num_kv_heads=kv_cache_num_kv_heads,
|
|
169
|
+
head_size=kv_cache_head_size,
|
|
170
|
+
mesh=mesh,
|
|
171
|
+
layer_names=[f"layer.{i}" for i in range(num_hidden_layers)],
|
|
172
|
+
cache_dtype=kv_cache_jnp_dtype)
|
|
173
|
+
|
|
174
|
+
dp_size = mesh.shape.get("data", 1) * mesh.shape.get("attn", 1)
|
|
175
|
+
|
|
176
|
+
# NOTE: the inputs don't need to match the actual ones, as long as the consumed weights are the same
|
|
177
|
+
input_ids = jax.random.randint(rng,
|
|
178
|
+
(DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS, ),
|
|
179
|
+
0,
|
|
180
|
+
100,
|
|
181
|
+
dtype=jnp.int32)
|
|
182
|
+
positions = jax.random.randint(rng,
|
|
183
|
+
(DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS, ),
|
|
184
|
+
0,
|
|
185
|
+
100,
|
|
186
|
+
dtype=jnp.int32)
|
|
187
|
+
block_tables = jax.random.randint(rng,
|
|
188
|
+
(DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS *
|
|
189
|
+
DEFAULT_MAX_NUM_BLOCKS_PER_REQ, ),
|
|
190
|
+
0,
|
|
191
|
+
100,
|
|
192
|
+
dtype=jnp.int32)
|
|
193
|
+
query_start_loc = jax.random.randint(
|
|
194
|
+
rng, (DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS + dp_size, ),
|
|
195
|
+
0,
|
|
196
|
+
100,
|
|
197
|
+
dtype=jnp.int32)
|
|
198
|
+
seq_lens = jax.random.randint(rng,
|
|
199
|
+
(DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS, ),
|
|
200
|
+
0,
|
|
201
|
+
100,
|
|
202
|
+
dtype=jnp.int32)
|
|
203
|
+
num_seqs = jax.random.randint(rng, (1, ), 0, 100, dtype=jnp.int32)
|
|
204
|
+
request_distribution = jnp.array([0, 0, num_seqs[0]] * dp_size,
|
|
205
|
+
dtype=jnp.int32)
|
|
206
|
+
|
|
207
|
+
(input_ids, positions, block_tables,
|
|
208
|
+
query_start_loc, seq_lens, request_distribution) = device_array(
|
|
209
|
+
mesh, (input_ids, positions, block_tables, query_start_loc, seq_lens,
|
|
210
|
+
request_distribution))
|
|
211
|
+
|
|
212
|
+
model_input = {
|
|
213
|
+
"kv_caches":
|
|
214
|
+
kv_caches,
|
|
215
|
+
"input_ids":
|
|
216
|
+
input_ids,
|
|
217
|
+
"attention_metadata":
|
|
218
|
+
AttentionMetadata(input_positions=positions,
|
|
219
|
+
block_tables=block_tables,
|
|
220
|
+
seq_lens=seq_lens,
|
|
221
|
+
query_start_loc=query_start_loc,
|
|
222
|
+
request_distribution=request_distribution),
|
|
223
|
+
}
|
|
224
|
+
model = qwix.quantize_model(model, qwix.PtqProvider(qwix_rules),
|
|
225
|
+
**model_input)
|
|
226
|
+
return model
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def quantization_config_file_path_to_dict(
|
|
230
|
+
quantization_config_file_path: str) -> dict:
|
|
231
|
+
"""
|
|
232
|
+
Converts a quantization config YAML file path to a dictionary.
|
|
233
|
+
|
|
234
|
+
The expected format of the quantization config YAML file is as follows:
|
|
235
|
+
```yaml
|
|
236
|
+
qwix:
|
|
237
|
+
# optional, defaults to False if not specified
|
|
238
|
+
use_abstract_model: True
|
|
239
|
+
rules:
|
|
240
|
+
# NOTE: each entry corresponds to a qwix.QuantizationRule
|
|
241
|
+
- module_path: '.*attn.*'
|
|
242
|
+
weight_qtype: 'int8'
|
|
243
|
+
- module_path: '.*'
|
|
244
|
+
weight_qtype: 'int8'
|
|
245
|
+
act_qtype: 'int8'
|
|
246
|
+
```
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
quantization_config_file_path: the path to the quantization config YAML file
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
a dictionary containing the quantization config
|
|
253
|
+
"""
|
|
254
|
+
all_entries = os.listdir(QUANTIZATION_CONFIG_PATH)
|
|
255
|
+
for filename in all_entries:
|
|
256
|
+
if filename == quantization_config_file_path:
|
|
257
|
+
path = os.path.join(QUANTIZATION_CONFIG_PATH, filename)
|
|
258
|
+
with open(path, "r") as f:
|
|
259
|
+
return yaml.safe_load(f)
|
|
260
|
+
raise ValueError(
|
|
261
|
+
f"Could not find quantization config file with name '{quantization_config_file_path}' in 'tpu_inference/models/jax/utils/quantization/configs."
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def apply_qwix_quantization(
|
|
266
|
+
vllm_config: "VllmConfig", model_or_model_fn: Callable | nnx.Module,
|
|
267
|
+
rng: jax.Array, mesh: Mesh,
|
|
268
|
+
apply_to_abstract_model: bool) -> nnx.Module | Callable:
|
|
269
|
+
"""
|
|
270
|
+
Will apply quantization if a valid quantization config with Qwix rules is provided. See README
|
|
271
|
+
for more details on Qwix.
|
|
272
|
+
|
|
273
|
+
Note that we currently support different methods for applying Qwix quantization. The typical
|
|
274
|
+
approach is to apply quantization on the concrete model, which already has the weights
|
|
275
|
+
loaded in. However, for models like DeepSeek, which are already quantized, we need to
|
|
276
|
+
first create the abstract model, then apply Qwix quantization to the abstract model, and
|
|
277
|
+
finally load the weights in. To use the latter approach, you will need to modify the
|
|
278
|
+
model weight loading code appropriately (see deepseek_v3.py for an example) and
|
|
279
|
+
pass and `use_abstract_model=True` in the quantization config.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
vllm_config: the base VLLM config
|
|
283
|
+
model_or_model_fn: if `apply_to_abstract_model` is True, this will be a Callable that returns the abstract model
|
|
284
|
+
(e.g. _create_abstract_model). Otherwise, this will be the concrete model (nnx.Module).
|
|
285
|
+
rng: JAX RNG
|
|
286
|
+
mesh: model Mesh
|
|
287
|
+
apply_to_abstract_model: if True, we will apply Qwix quantization to the abstract model, which
|
|
288
|
+
assumes that, during weight loading, the caller will thus override the QArray weights
|
|
289
|
+
(see deepseek_v3.py for an example). Otherwise, we will will apply Qwix quantization to the
|
|
290
|
+
concrete model, which already has the weights loaded in.
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
Either the concrete model (nnx.Module) or the abstract model (Callable) (if `apply_to_abstract_model` is True)
|
|
294
|
+
"""
|
|
295
|
+
qwix_config = None
|
|
296
|
+
if quantization_config := vllm_config.additional_config.get(
|
|
297
|
+
"quantization"):
|
|
298
|
+
qwix_config = quantization_config.get("qwix").get("rules")
|
|
299
|
+
if not qwix_config:
|
|
300
|
+
return model_or_model_fn
|
|
301
|
+
|
|
302
|
+
logging_abstract_model_str = "abstract" if apply_to_abstract_model else "concrete"
|
|
303
|
+
logger.info(
|
|
304
|
+
f"Applying Qwix quantization on {logging_abstract_model_str} model")
|
|
305
|
+
|
|
306
|
+
block_size = vllm_config.cache_config.block_size
|
|
307
|
+
model_config = vllm_config.model_config
|
|
308
|
+
|
|
309
|
+
# Pad num_kv_heads to multiple of TP size
|
|
310
|
+
num_kv_heads = utils.get_padded_num_heads(
|
|
311
|
+
model_config.get_total_num_kv_heads(), mesh.shape["model"])
|
|
312
|
+
|
|
313
|
+
# Pad head_dim to multiple of 128
|
|
314
|
+
head_size = model_config.get_head_size()
|
|
315
|
+
head_size = utils.get_padded_head_dim(head_size)
|
|
316
|
+
|
|
317
|
+
kv_cache_dtype = vllm_config.cache_config.cache_dtype
|
|
318
|
+
|
|
319
|
+
if not apply_to_abstract_model:
|
|
320
|
+
assert isinstance(model_or_model_fn, nnx.Module)
|
|
321
|
+
qwix_quantize_nnx_model_with_config = functools.partial(
|
|
322
|
+
qwix_quantize_nnx_model, qwix_config=qwix_config)
|
|
323
|
+
# NOTE: it's REALLY important `qwix_quantize_nnx_model_with_config` is jitted
|
|
324
|
+
# or else you'll run into hanging
|
|
325
|
+
model_or_model_fn = nnx.jit(
|
|
326
|
+
qwix_quantize_nnx_model_with_config,
|
|
327
|
+
donate_argnums=(0, ),
|
|
328
|
+
static_argnames=(
|
|
329
|
+
"mesh",
|
|
330
|
+
"num_hidden_layers",
|
|
331
|
+
"kv_cache_block_size",
|
|
332
|
+
"kv_cache_num_kv_heads",
|
|
333
|
+
"kv_cache_head_size",
|
|
334
|
+
"kv_cache_dtype",
|
|
335
|
+
))(model=model_or_model_fn,
|
|
336
|
+
rng=rng,
|
|
337
|
+
mesh=mesh,
|
|
338
|
+
num_hidden_layers=vllm_config.model_config.hf_config.
|
|
339
|
+
num_hidden_layers,
|
|
340
|
+
kv_cache_block_size=block_size,
|
|
341
|
+
kv_cache_num_kv_heads=num_kv_heads,
|
|
342
|
+
kv_cache_head_size=head_size,
|
|
343
|
+
kv_cache_dtype=kv_cache_dtype)
|
|
344
|
+
|
|
345
|
+
return model_or_model_fn
|
|
346
|
+
|
|
347
|
+
hf_config = vllm_config.model_config.hf_config
|
|
348
|
+
if hasattr(hf_config, "text_config") and hasattr(hf_config.text_config,
|
|
349
|
+
"num_hidden_layers"):
|
|
350
|
+
num_hidden_layers = hf_config.text_config.num_hidden_layers
|
|
351
|
+
logger.info(
|
|
352
|
+
f"Using num_hidden_layers from hf_config.text_config: {num_hidden_layers}"
|
|
353
|
+
)
|
|
354
|
+
elif hasattr(hf_config, "num_hidden_layers"):
|
|
355
|
+
num_hidden_layers = hf_config.num_hidden_layers
|
|
356
|
+
logger.info(
|
|
357
|
+
f"Using num_hidden_layers directly from hf_config: {num_hidden_layers}"
|
|
358
|
+
)
|
|
359
|
+
else:
|
|
360
|
+
raise AttributeError(
|
|
361
|
+
"Could not find 'num_hidden_layers' in hf_config or hf_config.text_config."
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
qwix_quantize_fn_for_eval = functools.partial(
|
|
365
|
+
qwix_quantize_nnx_model,
|
|
366
|
+
qwix_config=qwix_config,
|
|
367
|
+
mesh=mesh,
|
|
368
|
+
num_hidden_layers=num_hidden_layers,
|
|
369
|
+
kv_cache_block_size=block_size,
|
|
370
|
+
kv_cache_num_kv_heads=num_kv_heads,
|
|
371
|
+
kv_cache_head_size=head_size,
|
|
372
|
+
kv_cache_dtype=kv_cache_dtype)
|
|
373
|
+
|
|
374
|
+
def create_and_quantize_model_factory() -> Callable:
|
|
375
|
+
"""
|
|
376
|
+
Helper function to create and quantize the abstract model.
|
|
377
|
+
"""
|
|
378
|
+
model = model_or_model_fn()
|
|
379
|
+
# Handle the DeepSeek case, where this needs to be called in the abstract model
|
|
380
|
+
if hasattr(model, 'initialize_cache'):
|
|
381
|
+
model.initialize_cache()
|
|
382
|
+
return qwix_quantize_fn_for_eval(model=model, rng=rng)
|
|
383
|
+
|
|
384
|
+
return create_and_quantize_model_factory
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def apply_qwix_on_abstract_model(vllm_config: "VllmConfig") -> bool:
|
|
388
|
+
"""
|
|
389
|
+
Determines whether to apply Qwix quantization on the abstract model (e.g. for DeepSeek)
|
|
390
|
+
or the concrete model. See `apply_qwix_quantization` for more details on the differences
|
|
391
|
+
between these two approaches.
|
|
392
|
+
Args:
|
|
393
|
+
vllm_config: the vllm config
|
|
394
|
+
Returns:
|
|
395
|
+
whether to apply Qwix quantization on the abstract model
|
|
396
|
+
"""
|
|
397
|
+
quantization_config = vllm_config.additional_config.get("quantization", {})
|
|
398
|
+
return quantization_config.get("qwix", {}).get("use_abstract_model", False)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def get_default_qwix_quantization_config(
|
|
402
|
+
model_type: str, quant_method: str,
|
|
403
|
+
skip_quantization: bool) -> dict | None:
|
|
404
|
+
"""
|
|
405
|
+
Some models are pre-quantized and in those cases, we want to return a default set of
|
|
406
|
+
Qwix quantization rules (instead of forcing the user to pass in a quantization config each time).
|
|
407
|
+
|
|
408
|
+
Note that if a user passes in a quantization config (via `additional_config`), then
|
|
409
|
+
we'll use that instead of this function.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
model_type: the name of the model
|
|
413
|
+
quant_method: the quantization method
|
|
414
|
+
skip_quantization: whether to skip quantization. In this case, we'll return None
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
a dictionary containing the default Qwix quantization rules
|
|
418
|
+
"""
|
|
419
|
+
if skip_quantization:
|
|
420
|
+
return None
|
|
421
|
+
# TODO (jacobplatin): remove this so that we can support various quantization types
|
|
422
|
+
if model_type == "deepseek_v3" and quant_method == "fp8":
|
|
423
|
+
return DEFAULT_DEEPSEEK_FP8_CONFIG
|
|
424
|
+
elif model_type == "llama4" and quant_method == "compressed-tensors":
|
|
425
|
+
return DEFAULT_LLAMA4_FP8_CONFIG
|
|
426
|
+
# MXFP4 (GPT-OSS): provide a default configuration to quantize MoE experts via Qwix
|
|
427
|
+
elif model_type == "gpt_oss" and quant_method == "mxfp4":
|
|
428
|
+
return DEFAULT_GPT_OSS_FP4_CONFIG
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def update_vllm_config_for_qwix_quantization(vllm_config: "VllmConfig"):
|
|
432
|
+
"""
|
|
433
|
+
Updates the vLLM config to unpack the Qwix quantization config if it exists.
|
|
434
|
+
By default, we'll check if the checkpoint is quantized and update the
|
|
435
|
+
Qwix quantization config to use the default quantization config if it exists,
|
|
436
|
+
but we'll override this if the user passes in a quantization config via `additional_config`.
|
|
437
|
+
"""
|
|
438
|
+
# Automatically detect whether checkpoint is quantized and update the
|
|
439
|
+
# Qwix quantization config accordingly
|
|
440
|
+
# NOTE: if a Qwix config is provided (via the`additional_config`), we'll
|
|
441
|
+
# use that instead
|
|
442
|
+
model_type = vllm_config.model_config.hf_config.model_type.lower(
|
|
443
|
+
) if hasattr(vllm_config.model_config.hf_config, "model_type") else None
|
|
444
|
+
quant_method = vllm_config.model_config.hf_config.quantization_config[
|
|
445
|
+
"quant_method"] if hasattr(vllm_config.model_config.hf_config,
|
|
446
|
+
"quantization_config") else None
|
|
447
|
+
default_quantization_config = get_default_qwix_quantization_config(
|
|
448
|
+
model_type, quant_method,
|
|
449
|
+
vllm_config.additional_config.get("skip_quantization", False))
|
|
450
|
+
|
|
451
|
+
maybe_existing_quantization_config = vllm_config.additional_config.get(
|
|
452
|
+
"quantization")
|
|
453
|
+
if maybe_existing_quantization_config:
|
|
454
|
+
logger.warning("Overwriting default Qwix quantization config with "
|
|
455
|
+
"user provided quantization config.")
|
|
456
|
+
elif default_quantization_config is not None:
|
|
457
|
+
vllm_config.additional_config[
|
|
458
|
+
"quantization"] = default_quantization_config
|
|
459
|
+
|
|
460
|
+
# Validate additional config
|
|
461
|
+
if additional_config := vllm_config.additional_config:
|
|
462
|
+
# Try loading/parsing the quantization config so that we can fail fast
|
|
463
|
+
if quantization_config := additional_config.get("quantization"):
|
|
464
|
+
try:
|
|
465
|
+
# NOTE: Qwix quantization supports two paths:
|
|
466
|
+
# 1. quantization config file (which we need to parse to a dictionary)
|
|
467
|
+
# 2. quantization config JSON
|
|
468
|
+
if isinstance(quantization_config, str):
|
|
469
|
+
quantization_config = quantization_config_file_path_to_dict(
|
|
470
|
+
quantization_config)
|
|
471
|
+
# NOTE: unpack the quantization config now so we don't need to keep doing this every time
|
|
472
|
+
vllm_config.additional_config[
|
|
473
|
+
"quantization"] = quantization_config
|
|
474
|
+
parse_qwix_config_to_rules(
|
|
475
|
+
quantization_config["qwix"]["rules"])
|
|
476
|
+
except Exception as e:
|
|
477
|
+
raise ValueError(
|
|
478
|
+
f"Invalid quantization config; please see README for details on quantization config: {e}"
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def get_random_sharded_array(key: PRNGKey, mesh: Mesh, param: nnx.Param,
|
|
483
|
+
param_shape: tuple, dtype: jnp.dtype,
|
|
484
|
+
param_name: str) -> jax.Array:
|
|
485
|
+
"""
|
|
486
|
+
Returns a random sharded array for the given parameter for the given shape.
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
key: The random key.
|
|
490
|
+
mesh: The mesh to use for sharding.
|
|
491
|
+
param: The parameter.
|
|
492
|
+
param_shape: The shape of the parameter.
|
|
493
|
+
dtype: The dtype of the parameter.
|
|
494
|
+
param_name: The name of the parameter.
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
A random sharded array for the given parameter for the given shape.
|
|
498
|
+
"""
|
|
499
|
+
is_int = jnp.issubdtype(dtype, jnp.integer)
|
|
500
|
+
if is_int:
|
|
501
|
+
# These need to be JAX arrays or else you'll run into an overflow error
|
|
502
|
+
minval = jnp.array(jnp.iinfo(dtype).min, dtype=dtype)
|
|
503
|
+
maxval = jnp.array(jnp.iinfo(dtype).max, dtype=dtype)
|
|
504
|
+
weight = jax.random.randint(key, param_shape, minval, maxval, dtype)
|
|
505
|
+
else:
|
|
506
|
+
weight = jax.random.normal(key, param_shape, dtype)
|
|
507
|
+
|
|
508
|
+
def get_slice(index):
|
|
509
|
+
return weight[index]
|
|
510
|
+
|
|
511
|
+
try:
|
|
512
|
+
sharded_array = jax.make_array_from_callback(
|
|
513
|
+
param_shape, NamedSharding(mesh, P(*param.sharding)), get_slice)
|
|
514
|
+
except (ValueError, TypeError):
|
|
515
|
+
logger.warning(
|
|
516
|
+
f"Could not create sharded scale for {param_name} with shape {param_shape} and sharding {param.sharding}, skipping sharding..."
|
|
517
|
+
)
|
|
518
|
+
sharded_array = jax.make_array_from_callback(param_shape,
|
|
519
|
+
NamedSharding(mesh, P()),
|
|
520
|
+
get_slice)
|
|
521
|
+
|
|
522
|
+
return sharded_array
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
|
|
526
|
+
model: nnx.Module, mesh: Mesh,
|
|
527
|
+
quantization_config: dict):
|
|
528
|
+
"""
|
|
529
|
+
Loads random weights for an abstract, Qwix-quantized model.
|
|
530
|
+
|
|
531
|
+
Args:
|
|
532
|
+
rng: The random key.
|
|
533
|
+
state: The state of the model.
|
|
534
|
+
mesh: The mesh.
|
|
535
|
+
model: The model.
|
|
536
|
+
quantization_config: The quantization config for the model.
|
|
537
|
+
"""
|
|
538
|
+
logger.info("Initializing Qwix-quantized model with random weights...")
|
|
539
|
+
# TODO (jacobplatin): clean up this logic
|
|
540
|
+
scale_dtype = model.weight_loader.scale_dtype
|
|
541
|
+
scale_shape_map = model.weight_loader.scale_shap_map_for_random_weight_loading if hasattr(
|
|
542
|
+
model.weight_loader,
|
|
543
|
+
'scale_shap_map_for_random_weight_loading') else {}
|
|
544
|
+
quantization_block_sizes = quantization_config["weight_block_size"]
|
|
545
|
+
assert len(
|
|
546
|
+
quantization_block_sizes
|
|
547
|
+
) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
|
|
548
|
+
quantization_block_size_n, _ = quantization_block_sizes[
|
|
549
|
+
0], quantization_block_sizes[1]
|
|
550
|
+
|
|
551
|
+
# Iterate through all variables and initialize them
|
|
552
|
+
prev_param_shape = None
|
|
553
|
+
for path, param in nnx.iter_graph(model):
|
|
554
|
+
if not isinstance(param, nnx.Variable):
|
|
555
|
+
continue
|
|
556
|
+
if path[0] == 'rng' and path[-1] == "key":
|
|
557
|
+
param.value = rng
|
|
558
|
+
continue
|
|
559
|
+
is_qwix_scale = (path[-1] == 'scale' and path[-2] == "array")
|
|
560
|
+
param_dtype = scale_dtype if is_qwix_scale else param.value.dtype
|
|
561
|
+
param_shape = param.value.shape
|
|
562
|
+
# TODO (jacobplatin): clean this up
|
|
563
|
+
if is_qwix_scale:
|
|
564
|
+
param_shape = scale_shape_map.get(
|
|
565
|
+
path[3],
|
|
566
|
+
tuple(dim // quantization_block_size_n
|
|
567
|
+
for dim in prev_param_shape))
|
|
568
|
+
param.value = get_random_sharded_array(
|
|
569
|
+
rng, mesh, param, param_shape, param_dtype,
|
|
570
|
+
".".join([str(x) for x in path]))
|
|
571
|
+
prev_param_shape = param_shape
|
|
572
|
+
|
|
573
|
+
# Handles the DeepSeek case, where this needs to be called to make the cache weights
|
|
574
|
+
# concrete
|
|
575
|
+
if hasattr(model, 'initialize_cache'):
|
|
576
|
+
model.initialize_cache()
|
|
577
|
+
logger.info("Done initializing Qwix-quantized model with random weights")
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def manually_quantize_qwix_weight(weight: jax.Array, qtype: jnp.dtype,
|
|
581
|
+
channelwise_axes: List[int],
|
|
582
|
+
tiled_axes: dict,
|
|
583
|
+
calibration_method: str) -> QArray:
|
|
584
|
+
"""
|
|
585
|
+
Manually quantizes a weight tensor using Qwix. Only needed for the SparseMatmul DeepSeek case right now, since
|
|
586
|
+
otherwise, Qwix will handle this automatically (through our application of `qwix.quantize_model`).
|
|
587
|
+
"""
|
|
588
|
+
# TODO (jacobplatin): clean this up; this is needed because of issues with Qwix quantizing the `shard_map` in SpraseMatmul
|
|
589
|
+
how_to_quantize = ptq.qarray.HowToQuantize(
|
|
590
|
+
qtype=qtype,
|
|
591
|
+
channelwise_axes=channelwise_axes,
|
|
592
|
+
tiled_axes=tiled_axes,
|
|
593
|
+
calibration_method=calibration_method)
|
|
594
|
+
|
|
595
|
+
return ptq.create_quantized_param(weight, how_to_quantize)
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
def manually_quantize_qwix_activation(inputs: jax.Array, rule_name: str,
|
|
599
|
+
qtype: jnp.dtype,
|
|
600
|
+
channelwise_axes: List[int],
|
|
601
|
+
tiled_axes: dict,
|
|
602
|
+
calibration_method: str) -> QArray:
|
|
603
|
+
"""
|
|
604
|
+
Manually quantizes an activation tensor using Qwix. Needed for the SparseMatmul
|
|
605
|
+
DeepSeek MoE case currently.
|
|
606
|
+
|
|
607
|
+
Args:
|
|
608
|
+
inputs: The activation tensor to quantize.
|
|
609
|
+
rule_name: The name of the quantization rule to use.
|
|
610
|
+
qtype: The quantization type.
|
|
611
|
+
channelwise_axes: The channelwise axes to quantize.
|
|
612
|
+
tiled_axes: The tiled axes to quantize.
|
|
613
|
+
calibration_method: The calibration method to use.
|
|
614
|
+
|
|
615
|
+
Returns:
|
|
616
|
+
The quantized activation tensor.
|
|
617
|
+
"""
|
|
618
|
+
rule = qpl.get_current_rule(rule_name)
|
|
619
|
+
lhs_how = ptq.qarray.HowToQuantize(qtype=qtype,
|
|
620
|
+
channelwise_axes=channelwise_axes,
|
|
621
|
+
tiled_axes=tiled_axes,
|
|
622
|
+
calibration_method=calibration_method)
|
|
623
|
+
# This is needed because we aren't passing `act_name` right now
|
|
624
|
+
assert not rule.act_static_scale, "Static scale not supported right now"
|
|
625
|
+
|
|
626
|
+
# channelwise_axes should be set to (a subset of) non-contraction axes. e.g.
|
|
627
|
+
# for ragged_dot [m, k] x [g, k, n], they are [0] and [0, 2]
|
|
628
|
+
# TODO (jacobplatin): add support for `act_name`
|
|
629
|
+
return ptq.quantize_act(inputs, lhs_how, rule, "")
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
def get_quant_dtype_from_qwix_config(
|
|
633
|
+
vllm_config: "VllmConfig") -> tuple[jnp.dtype, jnp.dtype]:
|
|
634
|
+
"""
|
|
635
|
+
Gets the quantization dtype from the Qwix config.
|
|
636
|
+
|
|
637
|
+
Args:
|
|
638
|
+
vllm_config: The VllmConfig object.
|
|
639
|
+
|
|
640
|
+
Returns:
|
|
641
|
+
A tuple of the scale dtype and quant dtype.
|
|
642
|
+
"""
|
|
643
|
+
qwix_config = vllm_config.additional_config.get("quantization",
|
|
644
|
+
{}).get("qwix", {})
|
|
645
|
+
scale_dtype = getattr(jnp, qwix_config.get("scale_dtype", "bfloat16"))
|
|
646
|
+
quant_dtype = None
|
|
647
|
+
# TODO (jacobplatin): this needs to be much more robust
|
|
648
|
+
for rule in qwix_config.get("rules", []):
|
|
649
|
+
if rule.get("module_path") == ".*":
|
|
650
|
+
quant_dtype_str = rule.get("weight_qtype", "")
|
|
651
|
+
assert quant_dtype_str, "Quantization dtype not found in Qwix config! We currently expect your Qwix config to have a rule with module_path '.*' and a weight_qtype."
|
|
652
|
+
quant_dtype = getattr(jnp, quant_dtype_str)
|
|
653
|
+
return scale_dtype, quant_dtype
|