tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,650 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import functools
4
+ import os
5
+ from typing import TYPE_CHECKING, Callable, List
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import qwix
10
+ import qwix.pallas as qpl
11
+ import yaml
12
+ from flax import nnx
13
+ from flax.typing import PRNGKey
14
+ from jax.sharding import Mesh, NamedSharding
15
+ from jax.sharding import PartitionSpec as P
16
+ from qwix._src.core.qarray import QArray
17
+ from qwix._src.providers import ptq
18
+
19
+ if TYPE_CHECKING:
20
+ from vllm.config import VllmConfig
21
+
22
+ from tpu_inference import utils
23
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
24
+ from tpu_inference.logger import init_logger
25
+ from tpu_inference.runner.kv_cache import (DEFAULT_KV_CACHE_DTYPE,
26
+ create_kv_caches)
27
+ from tpu_inference.utils import device_array
28
+
29
+ logger = init_logger(__name__)
30
+
31
+ QUANTIZATION_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs")
32
+ DEFAULT_NUM_BLOCKS_FOR_JIT_KV_CACHE = 2000
33
+ DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS = 512
34
+ DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS = 256
35
+ DEFAULT_MAX_NUM_BLOCKS_PER_REQ = 16
36
+
37
+ DEFAULT_DEEPSEEK_FP8_CONFIG = {
38
+ "qwix": {
39
+ "use_abstract_model":
40
+ True,
41
+ "scale_dtype":
42
+ "bfloat16",
43
+ "rules": [
44
+ {
45
+ "module_path": ".*.custom_module.router.*",
46
+ "weight_qtype": None,
47
+ },
48
+ {
49
+ "module_path": ".*",
50
+ "weight_qtype": "float8_e4m3fn",
51
+ "act_qtype": "float8_e4m3fn",
52
+ },
53
+ ],
54
+ }
55
+ }
56
+
57
+ DEFAULT_LLAMA4_FP8_CONFIG = {
58
+ "qwix": {
59
+ "use_abstract_model":
60
+ True,
61
+ "scale_dtype":
62
+ "bfloat16",
63
+ "rules": [
64
+ {
65
+ "module_path": "layers.*.moe_ffw",
66
+ "op_names": "einsum",
67
+ "weight_qtype": "float8_e4m3fn",
68
+ "act_qtype": "float8_e4m3fn",
69
+ },
70
+ ],
71
+ }
72
+ }
73
+
74
+ # Default Qwix config for GPT-OSS MXFP4 checkpoints.
75
+ # Notes:
76
+ # - We quantize only the MoE expert weights by default (router stays in BF16).
77
+ # - We use Qwix's abstract-model path so weights can be set directly into QArray
78
+ # fields during weight loading (similar to DeepSeek's flow).
79
+ # - Activation quantization is not set but Qwix would pickup MoE sum if activated
80
+ DEFAULT_GPT_OSS_FP4_CONFIG = {
81
+ "qwix": {
82
+ "use_abstract_model":
83
+ True,
84
+ "scale_dtype":
85
+ "bfloat16",
86
+ "rules": [
87
+ {
88
+ "module_path": ".*custom_module",
89
+ "weight_qtype": "float4_e2m1fn",
90
+ "act_qtype": None,
91
+ "tile_size": 32,
92
+ },
93
+ ],
94
+ }
95
+ }
96
+
97
+
98
+ def parse_qwix_config_to_rules(
99
+ qwix_config: List[dict]) -> List[qwix.QuantizationRule]:
100
+ """
101
+ Parse a list of dictionaries containing Qwix quantization rules into a list of QuantizationRule objects.
102
+
103
+ Args:
104
+ qwix_config: a dictionary containing the Qwix quantization rules
105
+
106
+ Returns:
107
+ a list of QuantizationRule objects
108
+ """
109
+ rules = []
110
+ for rule in qwix_config:
111
+ rules.append(qwix.QuantizationRule(**rule))
112
+
113
+ return rules
114
+
115
+
116
+ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
117
+ rng: jax.Array, mesh: Mesh, num_hidden_layers: int,
118
+ kv_cache_block_size: int,
119
+ kv_cache_num_kv_heads: int,
120
+ kv_cache_head_size: int,
121
+ kv_cache_dtype: str) -> nnx.Module:
122
+ """
123
+ Quantizes a Flax NNX model using Qwix.
124
+
125
+ Args:
126
+ model: the model to quantize
127
+ qwix_config: a list of dictionaries, where each dictionary corresponds to a Qwix quantization rule
128
+ For example:
129
+ [
130
+ {
131
+ "module_path": ".*attn.*",
132
+ "weight_qtype": "int8",
133
+ },
134
+ {
135
+ "module_path": ".*mlp.*",
136
+ "weight_qtype": "int8",
137
+ "act_qtype": "int8",
138
+ "tile_size": None,
139
+ },
140
+ ]
141
+ rng: the random number generator to use
142
+ mesh: the mesh to use
143
+ num_hidden_layers: the number of hidden layers in the model
144
+ kv_cache_page_size: the page size of the kv cache
145
+ kv_cache_num_kv_heads: the number of kv heads
146
+ head_size: the head size of the kv cache
147
+ kv_cache_dtype: the dtype of the kv cache
148
+
149
+ Returns:
150
+ model: the quantized model
151
+ """
152
+ qwix_rules = parse_qwix_config_to_rules(qwix_config)
153
+ logger.info(f"Qwix rules: {qwix_rules}")
154
+ logger.info(f"Memory usage before applying quantization of params: "
155
+ f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
156
+
157
+ if kv_cache_dtype != "auto":
158
+ kv_cache_jnp_dtype = utils.to_jax_dtype(kv_cache_dtype)
159
+ else:
160
+ kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
161
+
162
+ kv_caches = create_kv_caches(
163
+ num_blocks=DEFAULT_NUM_BLOCKS_FOR_JIT_KV_CACHE,
164
+ block_size=kv_cache_block_size,
165
+ num_kv_heads=kv_cache_num_kv_heads,
166
+ head_size=kv_cache_head_size,
167
+ mesh=mesh,
168
+ layer_names=[f"layer.{i}" for i in range(num_hidden_layers)],
169
+ cache_dtype=kv_cache_jnp_dtype)
170
+
171
+ dp_size = mesh.shape.get("data", 1) * mesh.shape.get("attn", 1)
172
+
173
+ # NOTE: the inputs don't need to match the actual ones, as long as the consumed weights are the same
174
+ input_ids = jax.random.randint(rng,
175
+ (DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS, ),
176
+ 0,
177
+ 100,
178
+ dtype=jnp.int32)
179
+ positions = jax.random.randint(rng,
180
+ (DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS, ),
181
+ 0,
182
+ 100,
183
+ dtype=jnp.int32)
184
+ block_tables = jax.random.randint(rng,
185
+ (DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS *
186
+ DEFAULT_MAX_NUM_BLOCKS_PER_REQ, ),
187
+ 0,
188
+ 100,
189
+ dtype=jnp.int32)
190
+ query_start_loc = jax.random.randint(
191
+ rng, (DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS + dp_size, ),
192
+ 0,
193
+ 100,
194
+ dtype=jnp.int32)
195
+ seq_lens = jax.random.randint(rng,
196
+ (DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS, ),
197
+ 0,
198
+ 100,
199
+ dtype=jnp.int32)
200
+ num_seqs = jax.random.randint(rng, (1, ), 0, 100, dtype=jnp.int32)
201
+ request_distribution = jnp.array([0, 0, num_seqs[0]] * dp_size,
202
+ dtype=jnp.int32)
203
+
204
+ (input_ids, positions, block_tables,
205
+ query_start_loc, seq_lens, request_distribution) = device_array(
206
+ mesh, (input_ids, positions, block_tables, query_start_loc, seq_lens,
207
+ request_distribution))
208
+
209
+ model_input = {
210
+ "kv_caches":
211
+ kv_caches,
212
+ "input_ids":
213
+ input_ids,
214
+ "attention_metadata":
215
+ AttentionMetadata(input_positions=positions,
216
+ block_tables=block_tables,
217
+ seq_lens=seq_lens,
218
+ query_start_loc=query_start_loc,
219
+ request_distribution=request_distribution),
220
+ }
221
+ model = qwix.quantize_model(model, qwix.PtqProvider(qwix_rules),
222
+ **model_input)
223
+ return model
224
+
225
+
226
+ def quantization_config_file_path_to_dict(
227
+ quantization_config_file_path: str) -> dict:
228
+ """
229
+ Converts a quantization config YAML file path to a dictionary.
230
+
231
+ The expected format of the quantization config YAML file is as follows:
232
+ ```yaml
233
+ qwix:
234
+ # optional, defaults to False if not specified
235
+ use_abstract_model: True
236
+ rules:
237
+ # NOTE: each entry corresponds to a qwix.QuantizationRule
238
+ - module_path: '.*attn.*'
239
+ weight_qtype: 'int8'
240
+ - module_path: '.*'
241
+ weight_qtype: 'int8'
242
+ act_qtype: 'int8'
243
+ ```
244
+
245
+ Args:
246
+ quantization_config_file_path: the path to the quantization config YAML file
247
+
248
+ Returns:
249
+ a dictionary containing the quantization config
250
+ """
251
+ all_entries = os.listdir(QUANTIZATION_CONFIG_PATH)
252
+ for filename in all_entries:
253
+ if filename == quantization_config_file_path:
254
+ path = os.path.join(QUANTIZATION_CONFIG_PATH, filename)
255
+ with open(path, "r") as f:
256
+ return yaml.safe_load(f)
257
+ raise ValueError(
258
+ f"Could not find quantization config file with name '{quantization_config_file_path}' in 'tpu_inference/models/jax/utils/quantization/configs."
259
+ )
260
+
261
+
262
+ def apply_qwix_quantization(
263
+ vllm_config: "VllmConfig", model_or_model_fn: Callable | nnx.Module,
264
+ rng: jax.Array, mesh: Mesh,
265
+ apply_to_abstract_model: bool) -> nnx.Module | Callable:
266
+ """
267
+ Will apply quantization if a valid quantization config with Qwix rules is provided. See README
268
+ for more details on Qwix.
269
+
270
+ Note that we currently support different methods for applying Qwix quantization. The typical
271
+ approach is to apply quantization on the concrete model, which already has the weights
272
+ loaded in. However, for models like DeepSeek, which are already quantized, we need to
273
+ first create the abstract model, then apply Qwix quantization to the abstract model, and
274
+ finally load the weights in. To use the latter approach, you will need to modify the
275
+ model weight loading code appropriately (see deepseek_v3.py for an example) and
276
+ pass and `use_abstract_model=True` in the quantization config.
277
+
278
+ Args:
279
+ vllm_config: the base VLLM config
280
+ model_or_model_fn: if `apply_to_abstract_model` is True, this will be a Callable that returns the abstract model
281
+ (e.g. _create_abstract_model). Otherwise, this will be the concrete model (nnx.Module).
282
+ rng: JAX RNG
283
+ mesh: model Mesh
284
+ apply_to_abstract_model: if True, we will apply Qwix quantization to the abstract model, which
285
+ assumes that, during weight loading, the caller will thus override the QArray weights
286
+ (see deepseek_v3.py for an example). Otherwise, we will will apply Qwix quantization to the
287
+ concrete model, which already has the weights loaded in.
288
+
289
+ Returns:
290
+ Either the concrete model (nnx.Module) or the abstract model (Callable) (if `apply_to_abstract_model` is True)
291
+ """
292
+ qwix_config = None
293
+ if quantization_config := vllm_config.additional_config.get(
294
+ "quantization"):
295
+ qwix_config = quantization_config.get("qwix").get("rules")
296
+ if not qwix_config:
297
+ return model_or_model_fn
298
+
299
+ logging_abstract_model_str = "abstract" if apply_to_abstract_model else "concrete"
300
+ logger.info(
301
+ f"Applying Qwix quantization on {logging_abstract_model_str} model")
302
+
303
+ block_size = vllm_config.cache_config.block_size
304
+ model_config = vllm_config.model_config
305
+
306
+ # Pad num_kv_heads to multiple of TP size
307
+ num_kv_heads = utils.get_padded_num_heads(
308
+ model_config.get_total_num_kv_heads(), mesh.shape["model"])
309
+
310
+ # Pad head_dim to multiple of 128
311
+ head_size = model_config.get_head_size()
312
+ head_size = utils.get_padded_head_dim(head_size)
313
+
314
+ kv_cache_dtype = vllm_config.cache_config.cache_dtype
315
+
316
+ if not apply_to_abstract_model:
317
+ assert isinstance(model_or_model_fn, nnx.Module)
318
+ qwix_quantize_nnx_model_with_config = functools.partial(
319
+ qwix_quantize_nnx_model, qwix_config=qwix_config)
320
+ # NOTE: it's REALLY important `qwix_quantize_nnx_model_with_config` is jitted
321
+ # or else you'll run into hanging
322
+ model_or_model_fn = nnx.jit(
323
+ qwix_quantize_nnx_model_with_config,
324
+ donate_argnums=(0, ),
325
+ static_argnames=(
326
+ "mesh",
327
+ "num_hidden_layers",
328
+ "kv_cache_block_size",
329
+ "kv_cache_num_kv_heads",
330
+ "kv_cache_head_size",
331
+ "kv_cache_dtype",
332
+ ))(model=model_or_model_fn,
333
+ rng=rng,
334
+ mesh=mesh,
335
+ num_hidden_layers=vllm_config.model_config.hf_config.
336
+ num_hidden_layers,
337
+ kv_cache_block_size=block_size,
338
+ kv_cache_num_kv_heads=num_kv_heads,
339
+ kv_cache_head_size=head_size,
340
+ kv_cache_dtype=kv_cache_dtype)
341
+
342
+ return model_or_model_fn
343
+
344
+ hf_config = vllm_config.model_config.hf_config
345
+ if hasattr(hf_config, "text_config") and hasattr(hf_config.text_config,
346
+ "num_hidden_layers"):
347
+ num_hidden_layers = hf_config.text_config.num_hidden_layers
348
+ logger.info(
349
+ f"Using num_hidden_layers from hf_config.text_config: {num_hidden_layers}"
350
+ )
351
+ elif hasattr(hf_config, "num_hidden_layers"):
352
+ num_hidden_layers = hf_config.num_hidden_layers
353
+ logger.info(
354
+ f"Using num_hidden_layers directly from hf_config: {num_hidden_layers}"
355
+ )
356
+ else:
357
+ raise AttributeError(
358
+ "Could not find 'num_hidden_layers' in hf_config or hf_config.text_config."
359
+ )
360
+
361
+ qwix_quantize_fn_for_eval = functools.partial(
362
+ qwix_quantize_nnx_model,
363
+ qwix_config=qwix_config,
364
+ mesh=mesh,
365
+ num_hidden_layers=num_hidden_layers,
366
+ kv_cache_block_size=block_size,
367
+ kv_cache_num_kv_heads=num_kv_heads,
368
+ kv_cache_head_size=head_size,
369
+ kv_cache_dtype=kv_cache_dtype)
370
+
371
+ def create_and_quantize_model_factory() -> Callable:
372
+ """
373
+ Helper function to create and quantize the abstract model.
374
+ """
375
+ model = model_or_model_fn()
376
+ # Handle the DeepSeek case, where this needs to be called in the abstract model
377
+ if hasattr(model, 'initialize_cache'):
378
+ model.initialize_cache()
379
+ return qwix_quantize_fn_for_eval(model=model, rng=rng)
380
+
381
+ return create_and_quantize_model_factory
382
+
383
+
384
+ def apply_qwix_on_abstract_model(vllm_config: "VllmConfig") -> bool:
385
+ """
386
+ Determines whether to apply Qwix quantization on the abstract model (e.g. for DeepSeek)
387
+ or the concrete model. See `apply_qwix_quantization` for more details on the differences
388
+ between these two approaches.
389
+ Args:
390
+ vllm_config: the vllm config
391
+ Returns:
392
+ whether to apply Qwix quantization on the abstract model
393
+ """
394
+ quantization_config = vllm_config.additional_config.get("quantization", {})
395
+ return quantization_config.get("qwix", {}).get("use_abstract_model", False)
396
+
397
+
398
+ def get_default_qwix_quantization_config(
399
+ model_type: str, quant_method: str,
400
+ skip_quantization: bool) -> dict | None:
401
+ """
402
+ Some models are pre-quantized and in those cases, we want to return a default set of
403
+ Qwix quantization rules (instead of forcing the user to pass in a quantization config each time).
404
+
405
+ Note that if a user passes in a quantization config (via `additional_config`), then
406
+ we'll use that instead of this function.
407
+
408
+ Args:
409
+ model_type: the name of the model
410
+ quant_method: the quantization method
411
+ skip_quantization: whether to skip quantization. In this case, we'll return None
412
+
413
+ Returns:
414
+ a dictionary containing the default Qwix quantization rules
415
+ """
416
+ if skip_quantization:
417
+ return None
418
+ # TODO (jacobplatin): remove this so that we can support various quantization types
419
+ if model_type == "deepseek_v3" and quant_method == "fp8":
420
+ return DEFAULT_DEEPSEEK_FP8_CONFIG
421
+ elif model_type == "llama4" and quant_method == "compressed-tensors":
422
+ return DEFAULT_LLAMA4_FP8_CONFIG
423
+ # MXFP4 (GPT-OSS): provide a default configuration to quantize MoE experts via Qwix
424
+ elif model_type == "gpt_oss" and quant_method == "mxfp4":
425
+ return DEFAULT_GPT_OSS_FP4_CONFIG
426
+
427
+
428
+ def update_vllm_config_for_qwix_quantization(vllm_config: "VllmConfig"):
429
+ """
430
+ Updates the vLLM config to unpack the Qwix quantization config if it exists.
431
+ By default, we'll check if the checkpoint is quantized and update the
432
+ Qwix quantization config to use the default quantization config if it exists,
433
+ but we'll override this if the user passes in a quantization config via `additional_config`.
434
+ """
435
+ # Automatically detect whether checkpoint is quantized and update the
436
+ # Qwix quantization config accordingly
437
+ # NOTE: if a Qwix config is provided (via the`additional_config`), we'll
438
+ # use that instead
439
+ model_type = vllm_config.model_config.hf_config.model_type.lower(
440
+ ) if hasattr(vllm_config.model_config.hf_config, "model_type") else None
441
+ quant_method = vllm_config.model_config.hf_config.quantization_config[
442
+ "quant_method"] if hasattr(vllm_config.model_config.hf_config,
443
+ "quantization_config") else None
444
+ default_quantization_config = get_default_qwix_quantization_config(
445
+ model_type, quant_method,
446
+ vllm_config.additional_config.get("skip_quantization", False))
447
+
448
+ maybe_existing_quantization_config = vllm_config.additional_config.get(
449
+ "quantization")
450
+ if maybe_existing_quantization_config:
451
+ logger.warning("Overwriting default Qwix quantization config with "
452
+ "user provided quantization config.")
453
+ elif default_quantization_config is not None:
454
+ vllm_config.additional_config[
455
+ "quantization"] = default_quantization_config
456
+
457
+ # Validate additional config
458
+ if additional_config := vllm_config.additional_config:
459
+ # Try loading/parsing the quantization config so that we can fail fast
460
+ if quantization_config := additional_config.get("quantization"):
461
+ try:
462
+ # NOTE: Qwix quantization supports two paths:
463
+ # 1. quantization config file (which we need to parse to a dictionary)
464
+ # 2. quantization config JSON
465
+ if isinstance(quantization_config, str):
466
+ quantization_config = quantization_config_file_path_to_dict(
467
+ quantization_config)
468
+ # NOTE: unpack the quantization config now so we don't need to keep doing this every time
469
+ vllm_config.additional_config[
470
+ "quantization"] = quantization_config
471
+ parse_qwix_config_to_rules(
472
+ quantization_config["qwix"]["rules"])
473
+ except Exception as e:
474
+ raise ValueError(
475
+ f"Invalid quantization config; please see README for details on quantization config: {e}"
476
+ )
477
+
478
+
479
+ def get_random_sharded_array(key: PRNGKey, mesh: Mesh, param: nnx.Param,
480
+ param_shape: tuple, dtype: jnp.dtype,
481
+ param_name: str) -> jax.Array:
482
+ """
483
+ Returns a random sharded array for the given parameter for the given shape.
484
+
485
+ Args:
486
+ key: The random key.
487
+ mesh: The mesh to use for sharding.
488
+ param: The parameter.
489
+ param_shape: The shape of the parameter.
490
+ dtype: The dtype of the parameter.
491
+ param_name: The name of the parameter.
492
+
493
+ Returns:
494
+ A random sharded array for the given parameter for the given shape.
495
+ """
496
+ is_int = jnp.issubdtype(dtype, jnp.integer)
497
+ if is_int:
498
+ # These need to be JAX arrays or else you'll run into an overflow error
499
+ minval = jnp.array(jnp.iinfo(dtype).min, dtype=dtype)
500
+ maxval = jnp.array(jnp.iinfo(dtype).max, dtype=dtype)
501
+ weight = jax.random.randint(key, param_shape, minval, maxval, dtype)
502
+ else:
503
+ weight = jax.random.normal(key, param_shape, dtype)
504
+
505
+ def get_slice(index):
506
+ return weight[index]
507
+
508
+ try:
509
+ sharded_array = jax.make_array_from_callback(
510
+ param_shape, NamedSharding(mesh, P(*param.sharding)), get_slice)
511
+ except (ValueError, TypeError):
512
+ logger.warning(
513
+ f"Could not create sharded scale for {param_name} with shape {param_shape} and sharding {param.sharding}, skipping sharding..."
514
+ )
515
+ sharded_array = jax.make_array_from_callback(param_shape,
516
+ NamedSharding(mesh, P()),
517
+ get_slice)
518
+
519
+ return sharded_array
520
+
521
+
522
+ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
523
+ model: nnx.Module, mesh: Mesh,
524
+ quantization_config: dict):
525
+ """
526
+ Loads random weights for an abstract, Qwix-quantized model.
527
+
528
+ Args:
529
+ rng: The random key.
530
+ state: The state of the model.
531
+ mesh: The mesh.
532
+ model: The model.
533
+ quantization_config: The quantization config for the model.
534
+ """
535
+ logger.info("Initializing Qwix-quantized model with random weights...")
536
+ # TODO (jacobplatin): clean up this logic
537
+ scale_dtype = model.weight_loader.scale_dtype
538
+ scale_shape_map = model.weight_loader.scale_shap_map_for_random_weight_loading if hasattr(
539
+ model.weight_loader,
540
+ 'scale_shap_map_for_random_weight_loading') else {}
541
+ quantization_block_sizes = quantization_config["weight_block_size"]
542
+ assert len(
543
+ quantization_block_sizes
544
+ ) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
545
+ quantization_block_size_n, _ = quantization_block_sizes[
546
+ 0], quantization_block_sizes[1]
547
+
548
+ # Iterate through all variables and initialize them
549
+ prev_param_shape = None
550
+ for path, param in nnx.iter_graph(model):
551
+ if not isinstance(param, nnx.Variable):
552
+ continue
553
+ if path[0] == 'rng' and path[-1] == "key":
554
+ param.value = rng
555
+ continue
556
+ is_qwix_scale = (path[-1] == 'scale' and path[-2] == "array")
557
+ param_dtype = scale_dtype if is_qwix_scale else param.value.dtype
558
+ param_shape = param.value.shape
559
+ # TODO (jacobplatin): clean this up
560
+ if is_qwix_scale:
561
+ param_shape = scale_shape_map.get(
562
+ path[3],
563
+ tuple(dim // quantization_block_size_n
564
+ for dim in prev_param_shape))
565
+ param.value = get_random_sharded_array(
566
+ rng, mesh, param, param_shape, param_dtype,
567
+ ".".join([str(x) for x in path]))
568
+ prev_param_shape = param_shape
569
+
570
+ # Handles the DeepSeek case, where this needs to be called to make the cache weights
571
+ # concrete
572
+ if hasattr(model, 'initialize_cache'):
573
+ model.initialize_cache()
574
+ logger.info("Done initializing Qwix-quantized model with random weights")
575
+
576
+
577
+ def manually_quantize_qwix_weight(weight: jax.Array, qtype: jnp.dtype,
578
+ channelwise_axes: List[int],
579
+ tiled_axes: dict,
580
+ calibration_method: str) -> QArray:
581
+ """
582
+ Manually quantizes a weight tensor using Qwix. Only needed for the SparseMatmul DeepSeek case right now, since
583
+ otherwise, Qwix will handle this automatically (through our application of `qwix.quantize_model`).
584
+ """
585
+ # TODO (jacobplatin): clean this up; this is needed because of issues with Qwix quantizing the `shard_map` in SpraseMatmul
586
+ how_to_quantize = ptq.qarray.HowToQuantize(
587
+ qtype=qtype,
588
+ channelwise_axes=channelwise_axes,
589
+ tiled_axes=tiled_axes,
590
+ calibration_method=calibration_method)
591
+
592
+ return ptq.create_quantized_param(weight, how_to_quantize)
593
+
594
+
595
+ def manually_quantize_qwix_activation(inputs: jax.Array, rule_name: str,
596
+ qtype: jnp.dtype,
597
+ channelwise_axes: List[int],
598
+ tiled_axes: dict,
599
+ calibration_method: str) -> QArray:
600
+ """
601
+ Manually quantizes an activation tensor using Qwix. Needed for the SparseMatmul
602
+ DeepSeek MoE case currently.
603
+
604
+ Args:
605
+ inputs: The activation tensor to quantize.
606
+ rule_name: The name of the quantization rule to use.
607
+ qtype: The quantization type.
608
+ channelwise_axes: The channelwise axes to quantize.
609
+ tiled_axes: The tiled axes to quantize.
610
+ calibration_method: The calibration method to use.
611
+
612
+ Returns:
613
+ The quantized activation tensor.
614
+ """
615
+ rule = qpl.get_current_rule(rule_name)
616
+ lhs_how = ptq.qarray.HowToQuantize(qtype=qtype,
617
+ channelwise_axes=channelwise_axes,
618
+ tiled_axes=tiled_axes,
619
+ calibration_method=calibration_method)
620
+ # This is needed because we aren't passing `act_name` right now
621
+ assert not rule.act_static_scale, "Static scale not supported right now"
622
+
623
+ # channelwise_axes should be set to (a subset of) non-contraction axes. e.g.
624
+ # for ragged_dot [m, k] x [g, k, n], they are [0] and [0, 2]
625
+ # TODO (jacobplatin): add support for `act_name`
626
+ return ptq.quantize_act(inputs, lhs_how, rule, "")
627
+
628
+
629
+ def get_quant_dtype_from_qwix_config(
630
+ vllm_config: "VllmConfig") -> tuple[jnp.dtype, jnp.dtype]:
631
+ """
632
+ Gets the quantization dtype from the Qwix config.
633
+
634
+ Args:
635
+ vllm_config: The VllmConfig object.
636
+
637
+ Returns:
638
+ A tuple of the scale dtype and quant dtype.
639
+ """
640
+ qwix_config = vllm_config.additional_config.get("quantization",
641
+ {}).get("qwix", {})
642
+ scale_dtype = getattr(jnp, qwix_config.get("scale_dtype", "bfloat16"))
643
+ quant_dtype = None
644
+ # TODO (jacobplatin): this needs to be much more robust
645
+ for rule in qwix_config.get("rules", []):
646
+ if rule.get("module_path") == ".*":
647
+ quant_dtype_str = rule.get("weight_qtype", "")
648
+ assert quant_dtype_str, "Quantization dtype not found in Qwix config! We currently expect your Qwix config to have a rule with module_path '.*' and a weight_qtype."
649
+ quant_dtype = getattr(jnp, quant_dtype_str)
650
+ return scale_dtype, quant_dtype