tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,588 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import functools
4
+ import os
5
+ from typing import TYPE_CHECKING, Callable, List
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import qwix
10
+ import qwix.pallas as qpl
11
+ import yaml
12
+ from flax import nnx
13
+ from flax.typing import PRNGKey
14
+ from jax.sharding import Mesh, NamedSharding
15
+ from jax.sharding import PartitionSpec as P
16
+ from qwix._src.core.qarray import QArray
17
+ from qwix._src.providers import ptq
18
+
19
+ if TYPE_CHECKING:
20
+ from vllm.config import VllmConfig
21
+
22
+ from tpu_inference import utils
23
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
24
+ from tpu_inference.logger import init_logger
25
+ from tpu_inference.runner.kv_cache import (DEFAULT_KV_CACHE_DTYPE,
26
+ create_kv_caches)
27
+ from tpu_inference.utils import device_array
28
+
29
+ logger = init_logger(__name__)
30
+
31
+ QUANTIZATION_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs")
32
+ DEFAULT_NUM_BLOCKS_FOR_JIT_KV_CACHE = 2000
33
+ DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS = 512
34
+ DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS = 256
35
+ DEFAULT_MAX_NUM_BLOCKS_PER_REQ = 16
36
+
37
+ DEFAULT_DEEPSEEK_FP8_CONFIG = {
38
+ "qwix": {
39
+ "use_abstract_model":
40
+ True,
41
+ "scale_dtype":
42
+ "bfloat16",
43
+ "rules": [
44
+ {
45
+ "module_path": ".*.custom_module.router.*",
46
+ "weight_qtype": None,
47
+ },
48
+ {
49
+ "module_path": ".*",
50
+ "weight_qtype": "float8_e4m3fn",
51
+ "act_qtype": "float8_e4m3fn",
52
+ },
53
+ ],
54
+ }
55
+ }
56
+
57
+
58
+ def parse_qwix_config_to_rules(
59
+ qwix_config: List[dict]) -> List[qwix.QuantizationRule]:
60
+ """
61
+ Parse a list of dictionaries containing Qwix quantization rules into a list of QuantizationRule objects.
62
+
63
+ Args:
64
+ qwix_config: a dictionary containing the Qwix quantization rules
65
+
66
+ Returns:
67
+ a list of QuantizationRule objects
68
+ """
69
+ rules = []
70
+ for rule in qwix_config:
71
+ rules.append(qwix.QuantizationRule(**rule))
72
+
73
+ return rules
74
+
75
+
76
+ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
77
+ rng: jax.Array, mesh: Mesh, num_hidden_layers: int,
78
+ kv_cache_block_size: int,
79
+ kv_cache_num_kv_heads: int,
80
+ kv_cache_head_size: int,
81
+ kv_cache_dtype: str) -> nnx.Module:
82
+ """
83
+ Quantizes a Flax NNX model using Qwix.
84
+
85
+ Args:
86
+ model: the model to quantize
87
+ qwix_config: a list of dictionaries, where each dictionary corresponds to a Qwix quantization rule
88
+ For example:
89
+ [
90
+ {
91
+ "module_path": ".*attn.*",
92
+ "weight_qtype": "int8",
93
+ },
94
+ {
95
+ "module_path": ".*mlp.*",
96
+ "weight_qtype": "int8",
97
+ "act_qtype": "int8",
98
+ "tile_size": None,
99
+ },
100
+ ]
101
+ rng: the random number generator to use
102
+ mesh: the mesh to use
103
+ num_hidden_layers: the number of hidden layers in the model
104
+ kv_cache_page_size: the page size of the kv cache
105
+ kv_cache_num_kv_heads: the number of kv heads
106
+ head_size: the head size of the kv cache
107
+ kv_cache_dtype: the dtype of the kv cache
108
+
109
+ Returns:
110
+ model: the quantized model
111
+ """
112
+ qwix_rules = parse_qwix_config_to_rules(qwix_config)
113
+ logger.info(f"Qwix rules: {qwix_rules}")
114
+ logger.info(f"Memory usage before applying quantization of params: "
115
+ f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
116
+
117
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
118
+ kv_cache_jnp_dtype = utils.get_jax_dtype_from_str_dtype(kv_cache_dtype)
119
+
120
+ # Handle the case where kv_cache_dtype is "auto"
121
+ if kv_cache_jnp_dtype is None:
122
+ assert kv_cache_dtype == "auto", "kv_cache_dtype must be 'auto' if kv_cache_jnp_dtype is None"
123
+ kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
124
+
125
+ kv_caches = create_kv_caches(
126
+ num_blocks=DEFAULT_NUM_BLOCKS_FOR_JIT_KV_CACHE,
127
+ block_size=kv_cache_block_size,
128
+ num_kv_heads=kv_cache_num_kv_heads,
129
+ head_size=kv_cache_head_size,
130
+ mesh=mesh,
131
+ layer_names=[f"layer.{i}" for i in range(num_hidden_layers)],
132
+ cache_dtype=kv_cache_jnp_dtype)
133
+
134
+ # NOTE: the inputs don't need to match the actual ones, as long as the consumed weights are the same
135
+ input_ids = jax.random.randint(rng,
136
+ (DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS, ),
137
+ 0,
138
+ 100,
139
+ dtype=jnp.int32)
140
+ positions = jax.random.randint(rng,
141
+ (DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS, ),
142
+ 0,
143
+ 100,
144
+ dtype=jnp.int32)
145
+ block_tables = jax.random.randint(rng,
146
+ (DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS *
147
+ DEFAULT_MAX_NUM_BLOCKS_PER_REQ, ),
148
+ 0,
149
+ 100,
150
+ dtype=jnp.int32)
151
+ query_start_loc = jax.random.randint(
152
+ rng, (DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS + 1, ),
153
+ 0,
154
+ 100,
155
+ dtype=jnp.int32)
156
+ seq_lens = jax.random.randint(rng,
157
+ (DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS, ),
158
+ 0,
159
+ 100,
160
+ dtype=jnp.int32)
161
+ num_seqs = jax.random.randint(rng, (1, ), 0, 100, dtype=jnp.int32)
162
+ request_distribution = jnp.array([0, 0, num_seqs[0]], dtype=jnp.int32)
163
+
164
+ (input_ids, positions, block_tables,
165
+ query_start_loc, seq_lens, request_distribution) = device_array(
166
+ mesh, (input_ids, positions, block_tables, query_start_loc, seq_lens,
167
+ request_distribution))
168
+
169
+ model_input = {
170
+ "kv_caches":
171
+ kv_caches,
172
+ "input_ids":
173
+ input_ids,
174
+ "attention_metadata":
175
+ AttentionMetadata(input_positions=positions,
176
+ block_tables=block_tables,
177
+ seq_lens=seq_lens,
178
+ query_start_loc=query_start_loc,
179
+ request_distribution=request_distribution),
180
+ }
181
+ model = qwix.quantize_model(model, qwix.PtqProvider(qwix_rules),
182
+ **model_input)
183
+ return model
184
+
185
+
186
+ def quantization_config_file_path_to_dict(
187
+ quantization_config_file_path: str) -> dict:
188
+ """
189
+ Converts a quantization config YAML file path to a dictionary.
190
+
191
+ The expected format of the quantization config YAML file is as follows:
192
+ ```yaml
193
+ qwix:
194
+ # optional, defaults to False if not specified
195
+ use_abstract_model: True
196
+ rules:
197
+ # NOTE: each entry corresponds to a qwix.QuantizationRule
198
+ - module_path: '.*attn.*'
199
+ weight_qtype: 'int8'
200
+ - module_path: '.*'
201
+ weight_qtype: 'int8'
202
+ act_qtype: 'int8'
203
+ ```
204
+
205
+ Args:
206
+ quantization_config_file_path: the path to the quantization config YAML file
207
+
208
+ Returns:
209
+ a dictionary containing the quantization config
210
+ """
211
+ all_entries = os.listdir(QUANTIZATION_CONFIG_PATH)
212
+ for filename in all_entries:
213
+ if filename == quantization_config_file_path:
214
+ path = os.path.join(QUANTIZATION_CONFIG_PATH, filename)
215
+ with open(path, "r") as f:
216
+ return yaml.safe_load(f)
217
+ raise ValueError(
218
+ f"Could not find quantization config file with name '{quantization_config_file_path}' in 'tpu_inference/models/jax/utils/quantization/configs."
219
+ )
220
+
221
+
222
+ def apply_qwix_quantization(
223
+ vllm_config: "VllmConfig", model_or_model_fn: Callable | nnx.Module,
224
+ rng: jax.Array, mesh: Mesh,
225
+ apply_to_abstract_model: bool) -> nnx.Module | Callable:
226
+ """
227
+ Will apply quantization if a valid quantization config with Qwix rules is provided. See README
228
+ for more details on Qwix.
229
+
230
+ Note that we currently support different methods for applying Qwix quantization. The typical
231
+ approach is to apply quantization on the concrete model, which already has the weights
232
+ loaded in. However, for models like DeepSeek, which are already quantized, we need to
233
+ first create the abstract model, then apply Qwix quantization to the abstract model, and
234
+ finally load the weights in. To use the latter approach, you will need to modify the
235
+ model weight loading code appropriately (see deepseek_v3.py for an example) and
236
+ pass and `use_abstract_model=True` in the quantization config.
237
+
238
+ Args:
239
+ vllm_config: the base VLLM config
240
+ model_or_model_fn: if `apply_to_abstract_model` is True, this will be a Callable that returns the abstract model
241
+ (e.g. _create_abstract_model). Otherwise, this will be the concrete model (nnx.Module).
242
+ rng: JAX RNG
243
+ mesh: model Mesh
244
+ apply_to_abstract_model: if True, we will apply Qwix quantization to the abstract model, which
245
+ assumes that, during weight loading, the caller will thus override the QArray weights
246
+ (see deepseek_v3.py for an example). Otherwise, we will will apply Qwix quantization to the
247
+ concrete model, which already has the weights loaded in.
248
+
249
+ Returns:
250
+ Either the concrete model (nnx.Module) or the abstract model (Callable) (if `apply_to_abstract_model` is True)
251
+ """
252
+ qwix_config = None
253
+ if quantization_config := vllm_config.additional_config.get(
254
+ "quantization"):
255
+ qwix_config = quantization_config.get("qwix").get("rules")
256
+ if not qwix_config:
257
+ return model_or_model_fn
258
+
259
+ logging_abstract_model_str = "abstract" if apply_to_abstract_model else "concrete"
260
+ logger.info(
261
+ f"Applying Qwix quantization on {logging_abstract_model_str} model")
262
+
263
+ block_size = vllm_config.cache_config.block_size
264
+ model_config = vllm_config.model_config
265
+
266
+ # Pad num_kv_heads to multiple of TP size
267
+ num_kv_heads = utils.get_padded_num_heads(
268
+ model_config.get_total_num_kv_heads(), mesh.shape["model"])
269
+
270
+ # Pad head_dim to multiple of 128
271
+ head_size = model_config.get_head_size()
272
+ head_size = utils.get_padded_head_dim(head_size)
273
+
274
+ kv_cache_dtype = vllm_config.cache_config.cache_dtype
275
+
276
+ if not apply_to_abstract_model:
277
+ assert isinstance(model_or_model_fn, nnx.Module)
278
+ qwix_quantize_nnx_model_with_config = functools.partial(
279
+ qwix_quantize_nnx_model, qwix_config=qwix_config)
280
+ # NOTE: it's REALLY important `qwix_quantize_nnx_model_with_config` is jitted
281
+ # or else you'll run into hanging
282
+ model_or_model_fn = nnx.jit(
283
+ qwix_quantize_nnx_model_with_config,
284
+ donate_argnums=(0, ),
285
+ static_argnames=(
286
+ "mesh",
287
+ "num_hidden_layers",
288
+ "kv_cache_block_size",
289
+ "kv_cache_num_kv_heads",
290
+ "kv_cache_head_size",
291
+ "kv_cache_dtype",
292
+ ))(model=model_or_model_fn,
293
+ rng=rng,
294
+ mesh=mesh,
295
+ num_hidden_layers=vllm_config.model_config.hf_config.
296
+ num_hidden_layers,
297
+ kv_cache_block_size=block_size,
298
+ kv_cache_num_kv_heads=num_kv_heads,
299
+ kv_cache_head_size=head_size,
300
+ kv_cache_dtype=kv_cache_dtype)
301
+
302
+ return model_or_model_fn
303
+
304
+ qwix_quantize_fn_for_eval = functools.partial(
305
+ qwix_quantize_nnx_model,
306
+ qwix_config=qwix_config,
307
+ mesh=mesh,
308
+ num_hidden_layers=vllm_config.model_config.hf_config.num_hidden_layers,
309
+ kv_cache_block_size=block_size,
310
+ kv_cache_num_kv_heads=num_kv_heads,
311
+ kv_cache_head_size=head_size,
312
+ kv_cache_dtype=kv_cache_dtype)
313
+
314
+ def create_and_quantize_model_factory() -> Callable:
315
+ """
316
+ Helper function to create and quantize the abstract model.
317
+ """
318
+ model = model_or_model_fn()
319
+ # Handle the DeepSeek case, where this needs to be called in the abstract model
320
+ if hasattr(model, 'initialize_cache'):
321
+ model.initialize_cache()
322
+ return qwix_quantize_fn_for_eval(model=model, rng=rng)
323
+
324
+ return create_and_quantize_model_factory
325
+
326
+
327
+ def apply_qwix_on_abstract_model(vllm_config: "VllmConfig") -> bool:
328
+ """
329
+ Determines whether to apply Qwix quantization on the abstract model (e.g. for DeepSeek)
330
+ or the concrete model. See `apply_qwix_quantization` for more details on the differences
331
+ between these two approaches.
332
+ Args:
333
+ vllm_config: the vllm config
334
+ Returns:
335
+ whether to apply Qwix quantization on the abstract model
336
+ """
337
+ quantization_config = vllm_config.additional_config.get("quantization", {})
338
+ return quantization_config.get("qwix", {}).get("use_abstract_model", False)
339
+
340
+
341
+ def get_default_qwix_quantization_config(
342
+ model_type: str, quant_method: str,
343
+ skip_quantization: bool) -> dict | None:
344
+ """
345
+ Some models are pre-quantized and in those cases, we want to return a default set of
346
+ Qwix quantization rules (instead of forcing the user to pass in a quantization config each time).
347
+
348
+ Note that if a user passes in a quantization config (via `additional_config`), then
349
+ we'll use that instead of this function.
350
+
351
+ Args:
352
+ model_type: the name of the model
353
+ quant_method: the quantization method
354
+ skip_quantization: whether to skip quantization. In this case, we'll return None
355
+
356
+ Returns:
357
+ a dictionary containing the default Qwix quantization rules
358
+ """
359
+ if skip_quantization:
360
+ return None
361
+ # TODO (jacobplatin): remove this so that we can support various quantization types
362
+ if model_type == "deepseek_v3" and quant_method == "fp8":
363
+ return DEFAULT_DEEPSEEK_FP8_CONFIG
364
+
365
+
366
+ def update_vllm_config_for_qwix_quantization(vllm_config: "VllmConfig"):
367
+ """
368
+ Updates the vLLM config to unpack the Qwix quantization config if it exists.
369
+ By default, we'll check if the checkpoint is quantized and update the
370
+ Qwix quantization config to use the default quantization config if it exists,
371
+ but we'll override this if the user passes in a quantization config via `additional_config`.
372
+ """
373
+ # Automatically detect whether checkpoint is quantized and update the
374
+ # Qwix quantization config accordingly
375
+ # NOTE: if a Qwix config is provided (via the`additional_config`), we'll
376
+ # use that instead
377
+ model_type = vllm_config.model_config.hf_config.model_type.lower(
378
+ ) if hasattr(vllm_config.model_config.hf_config, "model_type") else None
379
+ quant_method = vllm_config.model_config.hf_config.quantization_config[
380
+ "quant_method"] if hasattr(vllm_config.model_config.hf_config,
381
+ "quantization_config") else None
382
+ default_quantization_config = get_default_qwix_quantization_config(
383
+ model_type, quant_method,
384
+ vllm_config.additional_config.get("skip_quantization", False))
385
+
386
+ maybe_existing_quantization_config = vllm_config.additional_config.get(
387
+ "quantization")
388
+ if maybe_existing_quantization_config:
389
+ logger.warning("Overwriting default Qwix quantization config with "
390
+ "user provided quantization config.")
391
+ elif default_quantization_config is not None:
392
+ vllm_config.additional_config[
393
+ "quantization"] = default_quantization_config
394
+
395
+ # Validate additional config
396
+ if additional_config := vllm_config.additional_config:
397
+ # Try loading/parsing the quantization config so that we can fail fast
398
+ if quantization_config := additional_config.get("quantization"):
399
+ try:
400
+ # NOTE: Qwix quantization supports two paths:
401
+ # 1. quantization config file (which we need to parse to a dictionary)
402
+ # 2. quantization config JSON
403
+ if isinstance(quantization_config, str):
404
+ quantization_config = quantization_config_file_path_to_dict(
405
+ quantization_config)
406
+ # NOTE: unpack the quantization config now so we don't need to keep doing this every time
407
+ vllm_config.additional_config[
408
+ "quantization"] = quantization_config
409
+ parse_qwix_config_to_rules(
410
+ quantization_config["qwix"]["rules"])
411
+ except Exception as e:
412
+ raise ValueError(
413
+ f"Invalid quantization config; please see README for details on quantization config: {e}"
414
+ )
415
+
416
+
417
+ def get_random_sharded_array(key: PRNGKey, mesh: Mesh, param: nnx.Param,
418
+ param_shape: tuple, dtype: jnp.dtype,
419
+ param_name: str) -> jax.Array:
420
+ """
421
+ Returns a random sharded array for the given parameter for the given shape.
422
+
423
+ Args:
424
+ key: The random key.
425
+ mesh: The mesh to use for sharding.
426
+ param: The parameter.
427
+ param_shape: The shape of the parameter.
428
+ dtype: The dtype of the parameter.
429
+ param_name: The name of the parameter.
430
+
431
+ Returns:
432
+ A random sharded array for the given parameter for the given shape.
433
+ """
434
+ is_int = jnp.issubdtype(dtype, jnp.integer)
435
+ if is_int:
436
+ # These need to be JAX arrays or else you'll run into an overflow error
437
+ minval = jnp.array(jnp.iinfo(dtype).min, dtype=dtype)
438
+ maxval = jnp.array(jnp.iinfo(dtype).max, dtype=dtype)
439
+ weight = jax.random.randint(key, param_shape, minval, maxval, dtype)
440
+ else:
441
+ weight = jax.random.normal(key, param_shape, dtype)
442
+
443
+ def get_slice(index):
444
+ return weight[index]
445
+
446
+ try:
447
+ sharded_array = jax.make_array_from_callback(
448
+ param_shape, NamedSharding(mesh, P(*param.sharding)), get_slice)
449
+ except (ValueError, TypeError):
450
+ logger.warning(
451
+ f"Could not create sharded scale for {param_name} with shape {param_shape} and sharding {param.sharding}, skipping sharding..."
452
+ )
453
+ sharded_array = jax.make_array_from_callback(param_shape,
454
+ NamedSharding(mesh, P()),
455
+ get_slice)
456
+
457
+ return sharded_array
458
+
459
+
460
+ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
461
+ model: nnx.Module, mesh: Mesh,
462
+ quantization_config: dict):
463
+ """
464
+ Loads random weights for an abstract, Qwix-quantized model.
465
+
466
+ Args:
467
+ rng: The random key.
468
+ state: The state of the model.
469
+ mesh: The mesh.
470
+ model: The model.
471
+ quantization_config: The quantization config for the model.
472
+ """
473
+ logger.info("Initializing Qwix-quantized model with random weights...")
474
+ # TODO (jacobplatin): clean up this logic
475
+ scale_dtype = model.weight_loader.scale_dtype
476
+ scale_shape_map = model.weight_loader.scale_shap_map_for_random_weight_loading if hasattr(
477
+ model.weight_loader,
478
+ 'scale_shap_map_for_random_weight_loading') else {}
479
+ quantization_block_sizes = quantization_config["weight_block_size"]
480
+ assert len(
481
+ quantization_block_sizes
482
+ ) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
483
+ quantization_block_size_n, _ = quantization_block_sizes[
484
+ 0], quantization_block_sizes[1]
485
+
486
+ # Iterate through all variables and initialize them
487
+ prev_param_shape = None
488
+ for path, param in nnx.iter_graph(model):
489
+ if not isinstance(param, nnx.Variable):
490
+ continue
491
+ if path[0] == 'rng' and path[-1] == "key":
492
+ param.value = rng
493
+ continue
494
+ is_qwix_scale = (path[-1] == 'scale' and path[-2] == "array")
495
+ param_dtype = scale_dtype if is_qwix_scale else param.value.dtype
496
+ param_shape = param.value.shape
497
+ # TODO (jacobplatin): clean this up
498
+ if is_qwix_scale:
499
+ param_shape = scale_shape_map.get(
500
+ path[3],
501
+ tuple(dim // quantization_block_size_n
502
+ for dim in prev_param_shape))
503
+ param.value = get_random_sharded_array(
504
+ rng, mesh, param, param_shape, param_dtype,
505
+ ".".join([str(x) for x in path]))
506
+ prev_param_shape = param_shape
507
+
508
+ # Handles the DeepSeek case, where this needs to be called to make the cache weights
509
+ # concrete
510
+ if hasattr(model, 'initialize_cache'):
511
+ model.initialize_cache()
512
+ logger.info("Done initializing Qwix-quantized model with random weights")
513
+
514
+
515
+ def manually_quantize_qwix_weight(weight: jax.Array, qtype: jnp.dtype,
516
+ channelwise_axes: List[int],
517
+ tiled_axes: dict,
518
+ calibration_method: str) -> QArray:
519
+ """
520
+ Manually quantizes a weight tensor using Qwix. Only needed for the SparseMatmul DeepSeek case right now, since
521
+ otherwise, Qwix will handle this automatically (through our application of `qwix.quantize_model`).
522
+ """
523
+ # TODO (jacobplatin): clean this up; this is needed because of issues with Qwix quantizing the `shard_map` in SpraseMatmul
524
+ how_to_quantize = ptq.qarray.HowToQuantize(
525
+ qtype=qtype,
526
+ channelwise_axes=channelwise_axes,
527
+ tiled_axes=tiled_axes,
528
+ calibration_method=calibration_method)
529
+
530
+ return ptq.create_quantized_param(weight, how_to_quantize)
531
+
532
+
533
+ def manually_quantize_qwix_activation(inputs: jax.Array, rule_name: str,
534
+ qtype: jnp.dtype,
535
+ channelwise_axes: List[int],
536
+ tiled_axes: dict,
537
+ calibration_method: str) -> QArray:
538
+ """
539
+ Manually quantizes an activation tensor using Qwix. Needed for the SparseMatmul
540
+ DeepSeek MoE case currently.
541
+
542
+ Args:
543
+ inputs: The activation tensor to quantize.
544
+ rule_name: The name of the quantization rule to use.
545
+ qtype: The quantization type.
546
+ channelwise_axes: The channelwise axes to quantize.
547
+ tiled_axes: The tiled axes to quantize.
548
+ calibration_method: The calibration method to use.
549
+
550
+ Returns:
551
+ The quantized activation tensor.
552
+ """
553
+ rule = qpl.get_current_rule(rule_name)
554
+ lhs_how = ptq.qarray.HowToQuantize(qtype=qtype,
555
+ channelwise_axes=channelwise_axes,
556
+ tiled_axes=tiled_axes,
557
+ calibration_method=calibration_method)
558
+ # This is needed because we aren't passing `act_name` right now
559
+ assert not rule.act_static_scale, "Static scale not supported right now"
560
+
561
+ # channelwise_axes should be set to (a subset of) non-contraction axes. e.g.
562
+ # for ragged_dot [m, k] x [g, k, n], they are [0] and [0, 2]
563
+ # TODO (jacobplatin): add support for `act_name`
564
+ return ptq.quantize_act(inputs, lhs_how, rule, "")
565
+
566
+
567
+ def get_quant_dtype_from_qwix_config(
568
+ vllm_config: "VllmConfig") -> tuple[jnp.dtype, jnp.dtype]:
569
+ """
570
+ Gets the quantization dtype from the Qwix config.
571
+
572
+ Args:
573
+ vllm_config: The VllmConfig object.
574
+
575
+ Returns:
576
+ A tuple of the scale dtype and quant dtype.
577
+ """
578
+ qwix_config = vllm_config.additional_config.get("quantization",
579
+ {}).get("qwix", {})
580
+ scale_dtype = getattr(jnp, qwix_config.get("scale_dtype", "bfloat16"))
581
+ quant_dtype = None
582
+ # TODO (jacobplatin): this needs to be much more robust
583
+ for rule in qwix_config.get("rules", []):
584
+ if rule.get("module_path") == ".*":
585
+ quant_dtype_str = rule.get("weight_qtype", "")
586
+ assert quant_dtype_str, "Quantization dtype not found in Qwix config! We currently expect your Qwix config to have a rule with module_path '.*' and a weight_qtype."
587
+ quant_dtype = getattr(jnp, quant_dtype_str)
588
+ return scale_dtype, quant_dtype