tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,653 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import functools
4
+ import os
5
+ from typing import TYPE_CHECKING, Callable, List
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import qwix
10
+ import qwix.pallas as qpl
11
+ import yaml
12
+ from flax import nnx
13
+ from flax.typing import PRNGKey
14
+ from jax.sharding import Mesh, NamedSharding
15
+ from jax.sharding import PartitionSpec as P
16
+ from qwix._src.core.qarray import QArray
17
+ from qwix._src.providers import ptq
18
+
19
+ if TYPE_CHECKING:
20
+ from vllm.config import VllmConfig
21
+
22
+ from tpu_inference import utils
23
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
24
+ from tpu_inference.logger import init_logger
25
+ from tpu_inference.runner.kv_cache import (DEFAULT_KV_CACHE_DTYPE,
26
+ create_kv_caches)
27
+ from tpu_inference.utils import device_array
28
+
29
+ logger = init_logger(__name__)
30
+
31
+ QUANTIZATION_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs")
32
+ DEFAULT_NUM_BLOCKS_FOR_JIT_KV_CACHE = 2000
33
+ DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS = 512
34
+ DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS = 256
35
+ DEFAULT_MAX_NUM_BLOCKS_PER_REQ = 16
36
+
37
+ DEFAULT_DEEPSEEK_FP8_CONFIG = {
38
+ "qwix": {
39
+ "use_abstract_model":
40
+ True,
41
+ "scale_dtype":
42
+ "bfloat16",
43
+ "rules": [
44
+ {
45
+ "module_path": ".*.custom_module.router.*",
46
+ "weight_qtype": None,
47
+ },
48
+ {
49
+ "module_path": ".*",
50
+ "weight_qtype": "float8_e4m3fn",
51
+ "act_qtype": "float8_e4m3fn",
52
+ },
53
+ ],
54
+ }
55
+ }
56
+
57
+ DEFAULT_LLAMA4_FP8_CONFIG = {
58
+ "qwix": {
59
+ "use_abstract_model":
60
+ True,
61
+ "scale_dtype":
62
+ "bfloat16",
63
+ "rules": [
64
+ {
65
+ "module_path": "layers.*.moe_ffw",
66
+ "op_names": "einsum",
67
+ "weight_qtype": "float8_e4m3fn",
68
+ "act_qtype": "float8_e4m3fn",
69
+ },
70
+ ],
71
+ }
72
+ }
73
+
74
+ # Default Qwix config for GPT-OSS MXFP4 checkpoints.
75
+ # Notes:
76
+ # - We quantize only the MoE expert weights by default (router stays in BF16).
77
+ # - We use Qwix's abstract-model path so weights can be set directly into QArray
78
+ # fields during weight loading (similar to DeepSeek's flow).
79
+ # - Activation quantization is not set but Qwix would pickup MoE sum if activated
80
+ DEFAULT_GPT_OSS_FP4_CONFIG = {
81
+ "qwix": {
82
+ "use_abstract_model":
83
+ True,
84
+ "scale_dtype":
85
+ "bfloat16",
86
+ "rules": [
87
+ {
88
+ "module_path": ".*custom_module",
89
+ "weight_qtype": "float4_e2m1fn",
90
+ "act_qtype": None,
91
+ "tile_size": 32,
92
+ },
93
+ ],
94
+ }
95
+ }
96
+
97
+
98
+ def parse_qwix_config_to_rules(
99
+ qwix_config: List[dict]) -> List[qwix.QuantizationRule]:
100
+ """
101
+ Parse a list of dictionaries containing Qwix quantization rules into a list of QuantizationRule objects.
102
+
103
+ Args:
104
+ qwix_config: a dictionary containing the Qwix quantization rules
105
+
106
+ Returns:
107
+ a list of QuantizationRule objects
108
+ """
109
+ rules = []
110
+ for rule in qwix_config:
111
+ rules.append(qwix.QuantizationRule(**rule))
112
+
113
+ return rules
114
+
115
+
116
+ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
117
+ rng: jax.Array, mesh: Mesh, num_hidden_layers: int,
118
+ kv_cache_block_size: int,
119
+ kv_cache_num_kv_heads: int,
120
+ kv_cache_head_size: int,
121
+ kv_cache_dtype: str) -> nnx.Module:
122
+ """
123
+ Quantizes a Flax NNX model using Qwix.
124
+
125
+ Args:
126
+ model: the model to quantize
127
+ qwix_config: a list of dictionaries, where each dictionary corresponds to a Qwix quantization rule
128
+ For example:
129
+ [
130
+ {
131
+ "module_path": ".*attn.*",
132
+ "weight_qtype": "int8",
133
+ },
134
+ {
135
+ "module_path": ".*mlp.*",
136
+ "weight_qtype": "int8",
137
+ "act_qtype": "int8",
138
+ "tile_size": None,
139
+ },
140
+ ]
141
+ rng: the random number generator to use
142
+ mesh: the mesh to use
143
+ num_hidden_layers: the number of hidden layers in the model
144
+ kv_cache_page_size: the page size of the kv cache
145
+ kv_cache_num_kv_heads: the number of kv heads
146
+ head_size: the head size of the kv cache
147
+ kv_cache_dtype: the dtype of the kv cache
148
+
149
+ Returns:
150
+ model: the quantized model
151
+ """
152
+ qwix_rules = parse_qwix_config_to_rules(qwix_config)
153
+ logger.info(f"Qwix rules: {qwix_rules}")
154
+ logger.info(f"Memory usage before applying quantization of params: "
155
+ f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
156
+
157
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
158
+ kv_cache_jnp_dtype = utils.get_jax_dtype_from_str_dtype(kv_cache_dtype)
159
+
160
+ # Handle the case where kv_cache_dtype is "auto"
161
+ if kv_cache_jnp_dtype is None:
162
+ assert kv_cache_dtype == "auto", "kv_cache_dtype must be 'auto' if kv_cache_jnp_dtype is None"
163
+ kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
164
+
165
+ kv_caches = create_kv_caches(
166
+ num_blocks=DEFAULT_NUM_BLOCKS_FOR_JIT_KV_CACHE,
167
+ block_size=kv_cache_block_size,
168
+ num_kv_heads=kv_cache_num_kv_heads,
169
+ head_size=kv_cache_head_size,
170
+ mesh=mesh,
171
+ layer_names=[f"layer.{i}" for i in range(num_hidden_layers)],
172
+ cache_dtype=kv_cache_jnp_dtype)
173
+
174
+ dp_size = mesh.shape.get("data", 1) * mesh.shape.get("attn", 1)
175
+
176
+ # NOTE: the inputs don't need to match the actual ones, as long as the consumed weights are the same
177
+ input_ids = jax.random.randint(rng,
178
+ (DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS, ),
179
+ 0,
180
+ 100,
181
+ dtype=jnp.int32)
182
+ positions = jax.random.randint(rng,
183
+ (DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS, ),
184
+ 0,
185
+ 100,
186
+ dtype=jnp.int32)
187
+ block_tables = jax.random.randint(rng,
188
+ (DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS *
189
+ DEFAULT_MAX_NUM_BLOCKS_PER_REQ, ),
190
+ 0,
191
+ 100,
192
+ dtype=jnp.int32)
193
+ query_start_loc = jax.random.randint(
194
+ rng, (DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS + dp_size, ),
195
+ 0,
196
+ 100,
197
+ dtype=jnp.int32)
198
+ seq_lens = jax.random.randint(rng,
199
+ (DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS, ),
200
+ 0,
201
+ 100,
202
+ dtype=jnp.int32)
203
+ num_seqs = jax.random.randint(rng, (1, ), 0, 100, dtype=jnp.int32)
204
+ request_distribution = jnp.array([0, 0, num_seqs[0]] * dp_size,
205
+ dtype=jnp.int32)
206
+
207
+ (input_ids, positions, block_tables,
208
+ query_start_loc, seq_lens, request_distribution) = device_array(
209
+ mesh, (input_ids, positions, block_tables, query_start_loc, seq_lens,
210
+ request_distribution))
211
+
212
+ model_input = {
213
+ "kv_caches":
214
+ kv_caches,
215
+ "input_ids":
216
+ input_ids,
217
+ "attention_metadata":
218
+ AttentionMetadata(input_positions=positions,
219
+ block_tables=block_tables,
220
+ seq_lens=seq_lens,
221
+ query_start_loc=query_start_loc,
222
+ request_distribution=request_distribution),
223
+ }
224
+ model = qwix.quantize_model(model, qwix.PtqProvider(qwix_rules),
225
+ **model_input)
226
+ return model
227
+
228
+
229
+ def quantization_config_file_path_to_dict(
230
+ quantization_config_file_path: str) -> dict:
231
+ """
232
+ Converts a quantization config YAML file path to a dictionary.
233
+
234
+ The expected format of the quantization config YAML file is as follows:
235
+ ```yaml
236
+ qwix:
237
+ # optional, defaults to False if not specified
238
+ use_abstract_model: True
239
+ rules:
240
+ # NOTE: each entry corresponds to a qwix.QuantizationRule
241
+ - module_path: '.*attn.*'
242
+ weight_qtype: 'int8'
243
+ - module_path: '.*'
244
+ weight_qtype: 'int8'
245
+ act_qtype: 'int8'
246
+ ```
247
+
248
+ Args:
249
+ quantization_config_file_path: the path to the quantization config YAML file
250
+
251
+ Returns:
252
+ a dictionary containing the quantization config
253
+ """
254
+ all_entries = os.listdir(QUANTIZATION_CONFIG_PATH)
255
+ for filename in all_entries:
256
+ if filename == quantization_config_file_path:
257
+ path = os.path.join(QUANTIZATION_CONFIG_PATH, filename)
258
+ with open(path, "r") as f:
259
+ return yaml.safe_load(f)
260
+ raise ValueError(
261
+ f"Could not find quantization config file with name '{quantization_config_file_path}' in 'tpu_inference/models/jax/utils/quantization/configs."
262
+ )
263
+
264
+
265
+ def apply_qwix_quantization(
266
+ vllm_config: "VllmConfig", model_or_model_fn: Callable | nnx.Module,
267
+ rng: jax.Array, mesh: Mesh,
268
+ apply_to_abstract_model: bool) -> nnx.Module | Callable:
269
+ """
270
+ Will apply quantization if a valid quantization config with Qwix rules is provided. See README
271
+ for more details on Qwix.
272
+
273
+ Note that we currently support different methods for applying Qwix quantization. The typical
274
+ approach is to apply quantization on the concrete model, which already has the weights
275
+ loaded in. However, for models like DeepSeek, which are already quantized, we need to
276
+ first create the abstract model, then apply Qwix quantization to the abstract model, and
277
+ finally load the weights in. To use the latter approach, you will need to modify the
278
+ model weight loading code appropriately (see deepseek_v3.py for an example) and
279
+ pass and `use_abstract_model=True` in the quantization config.
280
+
281
+ Args:
282
+ vllm_config: the base VLLM config
283
+ model_or_model_fn: if `apply_to_abstract_model` is True, this will be a Callable that returns the abstract model
284
+ (e.g. _create_abstract_model). Otherwise, this will be the concrete model (nnx.Module).
285
+ rng: JAX RNG
286
+ mesh: model Mesh
287
+ apply_to_abstract_model: if True, we will apply Qwix quantization to the abstract model, which
288
+ assumes that, during weight loading, the caller will thus override the QArray weights
289
+ (see deepseek_v3.py for an example). Otherwise, we will will apply Qwix quantization to the
290
+ concrete model, which already has the weights loaded in.
291
+
292
+ Returns:
293
+ Either the concrete model (nnx.Module) or the abstract model (Callable) (if `apply_to_abstract_model` is True)
294
+ """
295
+ qwix_config = None
296
+ if quantization_config := vllm_config.additional_config.get(
297
+ "quantization"):
298
+ qwix_config = quantization_config.get("qwix").get("rules")
299
+ if not qwix_config:
300
+ return model_or_model_fn
301
+
302
+ logging_abstract_model_str = "abstract" if apply_to_abstract_model else "concrete"
303
+ logger.info(
304
+ f"Applying Qwix quantization on {logging_abstract_model_str} model")
305
+
306
+ block_size = vllm_config.cache_config.block_size
307
+ model_config = vllm_config.model_config
308
+
309
+ # Pad num_kv_heads to multiple of TP size
310
+ num_kv_heads = utils.get_padded_num_heads(
311
+ model_config.get_total_num_kv_heads(), mesh.shape["model"])
312
+
313
+ # Pad head_dim to multiple of 128
314
+ head_size = model_config.get_head_size()
315
+ head_size = utils.get_padded_head_dim(head_size)
316
+
317
+ kv_cache_dtype = vllm_config.cache_config.cache_dtype
318
+
319
+ if not apply_to_abstract_model:
320
+ assert isinstance(model_or_model_fn, nnx.Module)
321
+ qwix_quantize_nnx_model_with_config = functools.partial(
322
+ qwix_quantize_nnx_model, qwix_config=qwix_config)
323
+ # NOTE: it's REALLY important `qwix_quantize_nnx_model_with_config` is jitted
324
+ # or else you'll run into hanging
325
+ model_or_model_fn = nnx.jit(
326
+ qwix_quantize_nnx_model_with_config,
327
+ donate_argnums=(0, ),
328
+ static_argnames=(
329
+ "mesh",
330
+ "num_hidden_layers",
331
+ "kv_cache_block_size",
332
+ "kv_cache_num_kv_heads",
333
+ "kv_cache_head_size",
334
+ "kv_cache_dtype",
335
+ ))(model=model_or_model_fn,
336
+ rng=rng,
337
+ mesh=mesh,
338
+ num_hidden_layers=vllm_config.model_config.hf_config.
339
+ num_hidden_layers,
340
+ kv_cache_block_size=block_size,
341
+ kv_cache_num_kv_heads=num_kv_heads,
342
+ kv_cache_head_size=head_size,
343
+ kv_cache_dtype=kv_cache_dtype)
344
+
345
+ return model_or_model_fn
346
+
347
+ hf_config = vllm_config.model_config.hf_config
348
+ if hasattr(hf_config, "text_config") and hasattr(hf_config.text_config,
349
+ "num_hidden_layers"):
350
+ num_hidden_layers = hf_config.text_config.num_hidden_layers
351
+ logger.info(
352
+ f"Using num_hidden_layers from hf_config.text_config: {num_hidden_layers}"
353
+ )
354
+ elif hasattr(hf_config, "num_hidden_layers"):
355
+ num_hidden_layers = hf_config.num_hidden_layers
356
+ logger.info(
357
+ f"Using num_hidden_layers directly from hf_config: {num_hidden_layers}"
358
+ )
359
+ else:
360
+ raise AttributeError(
361
+ "Could not find 'num_hidden_layers' in hf_config or hf_config.text_config."
362
+ )
363
+
364
+ qwix_quantize_fn_for_eval = functools.partial(
365
+ qwix_quantize_nnx_model,
366
+ qwix_config=qwix_config,
367
+ mesh=mesh,
368
+ num_hidden_layers=num_hidden_layers,
369
+ kv_cache_block_size=block_size,
370
+ kv_cache_num_kv_heads=num_kv_heads,
371
+ kv_cache_head_size=head_size,
372
+ kv_cache_dtype=kv_cache_dtype)
373
+
374
+ def create_and_quantize_model_factory() -> Callable:
375
+ """
376
+ Helper function to create and quantize the abstract model.
377
+ """
378
+ model = model_or_model_fn()
379
+ # Handle the DeepSeek case, where this needs to be called in the abstract model
380
+ if hasattr(model, 'initialize_cache'):
381
+ model.initialize_cache()
382
+ return qwix_quantize_fn_for_eval(model=model, rng=rng)
383
+
384
+ return create_and_quantize_model_factory
385
+
386
+
387
+ def apply_qwix_on_abstract_model(vllm_config: "VllmConfig") -> bool:
388
+ """
389
+ Determines whether to apply Qwix quantization on the abstract model (e.g. for DeepSeek)
390
+ or the concrete model. See `apply_qwix_quantization` for more details on the differences
391
+ between these two approaches.
392
+ Args:
393
+ vllm_config: the vllm config
394
+ Returns:
395
+ whether to apply Qwix quantization on the abstract model
396
+ """
397
+ quantization_config = vllm_config.additional_config.get("quantization", {})
398
+ return quantization_config.get("qwix", {}).get("use_abstract_model", False)
399
+
400
+
401
+ def get_default_qwix_quantization_config(
402
+ model_type: str, quant_method: str,
403
+ skip_quantization: bool) -> dict | None:
404
+ """
405
+ Some models are pre-quantized and in those cases, we want to return a default set of
406
+ Qwix quantization rules (instead of forcing the user to pass in a quantization config each time).
407
+
408
+ Note that if a user passes in a quantization config (via `additional_config`), then
409
+ we'll use that instead of this function.
410
+
411
+ Args:
412
+ model_type: the name of the model
413
+ quant_method: the quantization method
414
+ skip_quantization: whether to skip quantization. In this case, we'll return None
415
+
416
+ Returns:
417
+ a dictionary containing the default Qwix quantization rules
418
+ """
419
+ if skip_quantization:
420
+ return None
421
+ # TODO (jacobplatin): remove this so that we can support various quantization types
422
+ if model_type == "deepseek_v3" and quant_method == "fp8":
423
+ return DEFAULT_DEEPSEEK_FP8_CONFIG
424
+ elif model_type == "llama4" and quant_method == "compressed-tensors":
425
+ return DEFAULT_LLAMA4_FP8_CONFIG
426
+ # MXFP4 (GPT-OSS): provide a default configuration to quantize MoE experts via Qwix
427
+ elif model_type == "gpt_oss" and quant_method == "mxfp4":
428
+ return DEFAULT_GPT_OSS_FP4_CONFIG
429
+
430
+
431
+ def update_vllm_config_for_qwix_quantization(vllm_config: "VllmConfig"):
432
+ """
433
+ Updates the vLLM config to unpack the Qwix quantization config if it exists.
434
+ By default, we'll check if the checkpoint is quantized and update the
435
+ Qwix quantization config to use the default quantization config if it exists,
436
+ but we'll override this if the user passes in a quantization config via `additional_config`.
437
+ """
438
+ # Automatically detect whether checkpoint is quantized and update the
439
+ # Qwix quantization config accordingly
440
+ # NOTE: if a Qwix config is provided (via the`additional_config`), we'll
441
+ # use that instead
442
+ model_type = vllm_config.model_config.hf_config.model_type.lower(
443
+ ) if hasattr(vllm_config.model_config.hf_config, "model_type") else None
444
+ quant_method = vllm_config.model_config.hf_config.quantization_config[
445
+ "quant_method"] if hasattr(vllm_config.model_config.hf_config,
446
+ "quantization_config") else None
447
+ default_quantization_config = get_default_qwix_quantization_config(
448
+ model_type, quant_method,
449
+ vllm_config.additional_config.get("skip_quantization", False))
450
+
451
+ maybe_existing_quantization_config = vllm_config.additional_config.get(
452
+ "quantization")
453
+ if maybe_existing_quantization_config:
454
+ logger.warning("Overwriting default Qwix quantization config with "
455
+ "user provided quantization config.")
456
+ elif default_quantization_config is not None:
457
+ vllm_config.additional_config[
458
+ "quantization"] = default_quantization_config
459
+
460
+ # Validate additional config
461
+ if additional_config := vllm_config.additional_config:
462
+ # Try loading/parsing the quantization config so that we can fail fast
463
+ if quantization_config := additional_config.get("quantization"):
464
+ try:
465
+ # NOTE: Qwix quantization supports two paths:
466
+ # 1. quantization config file (which we need to parse to a dictionary)
467
+ # 2. quantization config JSON
468
+ if isinstance(quantization_config, str):
469
+ quantization_config = quantization_config_file_path_to_dict(
470
+ quantization_config)
471
+ # NOTE: unpack the quantization config now so we don't need to keep doing this every time
472
+ vllm_config.additional_config[
473
+ "quantization"] = quantization_config
474
+ parse_qwix_config_to_rules(
475
+ quantization_config["qwix"]["rules"])
476
+ except Exception as e:
477
+ raise ValueError(
478
+ f"Invalid quantization config; please see README for details on quantization config: {e}"
479
+ )
480
+
481
+
482
+ def get_random_sharded_array(key: PRNGKey, mesh: Mesh, param: nnx.Param,
483
+ param_shape: tuple, dtype: jnp.dtype,
484
+ param_name: str) -> jax.Array:
485
+ """
486
+ Returns a random sharded array for the given parameter for the given shape.
487
+
488
+ Args:
489
+ key: The random key.
490
+ mesh: The mesh to use for sharding.
491
+ param: The parameter.
492
+ param_shape: The shape of the parameter.
493
+ dtype: The dtype of the parameter.
494
+ param_name: The name of the parameter.
495
+
496
+ Returns:
497
+ A random sharded array for the given parameter for the given shape.
498
+ """
499
+ is_int = jnp.issubdtype(dtype, jnp.integer)
500
+ if is_int:
501
+ # These need to be JAX arrays or else you'll run into an overflow error
502
+ minval = jnp.array(jnp.iinfo(dtype).min, dtype=dtype)
503
+ maxval = jnp.array(jnp.iinfo(dtype).max, dtype=dtype)
504
+ weight = jax.random.randint(key, param_shape, minval, maxval, dtype)
505
+ else:
506
+ weight = jax.random.normal(key, param_shape, dtype)
507
+
508
+ def get_slice(index):
509
+ return weight[index]
510
+
511
+ try:
512
+ sharded_array = jax.make_array_from_callback(
513
+ param_shape, NamedSharding(mesh, P(*param.sharding)), get_slice)
514
+ except (ValueError, TypeError):
515
+ logger.warning(
516
+ f"Could not create sharded scale for {param_name} with shape {param_shape} and sharding {param.sharding}, skipping sharding..."
517
+ )
518
+ sharded_array = jax.make_array_from_callback(param_shape,
519
+ NamedSharding(mesh, P()),
520
+ get_slice)
521
+
522
+ return sharded_array
523
+
524
+
525
+ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
526
+ model: nnx.Module, mesh: Mesh,
527
+ quantization_config: dict):
528
+ """
529
+ Loads random weights for an abstract, Qwix-quantized model.
530
+
531
+ Args:
532
+ rng: The random key.
533
+ state: The state of the model.
534
+ mesh: The mesh.
535
+ model: The model.
536
+ quantization_config: The quantization config for the model.
537
+ """
538
+ logger.info("Initializing Qwix-quantized model with random weights...")
539
+ # TODO (jacobplatin): clean up this logic
540
+ scale_dtype = model.weight_loader.scale_dtype
541
+ scale_shape_map = model.weight_loader.scale_shap_map_for_random_weight_loading if hasattr(
542
+ model.weight_loader,
543
+ 'scale_shap_map_for_random_weight_loading') else {}
544
+ quantization_block_sizes = quantization_config["weight_block_size"]
545
+ assert len(
546
+ quantization_block_sizes
547
+ ) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
548
+ quantization_block_size_n, _ = quantization_block_sizes[
549
+ 0], quantization_block_sizes[1]
550
+
551
+ # Iterate through all variables and initialize them
552
+ prev_param_shape = None
553
+ for path, param in nnx.iter_graph(model):
554
+ if not isinstance(param, nnx.Variable):
555
+ continue
556
+ if path[0] == 'rng' and path[-1] == "key":
557
+ param.value = rng
558
+ continue
559
+ is_qwix_scale = (path[-1] == 'scale' and path[-2] == "array")
560
+ param_dtype = scale_dtype if is_qwix_scale else param.value.dtype
561
+ param_shape = param.value.shape
562
+ # TODO (jacobplatin): clean this up
563
+ if is_qwix_scale:
564
+ param_shape = scale_shape_map.get(
565
+ path[3],
566
+ tuple(dim // quantization_block_size_n
567
+ for dim in prev_param_shape))
568
+ param.value = get_random_sharded_array(
569
+ rng, mesh, param, param_shape, param_dtype,
570
+ ".".join([str(x) for x in path]))
571
+ prev_param_shape = param_shape
572
+
573
+ # Handles the DeepSeek case, where this needs to be called to make the cache weights
574
+ # concrete
575
+ if hasattr(model, 'initialize_cache'):
576
+ model.initialize_cache()
577
+ logger.info("Done initializing Qwix-quantized model with random weights")
578
+
579
+
580
+ def manually_quantize_qwix_weight(weight: jax.Array, qtype: jnp.dtype,
581
+ channelwise_axes: List[int],
582
+ tiled_axes: dict,
583
+ calibration_method: str) -> QArray:
584
+ """
585
+ Manually quantizes a weight tensor using Qwix. Only needed for the SparseMatmul DeepSeek case right now, since
586
+ otherwise, Qwix will handle this automatically (through our application of `qwix.quantize_model`).
587
+ """
588
+ # TODO (jacobplatin): clean this up; this is needed because of issues with Qwix quantizing the `shard_map` in SpraseMatmul
589
+ how_to_quantize = ptq.qarray.HowToQuantize(
590
+ qtype=qtype,
591
+ channelwise_axes=channelwise_axes,
592
+ tiled_axes=tiled_axes,
593
+ calibration_method=calibration_method)
594
+
595
+ return ptq.create_quantized_param(weight, how_to_quantize)
596
+
597
+
598
+ def manually_quantize_qwix_activation(inputs: jax.Array, rule_name: str,
599
+ qtype: jnp.dtype,
600
+ channelwise_axes: List[int],
601
+ tiled_axes: dict,
602
+ calibration_method: str) -> QArray:
603
+ """
604
+ Manually quantizes an activation tensor using Qwix. Needed for the SparseMatmul
605
+ DeepSeek MoE case currently.
606
+
607
+ Args:
608
+ inputs: The activation tensor to quantize.
609
+ rule_name: The name of the quantization rule to use.
610
+ qtype: The quantization type.
611
+ channelwise_axes: The channelwise axes to quantize.
612
+ tiled_axes: The tiled axes to quantize.
613
+ calibration_method: The calibration method to use.
614
+
615
+ Returns:
616
+ The quantized activation tensor.
617
+ """
618
+ rule = qpl.get_current_rule(rule_name)
619
+ lhs_how = ptq.qarray.HowToQuantize(qtype=qtype,
620
+ channelwise_axes=channelwise_axes,
621
+ tiled_axes=tiled_axes,
622
+ calibration_method=calibration_method)
623
+ # This is needed because we aren't passing `act_name` right now
624
+ assert not rule.act_static_scale, "Static scale not supported right now"
625
+
626
+ # channelwise_axes should be set to (a subset of) non-contraction axes. e.g.
627
+ # for ragged_dot [m, k] x [g, k, n], they are [0] and [0, 2]
628
+ # TODO (jacobplatin): add support for `act_name`
629
+ return ptq.quantize_act(inputs, lhs_how, rule, "")
630
+
631
+
632
+ def get_quant_dtype_from_qwix_config(
633
+ vllm_config: "VllmConfig") -> tuple[jnp.dtype, jnp.dtype]:
634
+ """
635
+ Gets the quantization dtype from the Qwix config.
636
+
637
+ Args:
638
+ vllm_config: The VllmConfig object.
639
+
640
+ Returns:
641
+ A tuple of the scale dtype and quant dtype.
642
+ """
643
+ qwix_config = vllm_config.additional_config.get("quantization",
644
+ {}).get("qwix", {})
645
+ scale_dtype = getattr(jnp, qwix_config.get("scale_dtype", "bfloat16"))
646
+ quant_dtype = None
647
+ # TODO (jacobplatin): this needs to be much more robust
648
+ for rule in qwix_config.get("rules", []):
649
+ if rule.get("module_path") == ".*":
650
+ quant_dtype_str = rule.get("weight_qtype", "")
651
+ assert quant_dtype_str, "Quantization dtype not found in Qwix config! We currently expect your Qwix config to have a rule with module_path '.*' and a weight_qtype."
652
+ quant_dtype = getattr(jnp, quant_dtype_str)
653
+ return scale_dtype, quant_dtype