tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,713 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import copy
4
+ import functools
5
+ import os
6
+ from typing import TYPE_CHECKING, Callable, List
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import qwix
11
+ import qwix.pallas as qpl
12
+ import yaml
13
+ from flax import nnx
14
+ from flax.typing import PRNGKey
15
+ from jax.sharding import Mesh, NamedSharding
16
+ from jax.sharding import PartitionSpec as P
17
+ from qwix._src.core.qarray import QArray
18
+ from qwix._src.providers import ptq
19
+
20
+ if TYPE_CHECKING:
21
+ from vllm.config import VllmConfig
22
+
23
+ from tpu_inference import utils
24
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
25
+ from tpu_inference.logger import init_logger
26
+ from tpu_inference.runner.kv_cache import (DEFAULT_KV_CACHE_DTYPE,
27
+ create_kv_caches)
28
+ from tpu_inference.utils import device_array
29
+
30
+ logger = init_logger(__name__)
31
+
32
+ QUANTIZATION_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs")
33
+ DEFAULT_NUM_BLOCKS_FOR_JIT_KV_CACHE = 2000
34
+ DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS = 512
35
+ DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS = 256
36
+ DEFAULT_MAX_NUM_BLOCKS_PER_REQ = 16
37
+
38
+ DEFAULT_DEEPSEEK_FP8_CONFIG = {
39
+ "qwix": {
40
+ "use_abstract_model":
41
+ True,
42
+ "scale_dtype":
43
+ "bfloat16",
44
+ "rules": [
45
+ # Exclude router from quantization
46
+ {
47
+ "module_path": ".*.custom_module.router.*",
48
+ "weight_qtype": None,
49
+ },
50
+ # Avoid the combine expert ops
51
+ {
52
+ "module_path": ".*combine_experts.*",
53
+ "weight_qtype": None,
54
+ },
55
+ # Attention layers: keep FP8 for weights and activations
56
+ {
57
+ "module_path": ".*.attn.*",
58
+ "weight_qtype": "float8_e4m3fn",
59
+ "act_qtype": "float8_e4m3fn",
60
+ },
61
+ # MoE experts: use FP4 for expert weights
62
+ {
63
+ "module_path": ".*.custom_module.*",
64
+ "weight_qtype": "float4_e2m1fn",
65
+ "act_qtype": "float8_e4m3fn",
66
+ "tile_size": 256,
67
+ },
68
+ # Shared experts: also FP4
69
+ {
70
+ "module_path": ".*.shared_experts.*",
71
+ "weight_qtype": "float4_e2m1fn",
72
+ "act_qtype": "float8_e4m3fn",
73
+ "tile_size": 256,
74
+ },
75
+ {
76
+ "module_path": ".*",
77
+ "weight_qtype": "float8_e4m3fn",
78
+ "act_qtype": "float8_e4m3fn",
79
+ },
80
+ ],
81
+ }
82
+ }
83
+
84
+ DEFAULT_LLAMA4_FP8_CONFIG = {
85
+ "qwix": {
86
+ "use_abstract_model":
87
+ True,
88
+ "scale_dtype":
89
+ "bfloat16",
90
+ "rules": [
91
+ {
92
+ "module_path": "layers.*.moe_ffw",
93
+ "op_names": "einsum",
94
+ "weight_qtype": "float8_e4m3fn",
95
+ "act_qtype": "float8_e4m3fn",
96
+ },
97
+ ],
98
+ }
99
+ }
100
+
101
+ # Default Qwix config for GPT-OSS MXFP4 checkpoints.
102
+ # Notes:
103
+ # - We quantize only the MoE expert weights by default (router stays in BF16).
104
+ # - We use Qwix's abstract-model path so weights can be set directly into QArray
105
+ # fields during weight loading (similar to DeepSeek's flow).
106
+ # - Activation quantization is not set but Qwix would pickup MoE sum if activated
107
+ DEFAULT_GPT_OSS_FP4_CONFIG = {
108
+ "qwix": {
109
+ "use_abstract_model":
110
+ True,
111
+ "scale_dtype":
112
+ "bfloat16",
113
+ "rules": [
114
+ {
115
+ "module_path": ".*custom_module",
116
+ "weight_qtype": "float4_e2m1fn",
117
+ "act_qtype": None,
118
+ "tile_size": 32,
119
+ },
120
+ ],
121
+ }
122
+ }
123
+
124
+
125
+ def parse_qwix_config_to_rules(
126
+ qwix_config: List[dict]) -> List[qwix.QuantizationRule]:
127
+ """
128
+ Parse a list of dictionaries containing Qwix quantization rules into a list of QuantizationRule objects.
129
+
130
+ Args:
131
+ qwix_config: a dictionary containing the Qwix quantization rules
132
+
133
+ Returns:
134
+ a list of QuantizationRule objects
135
+ """
136
+ rules = []
137
+ for rule in qwix_config:
138
+ rules.append(qwix.QuantizationRule(**rule))
139
+
140
+ return rules
141
+
142
+
143
+ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
144
+ rng: jax.Array, mesh: Mesh, num_hidden_layers: int,
145
+ kv_cache_block_size: int,
146
+ kv_cache_num_kv_heads: int,
147
+ kv_cache_head_size: int,
148
+ kv_cache_dtype: str) -> nnx.Module:
149
+ """
150
+ Quantizes a Flax NNX model using Qwix.
151
+
152
+ Args:
153
+ model: the model to quantize
154
+ qwix_config: a list of dictionaries, where each dictionary corresponds to a Qwix quantization rule
155
+ For example:
156
+ [
157
+ {
158
+ "module_path": ".*attn.*",
159
+ "weight_qtype": "int8",
160
+ },
161
+ {
162
+ "module_path": ".*mlp.*",
163
+ "weight_qtype": "int8",
164
+ "act_qtype": "int8",
165
+ "tile_size": None,
166
+ },
167
+ ]
168
+ rng: the random number generator to use
169
+ mesh: the mesh to use
170
+ num_hidden_layers: the number of hidden layers in the model
171
+ kv_cache_page_size: the page size of the kv cache
172
+ kv_cache_num_kv_heads: the number of kv heads
173
+ head_size: the head size of the kv cache
174
+ kv_cache_dtype: the dtype of the kv cache
175
+
176
+ Returns:
177
+ model: the quantized model
178
+ """
179
+ qwix_rules = parse_qwix_config_to_rules(qwix_config)
180
+ logger.info(f"Qwix rules: {qwix_rules}")
181
+ logger.info(f"Memory usage before applying quantization of params: "
182
+ f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
183
+
184
+ if kv_cache_dtype != "auto":
185
+ kv_cache_jnp_dtype = utils.to_jax_dtype(kv_cache_dtype)
186
+ else:
187
+ kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
188
+
189
+ kv_caches = create_kv_caches(
190
+ num_blocks=DEFAULT_NUM_BLOCKS_FOR_JIT_KV_CACHE,
191
+ block_size=kv_cache_block_size,
192
+ num_kv_heads=kv_cache_num_kv_heads,
193
+ head_size=kv_cache_head_size,
194
+ mesh=mesh,
195
+ layer_names=[f"layer.{i}" for i in range(num_hidden_layers)],
196
+ cache_dtype=kv_cache_jnp_dtype,
197
+ use_mla=model.vllm_config.model_config.use_mla,
198
+ )
199
+
200
+ dp_size = model.vllm_config.sharding_config.total_dp_size
201
+
202
+ # NOTE: the inputs don't need to match the actual ones, as long as the consumed weights are the same
203
+ input_ids = jax.random.randint(rng,
204
+ (DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS, ),
205
+ 0,
206
+ 100,
207
+ dtype=jnp.int32)
208
+ positions = jax.random.randint(rng,
209
+ (DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS, ),
210
+ 0,
211
+ 100,
212
+ dtype=jnp.int32)
213
+ block_tables = jax.random.randint(rng,
214
+ (DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS *
215
+ DEFAULT_MAX_NUM_BLOCKS_PER_REQ, ),
216
+ 0,
217
+ 100,
218
+ dtype=jnp.int32)
219
+ query_start_loc = jax.random.randint(
220
+ rng, (DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS + dp_size, ),
221
+ 0,
222
+ 100,
223
+ dtype=jnp.int32)
224
+ seq_lens = jax.random.randint(rng,
225
+ (DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS, ),
226
+ 0,
227
+ 100,
228
+ dtype=jnp.int32)
229
+ num_seqs = jax.random.randint(rng, (1, ), 0, 100, dtype=jnp.int32)
230
+ request_distribution = jnp.array([0, 0, num_seqs[0]] * dp_size,
231
+ dtype=jnp.int32)
232
+
233
+ (input_ids, positions, block_tables,
234
+ query_start_loc, seq_lens, request_distribution) = device_array(
235
+ mesh, (input_ids, positions, block_tables, query_start_loc, seq_lens,
236
+ request_distribution))
237
+
238
+ model_input = {
239
+ "kv_caches":
240
+ kv_caches,
241
+ "input_ids":
242
+ input_ids,
243
+ "attention_metadata":
244
+ AttentionMetadata(input_positions=positions,
245
+ block_tables=block_tables,
246
+ seq_lens=seq_lens,
247
+ query_start_loc=query_start_loc,
248
+ request_distribution=request_distribution),
249
+ }
250
+ model = qwix.quantize_model(model, qwix.PtqProvider(qwix_rules),
251
+ **model_input)
252
+ return model
253
+
254
+
255
+ def quantization_config_file_path_to_dict(
256
+ quantization_config_file_path: str) -> dict:
257
+ """
258
+ Converts a quantization config YAML file path to a dictionary.
259
+
260
+ The expected format of the quantization config YAML file is as follows:
261
+ ```yaml
262
+ qwix:
263
+ # optional, defaults to False if not specified
264
+ use_abstract_model: True
265
+ rules:
266
+ # NOTE: each entry corresponds to a qwix.QuantizationRule
267
+ - module_path: '.*attn.*'
268
+ weight_qtype: 'int8'
269
+ - module_path: '.*'
270
+ weight_qtype: 'int8'
271
+ act_qtype: 'int8'
272
+ ```
273
+
274
+ Args:
275
+ quantization_config_file_path: the path to the quantization config YAML file
276
+
277
+ Returns:
278
+ a dictionary containing the quantization config
279
+ """
280
+ all_entries = os.listdir(QUANTIZATION_CONFIG_PATH)
281
+ for filename in all_entries:
282
+ if filename == quantization_config_file_path:
283
+ path = os.path.join(QUANTIZATION_CONFIG_PATH, filename)
284
+ with open(path, "r") as f:
285
+ return yaml.safe_load(f)
286
+ raise ValueError(
287
+ f"Could not find quantization config file with name '{quantization_config_file_path}' in 'tpu_inference/models/jax/utils/quantization/configs."
288
+ )
289
+
290
+
291
+ def apply_qwix_quantization(
292
+ vllm_config: "VllmConfig", model_or_model_fn: Callable | nnx.Module,
293
+ rng: jax.Array, mesh: Mesh,
294
+ apply_to_abstract_model: bool) -> nnx.Module | Callable:
295
+ """
296
+ Will apply quantization if a valid quantization config with Qwix rules is provided. See README
297
+ for more details on Qwix.
298
+
299
+ Note that we currently support different methods for applying Qwix quantization. The typical
300
+ approach is to apply quantization on the concrete model, which already has the weights
301
+ loaded in. However, for models like DeepSeek, which are already quantized, we need to
302
+ first create the abstract model, then apply Qwix quantization to the abstract model, and
303
+ finally load the weights in. To use the latter approach, you will need to modify the
304
+ model weight loading code appropriately (see deepseek_v3.py for an example) and
305
+ pass and `use_abstract_model=True` in the quantization config.
306
+
307
+ Args:
308
+ vllm_config: the base VLLM config
309
+ model_or_model_fn: if `apply_to_abstract_model` is True, this will be a Callable that returns the abstract model
310
+ (e.g. _create_abstract_model). Otherwise, this will be the concrete model (nnx.Module).
311
+ rng: JAX RNG
312
+ mesh: model Mesh
313
+ apply_to_abstract_model: if True, we will apply Qwix quantization to the abstract model, which
314
+ assumes that, during weight loading, the caller will thus override the QArray weights
315
+ (see deepseek_v3.py for an example). Otherwise, we will will apply Qwix quantization to the
316
+ concrete model, which already has the weights loaded in.
317
+
318
+ Returns:
319
+ Either the concrete model (nnx.Module) or the abstract model (Callable) (if `apply_to_abstract_model` is True)
320
+ """
321
+ qwix_config = None
322
+ if quantization_config := vllm_config.additional_config.get(
323
+ "quantization"):
324
+ qwix_config = quantization_config.get("qwix").get("rules")
325
+ if not qwix_config:
326
+ return model_or_model_fn
327
+
328
+ logging_abstract_model_str = "abstract" if apply_to_abstract_model else "concrete"
329
+ logger.info(
330
+ f"Applying Qwix quantization on {logging_abstract_model_str} model")
331
+
332
+ block_size = vllm_config.cache_config.block_size
333
+ model_config = vllm_config.model_config
334
+
335
+ # Pad num_kv_heads to multiple of TP size
336
+ num_kv_heads = utils.get_padded_num_heads(
337
+ model_config.get_total_num_kv_heads(), mesh.shape["model"])
338
+
339
+ # Pad head_dim to multiple of 128
340
+ head_size = model_config.get_head_size()
341
+ head_size = utils.get_padded_head_dim(head_size)
342
+
343
+ kv_cache_dtype = vllm_config.cache_config.cache_dtype
344
+
345
+ if not apply_to_abstract_model:
346
+ assert isinstance(model_or_model_fn, nnx.Module)
347
+ qwix_quantize_nnx_model_with_config = functools.partial(
348
+ qwix_quantize_nnx_model, qwix_config=qwix_config)
349
+ # NOTE: it's REALLY important `qwix_quantize_nnx_model_with_config` is jitted
350
+ # or else you'll run into hanging
351
+ model_or_model_fn = nnx.jit(
352
+ qwix_quantize_nnx_model_with_config,
353
+ donate_argnums=(0, ),
354
+ static_argnames=(
355
+ "mesh",
356
+ "num_hidden_layers",
357
+ "kv_cache_block_size",
358
+ "kv_cache_num_kv_heads",
359
+ "kv_cache_head_size",
360
+ "kv_cache_dtype",
361
+ ))(model=model_or_model_fn,
362
+ rng=rng,
363
+ mesh=mesh,
364
+ num_hidden_layers=vllm_config.model_config.hf_config.
365
+ num_hidden_layers,
366
+ kv_cache_block_size=block_size,
367
+ kv_cache_num_kv_heads=num_kv_heads,
368
+ kv_cache_head_size=head_size,
369
+ kv_cache_dtype=kv_cache_dtype)
370
+
371
+ return model_or_model_fn
372
+
373
+ hf_config = vllm_config.model_config.hf_config
374
+ if hasattr(hf_config, "text_config") and hasattr(hf_config.text_config,
375
+ "num_hidden_layers"):
376
+ num_hidden_layers = hf_config.text_config.num_hidden_layers
377
+ logger.info(
378
+ f"Using num_hidden_layers from hf_config.text_config: {num_hidden_layers}"
379
+ )
380
+ elif hasattr(hf_config, "num_hidden_layers"):
381
+ num_hidden_layers = hf_config.num_hidden_layers
382
+ logger.info(
383
+ f"Using num_hidden_layers directly from hf_config: {num_hidden_layers}"
384
+ )
385
+ else:
386
+ raise AttributeError(
387
+ "Could not find 'num_hidden_layers' in hf_config or hf_config.text_config."
388
+ )
389
+
390
+ qwix_quantize_fn_for_eval = functools.partial(
391
+ qwix_quantize_nnx_model,
392
+ qwix_config=qwix_config,
393
+ mesh=mesh,
394
+ num_hidden_layers=num_hidden_layers,
395
+ kv_cache_block_size=block_size,
396
+ kv_cache_num_kv_heads=num_kv_heads,
397
+ kv_cache_head_size=head_size,
398
+ kv_cache_dtype=kv_cache_dtype)
399
+
400
+ def create_and_quantize_model_factory() -> Callable:
401
+ """
402
+ Helper function to create and quantize the abstract model.
403
+ """
404
+ model = model_or_model_fn()
405
+ # Handle the DeepSeek case, where this needs to be called in the abstract model
406
+ if hasattr(model, 'initialize_cache'):
407
+ model.initialize_cache()
408
+ return qwix_quantize_fn_for_eval(model=model, rng=rng)
409
+
410
+ return create_and_quantize_model_factory
411
+
412
+
413
+ def apply_qwix_on_abstract_model(vllm_config: "VllmConfig") -> bool:
414
+ """
415
+ Determines whether to apply Qwix quantization on the abstract model (e.g. for DeepSeek)
416
+ or the concrete model. See `apply_qwix_quantization` for more details on the differences
417
+ between these two approaches.
418
+ Args:
419
+ vllm_config: the vllm config
420
+ Returns:
421
+ whether to apply Qwix quantization on the abstract model
422
+ """
423
+ quantization_config = vllm_config.additional_config.get("quantization", {})
424
+ return quantization_config.get("qwix", {}).get("use_abstract_model", False)
425
+
426
+
427
+ def get_default_qwix_quantization_config(
428
+ hf_config: dict, skip_quantization: bool) -> dict | None:
429
+ """
430
+ Some models are pre-quantized and in those cases, we want to return a default set of
431
+ Qwix quantization rules (instead of forcing the user to pass in a quantization config each time).
432
+
433
+ Note that if a user passes in a quantization config (via `additional_config`), then
434
+ we'll use that instead of this function.
435
+
436
+ Args:
437
+ model_type: the name of the model
438
+ quant_method: the quantization method
439
+ skip_quantization: whether to skip quantization. In this case, we'll return None
440
+
441
+ Returns:
442
+ a dictionary containing the default Qwix quantization rules
443
+ """
444
+ if skip_quantization:
445
+ return None
446
+ model_type = hf_config.model_type.lower() if hasattr(
447
+ hf_config, "model_type") else None
448
+ quant_method = hf_config.quantization_config["quant_method"] if hasattr(
449
+ hf_config, "quantization_config") else None
450
+ # TODO (jacobplatin): remove this so that we can support various quantization types + make
451
+ # more flexible
452
+ # NOTE (jacobplatin): we'll default to mixed FP8 (attention) + FP4 (MoE experts)
453
+ # for DeepSeek
454
+ if model_type == "deepseek_v3" and quant_method == "fp8":
455
+ config = copy.deepcopy(DEFAULT_DEEPSEEK_FP8_CONFIG)
456
+
457
+ # Dynamically fetch block size from HF config if available
458
+ # Config fmt: 'weight_block_size': [1, 512] -> we want the 2nd dim for tile_size
459
+ # NOTE: if the checkpoint is not 1D subchannel, we will throw an error
460
+ hf_quant_config = hf_config.quantization_config
461
+ assert "weight_block_size" in hf_quant_config, "Expected weight_block_size in quantization_config"
462
+ block_size = hf_quant_config["weight_block_size"]
463
+ if isinstance(block_size, (list, tuple)) and len(block_size) == 2:
464
+ assert block_size[
465
+ 0] == 1, f"Expected first dimension to be 1 (unchanneled), but got {block_size[0]}!"
466
+ tile_size = block_size[1]
467
+ assert tile_size > 1, f"Expected tile_size > 1 for DeepSeek, but got {tile_size}"
468
+ logger.info(
469
+ f"Detected DeepSeek tile_size from config: {tile_size}")
470
+
471
+ # Update tile_size in the rules, since we might not always use a 1D subchannel size of
472
+ # 256
473
+ for rule in config["qwix"]["rules"]:
474
+ if "tile_size" in rule:
475
+ rule["tile_size"] = tile_size
476
+ else:
477
+ raise ValueError(
478
+ f"Invalid weight_block_size config: {block_size}, expected a list/tuple of length 2"
479
+ )
480
+
481
+ return config
482
+ elif model_type == "llama4" and quant_method == "compressed-tensors":
483
+ return DEFAULT_LLAMA4_FP8_CONFIG
484
+ # MXFP4 (GPT-OSS): provide a default configuration to quantize MoE experts via Qwix
485
+ elif model_type == "gpt_oss" and quant_method == "mxfp4":
486
+ return DEFAULT_GPT_OSS_FP4_CONFIG
487
+
488
+
489
+ def update_vllm_config_for_qwix_quantization(vllm_config: "VllmConfig"):
490
+ """
491
+ Updates the vLLM config to unpack the Qwix quantization config if it exists.
492
+ By default, we'll check if the checkpoint is quantized and update the
493
+ Qwix quantization config to use the default quantization config if it exists,
494
+ but we'll override this if the user passes in a quantization config via `additional_config`.
495
+ """
496
+ # Automatically detect whether checkpoint is quantized and update the
497
+ # Qwix quantization config accordingly
498
+ # NOTE: if a Qwix config is provided (via the`additional_config`), we'll
499
+ # use that instead
500
+ hf_config = vllm_config.model_config.hf_config
501
+ default_quantization_config = get_default_qwix_quantization_config(
502
+ hf_config, vllm_config.additional_config.get("skip_quantization",
503
+ False))
504
+
505
+ maybe_existing_quantization_config = vllm_config.additional_config.get(
506
+ "quantization")
507
+ if maybe_existing_quantization_config:
508
+ logger.warning("Overwriting default Qwix quantization config with "
509
+ "user provided quantization config.")
510
+ elif default_quantization_config is not None:
511
+ vllm_config.additional_config[
512
+ "quantization"] = default_quantization_config
513
+
514
+ # Validate additional config
515
+ if additional_config := vllm_config.additional_config:
516
+ # Try loading/parsing the quantization config so that we can fail fast
517
+ if quantization_config := additional_config.get("quantization"):
518
+ try:
519
+ # NOTE: Qwix quantization supports two paths:
520
+ # 1. quantization config file (which we need to parse to a dictionary)
521
+ # 2. quantization config JSON
522
+ if isinstance(quantization_config, str):
523
+ quantization_config = quantization_config_file_path_to_dict(
524
+ quantization_config)
525
+ # NOTE: unpack the quantization config now so we don't need to keep doing this every time
526
+ vllm_config.additional_config[
527
+ "quantization"] = quantization_config
528
+ parse_qwix_config_to_rules(
529
+ quantization_config["qwix"]["rules"])
530
+ except Exception as e:
531
+ raise ValueError(
532
+ f"Invalid quantization config; please see README for details on quantization config: {e}"
533
+ )
534
+
535
+
536
+ def get_random_sharded_array(key: PRNGKey, mesh: Mesh, param: nnx.Param,
537
+ param_shape: tuple, dtype: jnp.dtype,
538
+ param_name: str) -> jax.Array:
539
+ """
540
+ Returns a random sharded array for the given parameter for the given shape.
541
+
542
+ Args:
543
+ key: The random key.
544
+ mesh: The mesh to use for sharding.
545
+ param: The parameter.
546
+ param_shape: The shape of the parameter.
547
+ dtype: The dtype of the parameter.
548
+ param_name: The name of the parameter.
549
+
550
+ Returns:
551
+ A random sharded array for the given parameter for the given shape.
552
+ """
553
+ is_int = jnp.issubdtype(dtype, jnp.integer)
554
+ if is_int:
555
+ # These need to be JAX arrays or else you'll run into an overflow error
556
+ minval = jnp.array(jnp.iinfo(dtype).min, dtype=dtype)
557
+ maxval = jnp.array(jnp.iinfo(dtype).max, dtype=dtype)
558
+ weight = jax.random.randint(key, param_shape, minval, maxval, dtype)
559
+ else:
560
+ # NOTE: _uniform() in random.py does not accept float4_e2m1fn
561
+ # Error: "TypeError: uniform only accepts 8-, 16-, 32-, or 64-bit dtypesgot float4_e2m1fn."
562
+ # Workaround: call function with dtype jnp.float8_e4m3fn and cast back to float4_e2m1fn
563
+ if dtype != "float4_e2m1fn":
564
+ weight = jax.random.normal(key, param_shape, dtype)
565
+ else:
566
+ weight = jax.random.normal(key, param_shape,
567
+ jnp.float8_e4m3fn).astype(dtype)
568
+
569
+ def get_slice(index):
570
+ return weight[index]
571
+
572
+ try:
573
+ sharded_array = jax.make_array_from_callback(
574
+ param_shape, NamedSharding(mesh, P(*param.sharding)), get_slice)
575
+ except (ValueError, TypeError):
576
+ logger.warning(
577
+ f"Could not create sharded scale for {param_name} with shape {param_shape} and sharding {param.sharding}, skipping sharding..."
578
+ )
579
+ sharded_array = jax.make_array_from_callback(param_shape,
580
+ NamedSharding(mesh, P()),
581
+ get_slice)
582
+
583
+ return sharded_array
584
+
585
+
586
+ def load_random_weights_into_qwix_abstract_model(rng: PRNGKey,
587
+ model: nnx.Module, mesh: Mesh,
588
+ quantization_config: dict):
589
+ """
590
+ Loads random weights for an abstract, Qwix-quantized model.
591
+
592
+ Args:
593
+ rng: The random key.
594
+ state: The state of the model.
595
+ mesh: The mesh.
596
+ model: The model.
597
+ quantization_config: The quantization config for the model.
598
+ """
599
+ logger.info("Initializing Qwix-quantized model with random weights...")
600
+ # TODO (jacobplatin): clean up this logic
601
+ scale_dtype = model.weight_loader.scale_dtype
602
+ scale_shape_map = model.weight_loader.scale_shape_map_for_random_weight_loading if hasattr(
603
+ model.weight_loader,
604
+ 'scale_shape_map_for_random_weight_loading') else {}
605
+ quantization_block_sizes = quantization_config["weight_block_size"]
606
+ assert len(
607
+ quantization_block_sizes
608
+ ) == 2, f"Expected only 2 quantization block sizes but got {quantization_block_sizes}"
609
+
610
+ # Iterate through all variables and initialize them
611
+
612
+ for path, param in nnx.iter_graph(model):
613
+ if not isinstance(param, nnx.Variable):
614
+ continue
615
+ if path[0] == 'rng' and path[-1] == "key":
616
+ param.value = rng
617
+ continue
618
+ is_qwix_scale = (path[-1] == 'scale' and path[-2] == "array")
619
+ param_dtype = scale_dtype if is_qwix_scale else param.value.dtype
620
+ param_shape = param.value.shape
621
+ if is_qwix_scale:
622
+ key = f"{path[2]}.{path[3]}"
623
+
624
+ if key in scale_shape_map:
625
+ param_shape = scale_shape_map[key]
626
+ else:
627
+ raise ValueError(
628
+ f"Scale shape for {key} not found in scale_shape_map.")
629
+ param.value = get_random_sharded_array(
630
+ rng, mesh, param, param_shape, param_dtype,
631
+ ".".join([str(x) for x in path]))
632
+
633
+ # Handles the DeepSeek case, where this needs to be called to make the cache weights
634
+ # concrete
635
+ if hasattr(model, 'initialize_cache'):
636
+ model.initialize_cache()
637
+ logger.info("Done initializing Qwix-quantized model with random weights")
638
+
639
+
640
+ def manually_quantize_qwix_weight(weight: jax.Array, qtype: jnp.dtype,
641
+ channelwise_axes: List[int],
642
+ tiled_axes: dict,
643
+ calibration_method: str) -> QArray:
644
+ """
645
+ Manually quantizes a weight tensor using Qwix. Only needed for the SparseMatmul DeepSeek case right now, since
646
+ otherwise, Qwix will handle this automatically (through our application of `qwix.quantize_model`).
647
+ """
648
+ # TODO (jacobplatin): clean this up; this is needed because of issues with Qwix quantizing the `shard_map` in SpraseMatmul
649
+ how_to_quantize = ptq.qarray.HowToQuantize(
650
+ qtype=qtype,
651
+ channelwise_axes=channelwise_axes,
652
+ tiled_axes=tiled_axes,
653
+ calibration_method=calibration_method)
654
+
655
+ return ptq.create_quantized_param(weight, how_to_quantize)
656
+
657
+
658
+ def manually_quantize_qwix_activation(inputs: jax.Array, rule_name: str,
659
+ qtype: jnp.dtype,
660
+ channelwise_axes: List[int],
661
+ tiled_axes: dict,
662
+ calibration_method: str) -> QArray:
663
+ """
664
+ Manually quantizes an activation tensor using Qwix. Needed for the SparseMatmul
665
+ DeepSeek MoE case currently.
666
+
667
+ Args:
668
+ inputs: The activation tensor to quantize.
669
+ rule_name: The name of the quantization rule to use.
670
+ qtype: The quantization type.
671
+ channelwise_axes: The channelwise axes to quantize.
672
+ tiled_axes: The tiled axes to quantize.
673
+ calibration_method: The calibration method to use.
674
+
675
+ Returns:
676
+ The quantized activation tensor.
677
+ """
678
+ rule = qpl.get_current_rule(rule_name)
679
+ lhs_how = ptq.qarray.HowToQuantize(qtype=qtype,
680
+ channelwise_axes=channelwise_axes,
681
+ tiled_axes=tiled_axes,
682
+ calibration_method=calibration_method)
683
+ # This is needed because we aren't passing `act_name` right now
684
+ assert not rule.act_static_scale, "Static scale not supported right now"
685
+
686
+ # channelwise_axes should be set to (a subset of) non-contraction axes. e.g.
687
+ # for ragged_dot [m, k] x [g, k, n], they are [0] and [0, 2]
688
+ # TODO (jacobplatin): add support for `act_name`
689
+ return ptq.quantize_act(inputs, lhs_how, rule, "")
690
+
691
+
692
+ def get_quant_dtype_from_qwix_config(
693
+ vllm_config: "VllmConfig") -> tuple[jnp.dtype, jnp.dtype]:
694
+ """
695
+ Gets the quantization dtype from the Qwix config.
696
+
697
+ Args:
698
+ vllm_config: The VllmConfig object.
699
+
700
+ Returns:
701
+ A tuple of the scale dtype and quant dtype.
702
+ """
703
+ qwix_config = vllm_config.additional_config.get("quantization",
704
+ {}).get("qwix", {})
705
+ scale_dtype = getattr(jnp, qwix_config.get("scale_dtype", "bfloat16"))
706
+ quant_dtype = None
707
+ # TODO (jacobplatin): this needs to be much more robust
708
+ for rule in qwix_config.get("rules", []):
709
+ if rule.get("module_path") == ".*":
710
+ quant_dtype_str = rule.get("weight_qtype", "")
711
+ assert quant_dtype_str, "Quantization dtype not found in Qwix config! We currently expect your Qwix config to have a rule with module_path '.*' and a weight_qtype."
712
+ quant_dtype = getattr(jnp, quant_dtype_str)
713
+ return scale_dtype, quant_dtype