tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,836 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ import unittest
4
+ from unittest.mock import MagicMock, mock_open, patch
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import qwix
9
+ from flax import nnx
10
+ from jax.sharding import Mesh, NamedSharding
11
+ from jax.sharding import PartitionSpec as P
12
+ from qwix._src.providers import ptq
13
+
14
+ import tpu_inference.models.jax.utils.quantization.quantization_utils as quantize_qwix # noqa: E402
15
+ from tpu_inference.models.common.model_loader import apply_qwix_quantization
16
+ from tpu_inference.models.jax.utils.quantization.quantization_utils import (
17
+ DEFAULT_MAX_NUM_BLOCKS_PER_REQ, DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS,
18
+ DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS)
19
+
20
+ mock_nnx = MagicMock()
21
+ mock_jax = MagicMock()
22
+
23
+ module_mocks = {
24
+ 'flax': MagicMock(nnx=mock_nnx),
25
+ 'flax.nnx': mock_nnx,
26
+ 'jax': mock_jax,
27
+ 'jax.sharding': MagicMock(),
28
+ 'vllm': MagicMock(),
29
+ 'vllm.config': MagicMock(),
30
+ 'tpu_inference': MagicMock(),
31
+ 'tpu_inference.logger': MagicMock(init_logger=lambda name: MagicMock()),
32
+ 'tpu_inference.models.jax.utils.quantization.quantization_utils':
33
+ MagicMock(),
34
+ }
35
+
36
+
37
+ class TestParseQwixConfigToRules(unittest.TestCase):
38
+ """Tests for the parse_qwix_config_to_rules function."""
39
+
40
+ def test_empty_config(self):
41
+ """Test parsing an empty list of rules."""
42
+ qwix_config = []
43
+ rules = quantize_qwix.parse_qwix_config_to_rules(qwix_config)
44
+ self.assertEqual(rules, [])
45
+
46
+ def test_single_rule(self):
47
+ """Test parsing a single quantization rule."""
48
+ qwix_config = [{
49
+ "module_path": ".*attn.*",
50
+ "weight_qtype": "int8",
51
+ }]
52
+ rules = quantize_qwix.parse_qwix_config_to_rules(qwix_config)
53
+ self.assertEqual(len(rules), 1)
54
+ self.assertIsInstance(rules[0], qwix.QuantizationRule)
55
+ self.assertEqual(rules[0].module_path, ".*attn.*")
56
+ self.assertEqual(rules[0].weight_qtype, "int8")
57
+ self.assertIsNone(rules[0].act_qtype)
58
+
59
+ def test_multiple_rules(self):
60
+ """Test parsing multiple quantization rules."""
61
+ qwix_config = [
62
+ {
63
+ "module_path": ".*attn.*",
64
+ "weight_qtype": "int8",
65
+ },
66
+ {
67
+ "module_path": ".*mlp.*",
68
+ "weight_qtype": "int4",
69
+ "act_qtype": "int8",
70
+ },
71
+ ]
72
+ rules = quantize_qwix.parse_qwix_config_to_rules(qwix_config)
73
+ self.assertEqual(len(rules), 2)
74
+ self.assertIsInstance(rules[0], qwix.QuantizationRule)
75
+ self.assertIsInstance(rules[1], qwix.QuantizationRule)
76
+ self.assertEqual(rules[0].module_path, ".*attn.*")
77
+ self.assertEqual(rules[1].module_path, ".*mlp.*")
78
+ self.assertEqual(rules[1].weight_qtype, "int4")
79
+ self.assertEqual(rules[1].act_qtype, "int8")
80
+
81
+ def test_invalid_rule_key_raises_error(self):
82
+ """Test that an invalid key in a rule raises a TypeError."""
83
+ qwix_config = [{
84
+ "module_path": ".*attn.*",
85
+ "invalid_key": "some_value",
86
+ }]
87
+ with self.assertRaises(TypeError):
88
+ # qwix.QuantizationRule constructor will raise this error
89
+ quantize_qwix.parse_qwix_config_to_rules(qwix_config)
90
+
91
+
92
+ # A simple NNX module for testing quantization
93
+ class SimpleModel(nnx.Module):
94
+
95
+ def __init__(self, *, rngs: nnx.Rngs):
96
+ self.linear = nnx.Linear(10, 20, rngs=rngs)
97
+
98
+ def __call__(self, **kwargs):
99
+ # A simplified call signature for testing purposes
100
+ return self.linear(kwargs['input_ids'])
101
+
102
+
103
+ @patch('qwix.quantize_model', autospec=True)
104
+ class TestQwixQuantizeNnxModel(unittest.TestCase):
105
+ """Tests for the qwix_quantize_nnx_model function."""
106
+
107
+ def setUp(self):
108
+ """Set up a mock environment for testing."""
109
+ if not jax.devices():
110
+ self.skipTest(
111
+ "JAX device not found, skipping JAX-dependent tests.")
112
+ self.mesh = Mesh(jax.devices(), ('model', ))
113
+ self.rng = jax.random.PRNGKey(0)
114
+ self.model = SimpleModel(rngs=nnx.Rngs(0))
115
+
116
+ self.qwix_config = [
117
+ {
118
+ "module_path": ".*linear.*",
119
+ "weight_qtype": "int8",
120
+ },
121
+ ]
122
+
123
+ self.num_hidden_layers = 1
124
+ self.kv_cache_block_size = 16
125
+ self.kv_cache_num_kv_heads = 4
126
+ self.kv_cache_head_size = 64
127
+
128
+ self.mock_kv_caches = [MagicMock(), MagicMock()]
129
+
130
+ def test_quantization_call_with_correct_args(self, mock_quantize_model):
131
+ """Test that qwix.quantize_model is called with the correct arguments."""
132
+ quantized_model_mock = MagicMock(spec=nnx.Module)
133
+ mock_quantize_model.return_value = quantized_model_mock
134
+
135
+ with patch(
136
+ "tpu_inference.models.jax.utils.quantization.quantization_utils.init_logger",
137
+ return_value=MagicMock()
138
+ ), patch(
139
+ "tpu_inference.utils.hbm_usage_gb",
140
+ return_value=[(0.0, 0.0), (0.0, 0.0)]
141
+ ), patch(
142
+ "tpu_inference.models.jax.utils.quantization.quantization_utils.create_kv_caches",
143
+ return_value=self.mock_kv_caches
144
+ ), patch(
145
+ "tpu_inference.models.jax.utils.quantization.quantization_utils.quantization_config_file_path_to_dict",
146
+ return_value=self.qwix_config):
147
+ returned_model = quantize_qwix.qwix_quantize_nnx_model(
148
+ model=self.model,
149
+ qwix_config=self.qwix_config,
150
+ rng=self.rng,
151
+ mesh=self.mesh,
152
+ num_hidden_layers=self.num_hidden_layers,
153
+ kv_cache_block_size=self.kv_cache_block_size,
154
+ kv_cache_num_kv_heads=self.kv_cache_num_kv_heads,
155
+ kv_cache_head_size=self.kv_cache_head_size,
156
+ kv_cache_dtype="auto")
157
+
158
+ self.assertIs(returned_model, quantized_model_mock)
159
+ mock_quantize_model.assert_called_once()
160
+ args, kwargs = mock_quantize_model.call_args
161
+
162
+ # Assert positional arguments for qwix.quantize_model
163
+ self.assertIs(args[0], self.model)
164
+ self.assertIsInstance(args[1], qwix.PtqProvider)
165
+
166
+ # Assert keyword arguments (model inputs for tracing)
167
+ self.assertIn("kv_caches", kwargs)
168
+ self.assertEqual(kwargs["kv_caches"], self.mock_kv_caches)
169
+ self.assertIn("input_ids", kwargs)
170
+ self.assertEqual(kwargs["input_ids"].shape, (512, ))
171
+ self.assertIn("attention_metadata", kwargs)
172
+ attention_metadata = kwargs["attention_metadata"]
173
+
174
+ assert attention_metadata.input_positions.shape == (
175
+ DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS, )
176
+ assert attention_metadata.block_tables.shape == (
177
+ DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS *
178
+ DEFAULT_MAX_NUM_BLOCKS_PER_REQ, )
179
+ assert attention_metadata.seq_lens.shape == (
180
+ DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS, )
181
+ assert attention_metadata.query_start_loc.shape == (
182
+ DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS + 1, )
183
+ assert attention_metadata.request_distribution.shape == (3, )
184
+
185
+
186
+ @patch.dict('sys.modules', module_mocks)
187
+ class TestApplyQwixQuantization(unittest.TestCase):
188
+
189
+ def setUp(self):
190
+ """Set up common mock objects for all tests in this suite."""
191
+ mock_nnx.reset_mock()
192
+ mock_jax.reset_mock()
193
+
194
+ self.mock_vllm_config = MagicMock()
195
+ self.mock_vllm_config.additional_config = {}
196
+ self.mock_vllm_config.cache_config.block_size = 16
197
+ self.mock_vllm_config.model_config.get_head_size.return_value = 128
198
+ self.mock_vllm_config.model_config.get_total_num_kv_heads.return_value = 8
199
+ self.mock_vllm_config.model_config.hf_config.num_hidden_layers = 32
200
+
201
+ self.mock_model = MagicMock(name="original_nnx_model",
202
+ spec_set=nnx.Module)
203
+ self.mock_rng = MagicMock(name="mock_rng")
204
+ self.mock_mesh = MagicMock(name="mock_mesh")
205
+
206
+ def test_no_quantization_config(self):
207
+ """
208
+ Test that the model is returned unchanged if no 'quantization' key exists.
209
+ """
210
+ result = apply_qwix_quantization(self.mock_vllm_config,
211
+ self.mock_model,
212
+ self.mock_rng,
213
+ self.mock_mesh,
214
+ apply_to_abstract_model=False)
215
+
216
+ self.assertIs(result, self.mock_model,
217
+ "Model should be returned as-is.")
218
+ mock_nnx.jit.assert_not_called()
219
+
220
+ @patch('tpu_inference.models.common.model_loader.nnx.jit')
221
+ def test_quantization_applied_from_dict(self, mock_jit):
222
+ """
223
+ Test that quantization is applied correctly when the config is a dictionary.
224
+ """
225
+ qwix_rules = {"weights": "int8", "activations": None}
226
+ self.mock_vllm_config.additional_config = {
227
+ "quantization": {
228
+ "qwix": {
229
+ "rules": qwix_rules
230
+ }
231
+ }
232
+ }
233
+
234
+ with patch('tpu_inference.utils.get_padded_num_heads',
235
+ return_value=128):
236
+ apply_qwix_quantization(self.mock_vllm_config,
237
+ self.mock_model,
238
+ self.mock_rng,
239
+ self.mock_mesh,
240
+ apply_to_abstract_model=False)
241
+ mock_jit.assert_called_once()
242
+
243
+
244
+ class TestQuantizationConfigFileToDict(unittest.TestCase):
245
+ """Tests for the quantization_config_file_path_to_dict function."""
246
+
247
+ @patch("os.listdir")
248
+ @patch("os.path.join")
249
+ def test_file_not_found_raises_value_error(self, mock_join, mock_listdir):
250
+ """Test that a ValueError is raised if the config file is not found."""
251
+ mock_listdir.return_value = ["another_file.yaml", "config.txt"]
252
+ config_file_path = "non_existent.yaml"
253
+
254
+ with self.assertRaisesRegex(
255
+ ValueError,
256
+ f"Could not find quantization config file with name '{config_file_path}'"
257
+ ):
258
+ quantize_qwix.quantization_config_file_path_to_dict(
259
+ config_file_path)
260
+ mock_listdir.assert_called_once_with(
261
+ quantize_qwix.QUANTIZATION_CONFIG_PATH)
262
+
263
+ @patch("os.listdir")
264
+ @patch("os.path.join")
265
+ @patch("builtins.open",
266
+ new_callable=mock_open,
267
+ read_data="qwix:\n rules: []")
268
+ def test_file_found_and_loaded_successfully(self, mock_file, mock_join,
269
+ mock_listdir):
270
+ """Test that the YAML file is correctly loaded when found."""
271
+ config_filename = "my_quant_config.yaml"
272
+ mock_listdir.return_value = ["another.yaml", config_filename]
273
+ mock_join.return_value = f"/fake/path/{config_filename}"
274
+ expected_dict = {"qwix": {"rules": []}}
275
+
276
+ result = quantize_qwix.quantization_config_file_path_to_dict(
277
+ config_filename)
278
+
279
+ mock_listdir.assert_called_once_with(
280
+ quantize_qwix.QUANTIZATION_CONFIG_PATH)
281
+ mock_join.assert_called_once_with(
282
+ quantize_qwix.QUANTIZATION_CONFIG_PATH, config_filename)
283
+ mock_file.assert_called_once_with(f"/fake/path/{config_filename}", "r")
284
+ self.assertEqual(result, expected_dict)
285
+
286
+
287
+ class TestApplyQwixQuantizationLogic(unittest.TestCase):
288
+ """Tests the core logic of apply_qwix_quantization."""
289
+
290
+ def setUp(self):
291
+ self.mock_vllm_config = MagicMock()
292
+ self.mock_vllm_config.additional_config = {}
293
+ self.mock_vllm_config.cache_config.block_size = 16
294
+ self.mock_vllm_config.model_config.get_head_size.return_value = 128
295
+ self.mock_vllm_config.model_config.get_total_num_kv_heads.return_value = 8
296
+ self.mock_vllm_config.model_config.hf_config.num_hidden_layers = 32
297
+ self.mock_model = MagicMock(name="original_nnx_model")
298
+ self.mock_rng = MagicMock(name="mock_rng")
299
+ self.mock_mesh = MagicMock(name="mock_mesh", shape={"model": 1})
300
+
301
+ def test_quantization_config_without_qwix_rules(self):
302
+ """Test model is unchanged if the config lacks 'qwix' or 'rules'."""
303
+ self.mock_vllm_config.additional_config = {"quantization": {}}
304
+ result1 = quantize_qwix.apply_qwix_quantization(
305
+ self.mock_vllm_config, self.mock_model, self.mock_rng,
306
+ self.mock_mesh, False)
307
+ self.assertIs(result1, self.mock_model)
308
+
309
+ self.mock_vllm_config.additional_config = {
310
+ "quantization": {
311
+ "qwix": {}
312
+ }
313
+ }
314
+ result2 = quantize_qwix.apply_qwix_quantization(
315
+ self.mock_vllm_config, self.mock_model, self.mock_rng,
316
+ self.mock_mesh, False)
317
+ self.assertIs(result2, self.mock_model)
318
+
319
+ @patch(
320
+ 'tpu_inference.models.jax.utils.quantization.quantization_utils.qwix_quantize_nnx_model'
321
+ )
322
+ @patch(
323
+ 'tpu_inference.models.jax.utils.quantization.quantization_utils.utils')
324
+ def test_apply_to_abstract_model(self, mock_utils, mock_quantize_func):
325
+ """Test quantization is correctly applied to an abstract model factory."""
326
+ mock_utils.get_padded_num_heads.return_value = 8
327
+ mock_utils.get_padded_head_dim.return_value = 128
328
+ qwix_rules = [{"module_path": ".*", "weight_qtype": "int8"}]
329
+ self.mock_vllm_config.additional_config = {
330
+ "quantization": {
331
+ "qwix": {
332
+ "rules": qwix_rules
333
+ }
334
+ }
335
+ }
336
+ mock_abstract_model = MagicMock(name="abstract_model")
337
+ mock_model_fn = MagicMock(name="model_factory",
338
+ return_value=mock_abstract_model)
339
+ quantized_model = MagicMock(name="quantized_model")
340
+ mock_quantize_func.return_value = quantized_model
341
+
342
+ model_factory = quantize_qwix.apply_qwix_quantization(
343
+ self.mock_vllm_config,
344
+ mock_model_fn,
345
+ self.mock_rng,
346
+ self.mock_mesh,
347
+ apply_to_abstract_model=True)
348
+
349
+ self.assertTrue(callable(model_factory))
350
+ result_model = model_factory()
351
+
352
+ mock_model_fn.assert_called_once()
353
+ mock_quantize_func.assert_called_once()
354
+ call_kwargs = mock_quantize_func.call_args.kwargs
355
+ self.assertIs(call_kwargs['model'], mock_abstract_model)
356
+ self.assertIs(call_kwargs['rng'], self.mock_rng)
357
+ self.assertIs(result_model, quantized_model)
358
+
359
+ @patch(
360
+ 'tpu_inference.models.jax.utils.quantization.quantization_utils.qwix_quantize_nnx_model'
361
+ )
362
+ @patch(
363
+ 'tpu_inference.models.jax.utils.quantization.quantization_utils.utils')
364
+ def test_apply_to_abstract_model_with_initialize_cache(
365
+ self, mock_utils, mock_quantize_func):
366
+ """Test abstract model quantization with 'initialize_cache' method."""
367
+ mock_utils.get_padded_num_heads.return_value = 8
368
+ mock_utils.get_padded_head_dim.return_value = 128
369
+ qwix_rules = [{"module_path": ".*", "weight_qtype": "int8"}]
370
+ self.mock_vllm_config.additional_config = {
371
+ "quantization": {
372
+ "qwix": {
373
+ "rules": qwix_rules
374
+ }
375
+ }
376
+ }
377
+ mock_abstract_model = MagicMock(name="abstract_model")
378
+ mock_abstract_model.initialize_cache = MagicMock()
379
+ mock_model_fn = MagicMock(name="model_factory",
380
+ return_value=mock_abstract_model)
381
+
382
+ model_factory = quantize_qwix.apply_qwix_quantization(
383
+ self.mock_vllm_config,
384
+ mock_model_fn,
385
+ self.mock_rng,
386
+ self.mock_mesh,
387
+ apply_to_abstract_model=True)
388
+
389
+ model_factory()
390
+
391
+ mock_abstract_model.initialize_cache.assert_called_once()
392
+ mock_quantize_func.assert_called_once()
393
+
394
+
395
+ class TestDetermineWhetherToApplyQwixOnAbstractModel(unittest.TestCase):
396
+ """Tests for apply_qwix_on_abstract_model."""
397
+
398
+ def setUp(self):
399
+ self.mock_vllm_config = MagicMock()
400
+ self.mock_vllm_config.additional_config = {
401
+ "quantization": {
402
+ "qwix": {
403
+ "use_abstract_model": True,
404
+ "rules": [{
405
+ "module_path": ".*",
406
+ "weight_qtype": "int8"
407
+ }]
408
+ }
409
+ }
410
+ }
411
+
412
+ self.mock_vllm_config_no_abstract_model = MagicMock()
413
+ self.mock_vllm_config_no_abstract_model.additional_config = {
414
+ "quantization": {
415
+ "qwix": {
416
+ "rules": [{
417
+ "module_path": ".*",
418
+ "weight_qtype": "int8"
419
+ }]
420
+ }
421
+ }
422
+ }
423
+
424
+ self.mock_vllm_config_no_additional_config = MagicMock()
425
+ self.mock_vllm_config_no_additional_config.additional_config = {}
426
+
427
+ def test_returns_false_when_additional_config_is_missing(self):
428
+ """Test it returns False when additional_config is missing."""
429
+ result = quantize_qwix.apply_qwix_on_abstract_model(
430
+ self.mock_vllm_config_no_additional_config)
431
+ self.assertFalse(result)
432
+
433
+ def test_returns_true_when_additional_config_is_present(self):
434
+ """Test it returns False when additional_config is missing."""
435
+ result = quantize_qwix.apply_qwix_on_abstract_model(
436
+ self.mock_vllm_config)
437
+ self.assertTrue(result)
438
+
439
+ def test_returns_false_when_use_abstract_model_is_false(self):
440
+ """Test it returns False when use_abstract_model is False."""
441
+ result = quantize_qwix.apply_qwix_on_abstract_model(
442
+ self.mock_vllm_config_no_abstract_model)
443
+ self.assertFalse(result)
444
+
445
+
446
+ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
447
+ """Tests for the load_random_weights_into_qwix_abstract_model function."""
448
+
449
+ def setUp(self):
450
+ """Set up a mock environment for testing."""
451
+ if not jax.devices():
452
+ self.skipTest(
453
+ "JAX device not found, skipping JAX-dependent tests.")
454
+
455
+ self.rng = jax.random.PRNGKey(0)
456
+ self.mesh = Mesh(jax.devices(), ('data', ))
457
+ self.quantization_config = {
458
+ "weight_block_size": [64, 1],
459
+ }
460
+
461
+ # Mock model structure
462
+ self.model = MagicMock(spec=['weight_loader', 'initialize_cache'])
463
+ self.model.weight_loader = MagicMock(
464
+ spec=['scale_dtype', 'scale_shap_map_for_random_weight_loading'])
465
+ self.model.weight_loader.scale_dtype = jnp.float16
466
+ self.model.weight_loader.scale_shap_map_for_random_weight_loading = {}
467
+
468
+ @patch(
469
+ 'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
470
+ )
471
+ @patch(
472
+ 'tpu_inference.models.jax.utils.quantization.quantization_utils.get_random_sharded_array'
473
+ )
474
+ def test_successful_initialization(self, mock_get_random_array,
475
+ mock_iter_graph):
476
+ """Test that variables are correctly initialized."""
477
+ # Setup mock graph elements
478
+ mock_weight_param = nnx.Param(jnp.empty((128, 64), dtype=jnp.int8),
479
+ sharding=P('data', None))
480
+ mock_scale_var = nnx.Variable(jnp.empty((1, 1), dtype=jnp.float16))
481
+ mock_rng_var = nnx.Variable(jax.random.PRNGKey(0))
482
+ mock_random_array = jax.numpy.ones(1)
483
+ mock_get_random_array.return_value = mock_random_array
484
+
485
+ mock_iter_graph.return_value = [
486
+ (('layers', '0', 'attention', 'wq', 'kernel'), mock_weight_param),
487
+ (('layers', '0', 'attention', 'wq', 'array', 'scale'),
488
+ mock_scale_var),
489
+ (('rng', 'params', 'key'), mock_rng_var),
490
+ ]
491
+
492
+ quantize_qwix.load_random_weights_into_qwix_abstract_model(
493
+ self.rng, self.model, self.mesh, self.quantization_config)
494
+
495
+ # Assert weight is updated
496
+ self.assertIs(mock_weight_param.value, mock_random_array)
497
+ # Assert scale is updated
498
+ self.assertIs(mock_scale_var.value, mock_random_array)
499
+ # Assert RNG key is updated with the passed-in RNG
500
+ self.assertIs(mock_rng_var.value, self.rng)
501
+ # Assert initialize_cache is called
502
+ self.model.initialize_cache.assert_called_once()
503
+
504
+ def test_invalid_config_raises_assertion_error(self):
505
+ """Test that an invalid quantization_block_sizes config raises an error."""
506
+ invalid_config = {"weight_block_size": [64]} # Length is 1, not 2
507
+ with self.assertRaisesRegex(AssertionError,
508
+ "Expected only 2 quantization block"):
509
+ quantize_qwix.load_random_weights_into_qwix_abstract_model(
510
+ self.rng, self.model, self.mesh, invalid_config)
511
+
512
+ @patch(
513
+ 'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
514
+ )
515
+ def test_param_shape_setting_no_scale_map(self, mock_iter_graph):
516
+ """Test correct scale shape calculation when not in the map."""
517
+ old_weight_param_val = jnp.empty((128, 64))
518
+ mock_weight_param = nnx.Param(old_weight_param_val, dtype=jnp.int8)
519
+ old_scale_var_val = jnp.empty((0, 0))
520
+ mock_scale_var = nnx.Variable(old_scale_var_val)
521
+
522
+ mock_iter_graph.return_value = [
523
+ (('layers', '0', 'attention', 'wq', 'kernel'), mock_weight_param),
524
+ (('layers', '0', 'attention', 'wq', 'array', 'scale'),
525
+ mock_scale_var),
526
+ ]
527
+
528
+ quantize_qwix.load_random_weights_into_qwix_abstract_model(
529
+ self.rng, self.model, self.mesh, self.quantization_config)
530
+
531
+ new_weight_param_val = mock_weight_param.value
532
+ new_scale_var_val = mock_scale_var.value
533
+
534
+ expected_scale_shape = (128 // 64, 64 // 64)
535
+ actual_scale_shape = new_scale_var_val.shape
536
+
537
+ expected_weight_shape = (128, 64)
538
+ actual_weight_shape = new_weight_param_val.shape
539
+
540
+ self.assertEqual(expected_scale_shape, actual_scale_shape)
541
+ self.assertEqual(expected_weight_shape, actual_weight_shape)
542
+ self.assertNotEqual(old_scale_var_val.shape, new_scale_var_val.shape)
543
+ assert jnp.not_equal(old_weight_param_val, new_weight_param_val).all()
544
+
545
+ @patch(
546
+ 'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
547
+ )
548
+ def test_param_shape_setting_with_scale_map(self, mock_iter_graph):
549
+ """Test correct scale shape calculation when in the map."""
550
+ old_weight_param_val = jnp.empty((128, 64))
551
+ mock_weight_param = nnx.Param(old_weight_param_val, dtype=jnp.int8)
552
+ old_scale_var_val = jnp.empty((0, 0))
553
+ mock_scale_var = nnx.Variable(old_scale_var_val)
554
+
555
+ expected_scale_shape = (55, 34)
556
+
557
+ self.model.weight_loader.scale_shap_map_for_random_weight_loading = {
558
+ 'wq': expected_scale_shape
559
+ }
560
+
561
+ mock_iter_graph.return_value = [
562
+ (('layers', '0', 'attention', 'wq', 'kernel'), mock_weight_param),
563
+ (('layers', '0', 'attention', 'wq', 'array', 'scale'),
564
+ mock_scale_var),
565
+ ]
566
+
567
+ quantize_qwix.load_random_weights_into_qwix_abstract_model(
568
+ self.rng, self.model, self.mesh, self.quantization_config)
569
+
570
+ new_weight_param_val = mock_weight_param.value
571
+ new_scale_var_val = mock_scale_var.value
572
+
573
+ actual_scale_shape = new_scale_var_val.shape
574
+
575
+ expected_weight_shape = (128, 64)
576
+ actual_weight_shape = new_weight_param_val.shape
577
+
578
+ self.assertEqual(expected_scale_shape, actual_scale_shape)
579
+ self.assertEqual(expected_weight_shape, actual_weight_shape)
580
+ self.assertNotEqual(old_scale_var_val.shape, new_scale_var_val.shape)
581
+ assert jnp.not_equal(old_weight_param_val, new_weight_param_val).all()
582
+
583
+ @patch('jax.random.randint')
584
+ @patch('jax.random.normal')
585
+ @patch('jax.make_array_from_callback')
586
+ def test_get_random_sharded_array_dtype_dispatch(self, mock_make_array,
587
+ mock_normal,
588
+ mock_randint):
589
+ """Test that integer dtypes call randint and floats call normal."""
590
+ # Test integer
591
+ quantize_qwix.get_random_sharded_array(
592
+ self.rng, self.mesh, nnx.Param(jnp.empty((2, 2)), sharding=P()),
593
+ (2, 2), jnp.int8, "int_param")
594
+ mock_randint.assert_called_once()
595
+ mock_normal.assert_not_called()
596
+
597
+ mock_randint.reset_mock()
598
+ mock_normal.reset_mock()
599
+
600
+ # Test float
601
+ quantize_qwix.get_random_sharded_array(
602
+ self.rng, self.mesh, nnx.Param(jnp.empty((2, 2)), sharding=P()),
603
+ (2, 2), jnp.float32, "float_param")
604
+ mock_randint.assert_not_called()
605
+ mock_normal.assert_called_once()
606
+
607
+ @patch(
608
+ "tpu_inference.models.jax.utils.quantization.quantization_utils.logger.warning"
609
+ )
610
+ @patch("jax.make_array_from_callback")
611
+ def test_get_random_sharded_array_sharding_fallback(
612
+ self, mock_make_array, mock_logger_warning):
613
+ """Test that sharding failure logs a warning and uses a fallback."""
614
+ # First call raises an error, second call (fallback) succeeds
615
+ mock_make_array.side_effect = [
616
+ ValueError("Sharding failed"),
617
+ MagicMock()
618
+ ]
619
+
620
+ param = nnx.Param(jnp.empty((2, 2)), sharding=P('data', None))
621
+ quantize_qwix.get_random_sharded_array(self.rng, self.mesh, param,
622
+ (2, 2), jnp.float32,
623
+ "test_param")
624
+
625
+ # Check that a warning was logged
626
+ mock_logger_warning.assert_called_once()
627
+ self.assertIn("Could not create sharded scale for test_param",
628
+ mock_logger_warning.call_args[0][0])
629
+
630
+ # Check that the fallback was attempted with an empty PartitionSpec
631
+ fallback_call_args = mock_make_array.call_args_list[1]
632
+ fallback_sharding = fallback_call_args.args[1]
633
+ self.assertEqual(fallback_sharding, NamedSharding(self.mesh, P()))
634
+
635
+
636
+ class TestManualQwixQuantization(unittest.TestCase):
637
+ """Tests for manual Qwix quantization functions."""
638
+
639
+ def setUp(self):
640
+ if not jax.devices():
641
+ self.skipTest(
642
+ "JAX device not found, skipping JAX-dependent tests.")
643
+ self.weight = jnp.ones((4, 4))
644
+ self.inputs = jnp.ones((8, 4))
645
+ self.qtype = jnp.int8
646
+ self.channelwise_axes = [0]
647
+ self.tiled_axes = {}
648
+ self.calibration_method = 'max'
649
+
650
+ @patch(
651
+ 'tpu_inference.models.jax.utils.quantization.quantization_utils.ptq.create_quantized_param'
652
+ )
653
+ def test_manually_quantize_qwix_weight(self, mock_create_param):
654
+ """Test that manually_quantize_qwix_weight calls ptq.create_quantized_param correctly."""
655
+ quantize_qwix.manually_quantize_qwix_weight(
656
+ weight=self.weight,
657
+ qtype=self.qtype,
658
+ channelwise_axes=self.channelwise_axes,
659
+ tiled_axes=self.tiled_axes,
660
+ calibration_method=self.calibration_method)
661
+
662
+ mock_create_param.assert_called_once()
663
+ args, _ = mock_create_param.call_args
664
+ passed_weight, passed_how_to_quantize = args
665
+
666
+ self.assertTrue(jnp.array_equal(passed_weight, self.weight))
667
+ self.assertIsInstance(passed_how_to_quantize, ptq.qarray.HowToQuantize)
668
+ self.assertEqual(passed_how_to_quantize.qtype, self.qtype)
669
+ self.assertEqual(passed_how_to_quantize.channelwise_axes,
670
+ self.channelwise_axes)
671
+ self.assertEqual(passed_how_to_quantize.tiled_axes, self.tiled_axes)
672
+ self.assertEqual(passed_how_to_quantize.calibration_method,
673
+ self.calibration_method)
674
+
675
+ @patch(
676
+ 'tpu_inference.models.jax.utils.quantization.quantization_utils.ptq.quantize_act'
677
+ )
678
+ @patch('qwix.pallas.get_current_rule')
679
+ def test_manually_quantize_qwix_activation(self, mock_get_rule,
680
+ mock_quantize_act):
681
+ """Test that manually_quantize_qwix_activation calls ptq.quantize_act correctly."""
682
+ mock_rule = MagicMock()
683
+ mock_rule.act_static_scale = False
684
+ mock_get_rule.return_value = mock_rule
685
+ rule_name = "test_rule"
686
+
687
+ quantize_qwix.manually_quantize_qwix_activation(
688
+ inputs=self.inputs,
689
+ rule_name=rule_name,
690
+ qtype=self.qtype,
691
+ channelwise_axes=self.channelwise_axes,
692
+ tiled_axes=self.tiled_axes,
693
+ calibration_method=self.calibration_method)
694
+
695
+ mock_get_rule.assert_called_once_with(rule_name)
696
+ mock_quantize_act.assert_called_once()
697
+
698
+ args, _ = mock_quantize_act.call_args
699
+ passed_inputs, passed_how, passed_rule, passed_act_name = args
700
+
701
+ self.assertTrue(jnp.array_equal(passed_inputs, self.inputs))
702
+ self.assertIsInstance(passed_how, ptq.qarray.HowToQuantize)
703
+ self.assertEqual(passed_how.qtype, self.qtype)
704
+ self.assertEqual(passed_how.channelwise_axes, self.channelwise_axes)
705
+ self.assertEqual(passed_how.tiled_axes, self.tiled_axes)
706
+ self.assertEqual(passed_how.calibration_method,
707
+ self.calibration_method)
708
+ self.assertIs(passed_rule, mock_rule)
709
+ self.assertEqual(passed_act_name, "") # act_name is hardcoded to ""
710
+
711
+ @patch('qwix.pallas.get_current_rule')
712
+ def test_manually_quantize_qwix_activation_static_scale_raises_error(
713
+ self, mock_get_rule):
714
+ """Test that an assertion is raised if the rule has static scale."""
715
+ mock_rule = MagicMock()
716
+ mock_rule.act_static_scale = True
717
+ mock_get_rule.return_value = mock_rule
718
+
719
+ with self.assertRaisesRegex(AssertionError,
720
+ "Static scale not supported right now"):
721
+ quantize_qwix.manually_quantize_qwix_activation(
722
+ inputs=self.inputs,
723
+ rule_name="any_rule",
724
+ qtype=self.qtype,
725
+ channelwise_axes=self.channelwise_axes,
726
+ tiled_axes=self.tiled_axes,
727
+ calibration_method=self.calibration_method)
728
+
729
+
730
+ class TestGetQuantDtypeFromQwixConfig(unittest.TestCase):
731
+ """Tests for the get_quant_dtype_from_qwix_config function."""
732
+
733
+ def setUp(self):
734
+ self.mock_vllm_config = MagicMock()
735
+ self.mock_vllm_config.additional_config = {}
736
+
737
+ def test_get_quant_dtype_success(self):
738
+ """Test successful extraction of dtypes from a valid config."""
739
+ self.mock_vllm_config.additional_config = {
740
+ "quantization": {
741
+ "qwix": {
742
+ "scale_dtype":
743
+ "float16",
744
+ "rules": [
745
+ {
746
+ "module_path": ".*mlp.*",
747
+ "weight_qtype": "int4"
748
+ },
749
+ {
750
+ "module_path": ".*",
751
+ "weight_qtype": "int8"
752
+ },
753
+ ],
754
+ }
755
+ }
756
+ }
757
+ scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
758
+ self.mock_vllm_config)
759
+ self.assertEqual(scale_dtype, jnp.float16)
760
+ self.assertEqual(quant_dtype, jnp.int8)
761
+
762
+ def test_get_quant_dtype_default_scale(self):
763
+ """Test that scale_dtype defaults to bfloat16 when not specified."""
764
+ self.mock_vllm_config.additional_config = {
765
+ "quantization": {
766
+ "qwix": {
767
+ "rules": [{
768
+ "module_path": ".*",
769
+ "weight_qtype": "int8"
770
+ }]
771
+ }
772
+ }
773
+ }
774
+ scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
775
+ self.mock_vllm_config)
776
+ self.assertEqual(scale_dtype, jnp.bfloat16)
777
+ self.assertEqual(quant_dtype, jnp.int8)
778
+
779
+ def test_no_quantization_config_returns_defaults(self):
780
+ """Test that default dtypes are returned when config is missing."""
781
+ self.mock_vllm_config.additional_config = {}
782
+ scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
783
+ self.mock_vllm_config)
784
+ self.assertEqual(scale_dtype, jnp.bfloat16)
785
+ self.assertIsNone(quant_dtype)
786
+
787
+ def test_get_quant_dtype_no_wildcard_rule_returns_none(self):
788
+ """Test that quant_dtype is None if no wildcard rule is found."""
789
+ self.mock_vllm_config.additional_config = {
790
+ "quantization": {
791
+ "qwix": {
792
+ "rules": [{
793
+ "module_path": ".*mlp.*",
794
+ "weight_qtype": "int4"
795
+ }]
796
+ }
797
+ }
798
+ }
799
+ scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
800
+ self.mock_vllm_config)
801
+ self.assertEqual(scale_dtype, jnp.bfloat16)
802
+ self.assertIsNone(quant_dtype)
803
+
804
+ def test_get_quant_dtype_wildcard_rule_missing_qtype_raises_error(self):
805
+ """Test that an assertion is raised if the wildcard rule is missing weight_qtype."""
806
+ self.mock_vllm_config.additional_config = {
807
+ "quantization": {
808
+ "qwix": {
809
+ "rules": [{
810
+ "module_path": ".*"
811
+ }]
812
+ }
813
+ }
814
+ }
815
+ with self.assertRaisesRegex(AssertionError,
816
+ "Quantization dtype not found"):
817
+ quantize_qwix.get_quant_dtype_from_qwix_config(
818
+ self.mock_vllm_config)
819
+
820
+ def test_get_quant_dtype_no_rules_key_returns_none(self):
821
+ """Test that quant_dtype is None if 'rules' key is missing."""
822
+ self.mock_vllm_config.additional_config = {
823
+ "quantization": {
824
+ "qwix": {
825
+ "scale_dtype": "float16",
826
+ }
827
+ }
828
+ }
829
+ scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
830
+ self.mock_vllm_config)
831
+ self.assertEqual(scale_dtype, jnp.float16)
832
+ self.assertIsNone(quant_dtype)
833
+
834
+
835
+ if __name__ == '__main__':
836
+ unittest.main()