tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,836 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
import unittest
|
|
4
|
+
from unittest.mock import MagicMock, mock_open, patch
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import qwix
|
|
9
|
+
from flax import nnx
|
|
10
|
+
from jax.sharding import Mesh, NamedSharding
|
|
11
|
+
from jax.sharding import PartitionSpec as P
|
|
12
|
+
from qwix._src.providers import ptq
|
|
13
|
+
|
|
14
|
+
import tpu_inference.models.jax.utils.quantization.quantization_utils as quantize_qwix # noqa: E402
|
|
15
|
+
from tpu_inference.models.common.model_loader import apply_qwix_quantization
|
|
16
|
+
from tpu_inference.models.jax.utils.quantization.quantization_utils import (
|
|
17
|
+
DEFAULT_MAX_NUM_BLOCKS_PER_REQ, DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS,
|
|
18
|
+
DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS)
|
|
19
|
+
|
|
20
|
+
mock_nnx = MagicMock()
|
|
21
|
+
mock_jax = MagicMock()
|
|
22
|
+
|
|
23
|
+
module_mocks = {
|
|
24
|
+
'flax': MagicMock(nnx=mock_nnx),
|
|
25
|
+
'flax.nnx': mock_nnx,
|
|
26
|
+
'jax': mock_jax,
|
|
27
|
+
'jax.sharding': MagicMock(),
|
|
28
|
+
'vllm': MagicMock(),
|
|
29
|
+
'vllm.config': MagicMock(),
|
|
30
|
+
'tpu_inference': MagicMock(),
|
|
31
|
+
'tpu_inference.logger': MagicMock(init_logger=lambda name: MagicMock()),
|
|
32
|
+
'tpu_inference.models.jax.utils.quantization.quantization_utils':
|
|
33
|
+
MagicMock(),
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TestParseQwixConfigToRules(unittest.TestCase):
|
|
38
|
+
"""Tests for the parse_qwix_config_to_rules function."""
|
|
39
|
+
|
|
40
|
+
def test_empty_config(self):
|
|
41
|
+
"""Test parsing an empty list of rules."""
|
|
42
|
+
qwix_config = []
|
|
43
|
+
rules = quantize_qwix.parse_qwix_config_to_rules(qwix_config)
|
|
44
|
+
self.assertEqual(rules, [])
|
|
45
|
+
|
|
46
|
+
def test_single_rule(self):
|
|
47
|
+
"""Test parsing a single quantization rule."""
|
|
48
|
+
qwix_config = [{
|
|
49
|
+
"module_path": ".*attn.*",
|
|
50
|
+
"weight_qtype": "int8",
|
|
51
|
+
}]
|
|
52
|
+
rules = quantize_qwix.parse_qwix_config_to_rules(qwix_config)
|
|
53
|
+
self.assertEqual(len(rules), 1)
|
|
54
|
+
self.assertIsInstance(rules[0], qwix.QuantizationRule)
|
|
55
|
+
self.assertEqual(rules[0].module_path, ".*attn.*")
|
|
56
|
+
self.assertEqual(rules[0].weight_qtype, "int8")
|
|
57
|
+
self.assertIsNone(rules[0].act_qtype)
|
|
58
|
+
|
|
59
|
+
def test_multiple_rules(self):
|
|
60
|
+
"""Test parsing multiple quantization rules."""
|
|
61
|
+
qwix_config = [
|
|
62
|
+
{
|
|
63
|
+
"module_path": ".*attn.*",
|
|
64
|
+
"weight_qtype": "int8",
|
|
65
|
+
},
|
|
66
|
+
{
|
|
67
|
+
"module_path": ".*mlp.*",
|
|
68
|
+
"weight_qtype": "int4",
|
|
69
|
+
"act_qtype": "int8",
|
|
70
|
+
},
|
|
71
|
+
]
|
|
72
|
+
rules = quantize_qwix.parse_qwix_config_to_rules(qwix_config)
|
|
73
|
+
self.assertEqual(len(rules), 2)
|
|
74
|
+
self.assertIsInstance(rules[0], qwix.QuantizationRule)
|
|
75
|
+
self.assertIsInstance(rules[1], qwix.QuantizationRule)
|
|
76
|
+
self.assertEqual(rules[0].module_path, ".*attn.*")
|
|
77
|
+
self.assertEqual(rules[1].module_path, ".*mlp.*")
|
|
78
|
+
self.assertEqual(rules[1].weight_qtype, "int4")
|
|
79
|
+
self.assertEqual(rules[1].act_qtype, "int8")
|
|
80
|
+
|
|
81
|
+
def test_invalid_rule_key_raises_error(self):
|
|
82
|
+
"""Test that an invalid key in a rule raises a TypeError."""
|
|
83
|
+
qwix_config = [{
|
|
84
|
+
"module_path": ".*attn.*",
|
|
85
|
+
"invalid_key": "some_value",
|
|
86
|
+
}]
|
|
87
|
+
with self.assertRaises(TypeError):
|
|
88
|
+
# qwix.QuantizationRule constructor will raise this error
|
|
89
|
+
quantize_qwix.parse_qwix_config_to_rules(qwix_config)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# A simple NNX module for testing quantization
|
|
93
|
+
class SimpleModel(nnx.Module):
|
|
94
|
+
|
|
95
|
+
def __init__(self, *, rngs: nnx.Rngs):
|
|
96
|
+
self.linear = nnx.Linear(10, 20, rngs=rngs)
|
|
97
|
+
|
|
98
|
+
def __call__(self, **kwargs):
|
|
99
|
+
# A simplified call signature for testing purposes
|
|
100
|
+
return self.linear(kwargs['input_ids'])
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@patch('qwix.quantize_model', autospec=True)
|
|
104
|
+
class TestQwixQuantizeNnxModel(unittest.TestCase):
|
|
105
|
+
"""Tests for the qwix_quantize_nnx_model function."""
|
|
106
|
+
|
|
107
|
+
def setUp(self):
|
|
108
|
+
"""Set up a mock environment for testing."""
|
|
109
|
+
if not jax.devices():
|
|
110
|
+
self.skipTest(
|
|
111
|
+
"JAX device not found, skipping JAX-dependent tests.")
|
|
112
|
+
self.mesh = Mesh(jax.devices(), ('model', ))
|
|
113
|
+
self.rng = jax.random.PRNGKey(0)
|
|
114
|
+
self.model = SimpleModel(rngs=nnx.Rngs(0))
|
|
115
|
+
|
|
116
|
+
self.qwix_config = [
|
|
117
|
+
{
|
|
118
|
+
"module_path": ".*linear.*",
|
|
119
|
+
"weight_qtype": "int8",
|
|
120
|
+
},
|
|
121
|
+
]
|
|
122
|
+
|
|
123
|
+
self.num_hidden_layers = 1
|
|
124
|
+
self.kv_cache_block_size = 16
|
|
125
|
+
self.kv_cache_num_kv_heads = 4
|
|
126
|
+
self.kv_cache_head_size = 64
|
|
127
|
+
|
|
128
|
+
self.mock_kv_caches = [MagicMock(), MagicMock()]
|
|
129
|
+
|
|
130
|
+
def test_quantization_call_with_correct_args(self, mock_quantize_model):
|
|
131
|
+
"""Test that qwix.quantize_model is called with the correct arguments."""
|
|
132
|
+
quantized_model_mock = MagicMock(spec=nnx.Module)
|
|
133
|
+
mock_quantize_model.return_value = quantized_model_mock
|
|
134
|
+
|
|
135
|
+
with patch(
|
|
136
|
+
"tpu_inference.models.jax.utils.quantization.quantization_utils.init_logger",
|
|
137
|
+
return_value=MagicMock()
|
|
138
|
+
), patch(
|
|
139
|
+
"tpu_inference.utils.hbm_usage_gb",
|
|
140
|
+
return_value=[(0.0, 0.0), (0.0, 0.0)]
|
|
141
|
+
), patch(
|
|
142
|
+
"tpu_inference.models.jax.utils.quantization.quantization_utils.create_kv_caches",
|
|
143
|
+
return_value=self.mock_kv_caches
|
|
144
|
+
), patch(
|
|
145
|
+
"tpu_inference.models.jax.utils.quantization.quantization_utils.quantization_config_file_path_to_dict",
|
|
146
|
+
return_value=self.qwix_config):
|
|
147
|
+
returned_model = quantize_qwix.qwix_quantize_nnx_model(
|
|
148
|
+
model=self.model,
|
|
149
|
+
qwix_config=self.qwix_config,
|
|
150
|
+
rng=self.rng,
|
|
151
|
+
mesh=self.mesh,
|
|
152
|
+
num_hidden_layers=self.num_hidden_layers,
|
|
153
|
+
kv_cache_block_size=self.kv_cache_block_size,
|
|
154
|
+
kv_cache_num_kv_heads=self.kv_cache_num_kv_heads,
|
|
155
|
+
kv_cache_head_size=self.kv_cache_head_size,
|
|
156
|
+
kv_cache_dtype="auto")
|
|
157
|
+
|
|
158
|
+
self.assertIs(returned_model, quantized_model_mock)
|
|
159
|
+
mock_quantize_model.assert_called_once()
|
|
160
|
+
args, kwargs = mock_quantize_model.call_args
|
|
161
|
+
|
|
162
|
+
# Assert positional arguments for qwix.quantize_model
|
|
163
|
+
self.assertIs(args[0], self.model)
|
|
164
|
+
self.assertIsInstance(args[1], qwix.PtqProvider)
|
|
165
|
+
|
|
166
|
+
# Assert keyword arguments (model inputs for tracing)
|
|
167
|
+
self.assertIn("kv_caches", kwargs)
|
|
168
|
+
self.assertEqual(kwargs["kv_caches"], self.mock_kv_caches)
|
|
169
|
+
self.assertIn("input_ids", kwargs)
|
|
170
|
+
self.assertEqual(kwargs["input_ids"].shape, (512, ))
|
|
171
|
+
self.assertIn("attention_metadata", kwargs)
|
|
172
|
+
attention_metadata = kwargs["attention_metadata"]
|
|
173
|
+
|
|
174
|
+
assert attention_metadata.input_positions.shape == (
|
|
175
|
+
DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS, )
|
|
176
|
+
assert attention_metadata.block_tables.shape == (
|
|
177
|
+
DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS *
|
|
178
|
+
DEFAULT_MAX_NUM_BLOCKS_PER_REQ, )
|
|
179
|
+
assert attention_metadata.seq_lens.shape == (
|
|
180
|
+
DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS, )
|
|
181
|
+
assert attention_metadata.query_start_loc.shape == (
|
|
182
|
+
DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS + 1, )
|
|
183
|
+
assert attention_metadata.request_distribution.shape == (3, )
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@patch.dict('sys.modules', module_mocks)
|
|
187
|
+
class TestApplyQwixQuantization(unittest.TestCase):
|
|
188
|
+
|
|
189
|
+
def setUp(self):
|
|
190
|
+
"""Set up common mock objects for all tests in this suite."""
|
|
191
|
+
mock_nnx.reset_mock()
|
|
192
|
+
mock_jax.reset_mock()
|
|
193
|
+
|
|
194
|
+
self.mock_vllm_config = MagicMock()
|
|
195
|
+
self.mock_vllm_config.additional_config = {}
|
|
196
|
+
self.mock_vllm_config.cache_config.block_size = 16
|
|
197
|
+
self.mock_vllm_config.model_config.get_head_size.return_value = 128
|
|
198
|
+
self.mock_vllm_config.model_config.get_total_num_kv_heads.return_value = 8
|
|
199
|
+
self.mock_vllm_config.model_config.hf_config.num_hidden_layers = 32
|
|
200
|
+
|
|
201
|
+
self.mock_model = MagicMock(name="original_nnx_model",
|
|
202
|
+
spec_set=nnx.Module)
|
|
203
|
+
self.mock_rng = MagicMock(name="mock_rng")
|
|
204
|
+
self.mock_mesh = MagicMock(name="mock_mesh")
|
|
205
|
+
|
|
206
|
+
def test_no_quantization_config(self):
|
|
207
|
+
"""
|
|
208
|
+
Test that the model is returned unchanged if no 'quantization' key exists.
|
|
209
|
+
"""
|
|
210
|
+
result = apply_qwix_quantization(self.mock_vllm_config,
|
|
211
|
+
self.mock_model,
|
|
212
|
+
self.mock_rng,
|
|
213
|
+
self.mock_mesh,
|
|
214
|
+
apply_to_abstract_model=False)
|
|
215
|
+
|
|
216
|
+
self.assertIs(result, self.mock_model,
|
|
217
|
+
"Model should be returned as-is.")
|
|
218
|
+
mock_nnx.jit.assert_not_called()
|
|
219
|
+
|
|
220
|
+
@patch('tpu_inference.models.common.model_loader.nnx.jit')
|
|
221
|
+
def test_quantization_applied_from_dict(self, mock_jit):
|
|
222
|
+
"""
|
|
223
|
+
Test that quantization is applied correctly when the config is a dictionary.
|
|
224
|
+
"""
|
|
225
|
+
qwix_rules = {"weights": "int8", "activations": None}
|
|
226
|
+
self.mock_vllm_config.additional_config = {
|
|
227
|
+
"quantization": {
|
|
228
|
+
"qwix": {
|
|
229
|
+
"rules": qwix_rules
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
with patch('tpu_inference.utils.get_padded_num_heads',
|
|
235
|
+
return_value=128):
|
|
236
|
+
apply_qwix_quantization(self.mock_vllm_config,
|
|
237
|
+
self.mock_model,
|
|
238
|
+
self.mock_rng,
|
|
239
|
+
self.mock_mesh,
|
|
240
|
+
apply_to_abstract_model=False)
|
|
241
|
+
mock_jit.assert_called_once()
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class TestQuantizationConfigFileToDict(unittest.TestCase):
|
|
245
|
+
"""Tests for the quantization_config_file_path_to_dict function."""
|
|
246
|
+
|
|
247
|
+
@patch("os.listdir")
|
|
248
|
+
@patch("os.path.join")
|
|
249
|
+
def test_file_not_found_raises_value_error(self, mock_join, mock_listdir):
|
|
250
|
+
"""Test that a ValueError is raised if the config file is not found."""
|
|
251
|
+
mock_listdir.return_value = ["another_file.yaml", "config.txt"]
|
|
252
|
+
config_file_path = "non_existent.yaml"
|
|
253
|
+
|
|
254
|
+
with self.assertRaisesRegex(
|
|
255
|
+
ValueError,
|
|
256
|
+
f"Could not find quantization config file with name '{config_file_path}'"
|
|
257
|
+
):
|
|
258
|
+
quantize_qwix.quantization_config_file_path_to_dict(
|
|
259
|
+
config_file_path)
|
|
260
|
+
mock_listdir.assert_called_once_with(
|
|
261
|
+
quantize_qwix.QUANTIZATION_CONFIG_PATH)
|
|
262
|
+
|
|
263
|
+
@patch("os.listdir")
|
|
264
|
+
@patch("os.path.join")
|
|
265
|
+
@patch("builtins.open",
|
|
266
|
+
new_callable=mock_open,
|
|
267
|
+
read_data="qwix:\n rules: []")
|
|
268
|
+
def test_file_found_and_loaded_successfully(self, mock_file, mock_join,
|
|
269
|
+
mock_listdir):
|
|
270
|
+
"""Test that the YAML file is correctly loaded when found."""
|
|
271
|
+
config_filename = "my_quant_config.yaml"
|
|
272
|
+
mock_listdir.return_value = ["another.yaml", config_filename]
|
|
273
|
+
mock_join.return_value = f"/fake/path/{config_filename}"
|
|
274
|
+
expected_dict = {"qwix": {"rules": []}}
|
|
275
|
+
|
|
276
|
+
result = quantize_qwix.quantization_config_file_path_to_dict(
|
|
277
|
+
config_filename)
|
|
278
|
+
|
|
279
|
+
mock_listdir.assert_called_once_with(
|
|
280
|
+
quantize_qwix.QUANTIZATION_CONFIG_PATH)
|
|
281
|
+
mock_join.assert_called_once_with(
|
|
282
|
+
quantize_qwix.QUANTIZATION_CONFIG_PATH, config_filename)
|
|
283
|
+
mock_file.assert_called_once_with(f"/fake/path/{config_filename}", "r")
|
|
284
|
+
self.assertEqual(result, expected_dict)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class TestApplyQwixQuantizationLogic(unittest.TestCase):
|
|
288
|
+
"""Tests the core logic of apply_qwix_quantization."""
|
|
289
|
+
|
|
290
|
+
def setUp(self):
|
|
291
|
+
self.mock_vllm_config = MagicMock()
|
|
292
|
+
self.mock_vllm_config.additional_config = {}
|
|
293
|
+
self.mock_vllm_config.cache_config.block_size = 16
|
|
294
|
+
self.mock_vllm_config.model_config.get_head_size.return_value = 128
|
|
295
|
+
self.mock_vllm_config.model_config.get_total_num_kv_heads.return_value = 8
|
|
296
|
+
self.mock_vllm_config.model_config.hf_config.num_hidden_layers = 32
|
|
297
|
+
self.mock_model = MagicMock(name="original_nnx_model")
|
|
298
|
+
self.mock_rng = MagicMock(name="mock_rng")
|
|
299
|
+
self.mock_mesh = MagicMock(name="mock_mesh", shape={"model": 1})
|
|
300
|
+
|
|
301
|
+
def test_quantization_config_without_qwix_rules(self):
|
|
302
|
+
"""Test model is unchanged if the config lacks 'qwix' or 'rules'."""
|
|
303
|
+
self.mock_vllm_config.additional_config = {"quantization": {}}
|
|
304
|
+
result1 = quantize_qwix.apply_qwix_quantization(
|
|
305
|
+
self.mock_vllm_config, self.mock_model, self.mock_rng,
|
|
306
|
+
self.mock_mesh, False)
|
|
307
|
+
self.assertIs(result1, self.mock_model)
|
|
308
|
+
|
|
309
|
+
self.mock_vllm_config.additional_config = {
|
|
310
|
+
"quantization": {
|
|
311
|
+
"qwix": {}
|
|
312
|
+
}
|
|
313
|
+
}
|
|
314
|
+
result2 = quantize_qwix.apply_qwix_quantization(
|
|
315
|
+
self.mock_vllm_config, self.mock_model, self.mock_rng,
|
|
316
|
+
self.mock_mesh, False)
|
|
317
|
+
self.assertIs(result2, self.mock_model)
|
|
318
|
+
|
|
319
|
+
@patch(
|
|
320
|
+
'tpu_inference.models.jax.utils.quantization.quantization_utils.qwix_quantize_nnx_model'
|
|
321
|
+
)
|
|
322
|
+
@patch(
|
|
323
|
+
'tpu_inference.models.jax.utils.quantization.quantization_utils.utils')
|
|
324
|
+
def test_apply_to_abstract_model(self, mock_utils, mock_quantize_func):
|
|
325
|
+
"""Test quantization is correctly applied to an abstract model factory."""
|
|
326
|
+
mock_utils.get_padded_num_heads.return_value = 8
|
|
327
|
+
mock_utils.get_padded_head_dim.return_value = 128
|
|
328
|
+
qwix_rules = [{"module_path": ".*", "weight_qtype": "int8"}]
|
|
329
|
+
self.mock_vllm_config.additional_config = {
|
|
330
|
+
"quantization": {
|
|
331
|
+
"qwix": {
|
|
332
|
+
"rules": qwix_rules
|
|
333
|
+
}
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
mock_abstract_model = MagicMock(name="abstract_model")
|
|
337
|
+
mock_model_fn = MagicMock(name="model_factory",
|
|
338
|
+
return_value=mock_abstract_model)
|
|
339
|
+
quantized_model = MagicMock(name="quantized_model")
|
|
340
|
+
mock_quantize_func.return_value = quantized_model
|
|
341
|
+
|
|
342
|
+
model_factory = quantize_qwix.apply_qwix_quantization(
|
|
343
|
+
self.mock_vllm_config,
|
|
344
|
+
mock_model_fn,
|
|
345
|
+
self.mock_rng,
|
|
346
|
+
self.mock_mesh,
|
|
347
|
+
apply_to_abstract_model=True)
|
|
348
|
+
|
|
349
|
+
self.assertTrue(callable(model_factory))
|
|
350
|
+
result_model = model_factory()
|
|
351
|
+
|
|
352
|
+
mock_model_fn.assert_called_once()
|
|
353
|
+
mock_quantize_func.assert_called_once()
|
|
354
|
+
call_kwargs = mock_quantize_func.call_args.kwargs
|
|
355
|
+
self.assertIs(call_kwargs['model'], mock_abstract_model)
|
|
356
|
+
self.assertIs(call_kwargs['rng'], self.mock_rng)
|
|
357
|
+
self.assertIs(result_model, quantized_model)
|
|
358
|
+
|
|
359
|
+
@patch(
|
|
360
|
+
'tpu_inference.models.jax.utils.quantization.quantization_utils.qwix_quantize_nnx_model'
|
|
361
|
+
)
|
|
362
|
+
@patch(
|
|
363
|
+
'tpu_inference.models.jax.utils.quantization.quantization_utils.utils')
|
|
364
|
+
def test_apply_to_abstract_model_with_initialize_cache(
|
|
365
|
+
self, mock_utils, mock_quantize_func):
|
|
366
|
+
"""Test abstract model quantization with 'initialize_cache' method."""
|
|
367
|
+
mock_utils.get_padded_num_heads.return_value = 8
|
|
368
|
+
mock_utils.get_padded_head_dim.return_value = 128
|
|
369
|
+
qwix_rules = [{"module_path": ".*", "weight_qtype": "int8"}]
|
|
370
|
+
self.mock_vllm_config.additional_config = {
|
|
371
|
+
"quantization": {
|
|
372
|
+
"qwix": {
|
|
373
|
+
"rules": qwix_rules
|
|
374
|
+
}
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
mock_abstract_model = MagicMock(name="abstract_model")
|
|
378
|
+
mock_abstract_model.initialize_cache = MagicMock()
|
|
379
|
+
mock_model_fn = MagicMock(name="model_factory",
|
|
380
|
+
return_value=mock_abstract_model)
|
|
381
|
+
|
|
382
|
+
model_factory = quantize_qwix.apply_qwix_quantization(
|
|
383
|
+
self.mock_vllm_config,
|
|
384
|
+
mock_model_fn,
|
|
385
|
+
self.mock_rng,
|
|
386
|
+
self.mock_mesh,
|
|
387
|
+
apply_to_abstract_model=True)
|
|
388
|
+
|
|
389
|
+
model_factory()
|
|
390
|
+
|
|
391
|
+
mock_abstract_model.initialize_cache.assert_called_once()
|
|
392
|
+
mock_quantize_func.assert_called_once()
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
class TestDetermineWhetherToApplyQwixOnAbstractModel(unittest.TestCase):
|
|
396
|
+
"""Tests for apply_qwix_on_abstract_model."""
|
|
397
|
+
|
|
398
|
+
def setUp(self):
|
|
399
|
+
self.mock_vllm_config = MagicMock()
|
|
400
|
+
self.mock_vllm_config.additional_config = {
|
|
401
|
+
"quantization": {
|
|
402
|
+
"qwix": {
|
|
403
|
+
"use_abstract_model": True,
|
|
404
|
+
"rules": [{
|
|
405
|
+
"module_path": ".*",
|
|
406
|
+
"weight_qtype": "int8"
|
|
407
|
+
}]
|
|
408
|
+
}
|
|
409
|
+
}
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
self.mock_vllm_config_no_abstract_model = MagicMock()
|
|
413
|
+
self.mock_vllm_config_no_abstract_model.additional_config = {
|
|
414
|
+
"quantization": {
|
|
415
|
+
"qwix": {
|
|
416
|
+
"rules": [{
|
|
417
|
+
"module_path": ".*",
|
|
418
|
+
"weight_qtype": "int8"
|
|
419
|
+
}]
|
|
420
|
+
}
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
self.mock_vllm_config_no_additional_config = MagicMock()
|
|
425
|
+
self.mock_vllm_config_no_additional_config.additional_config = {}
|
|
426
|
+
|
|
427
|
+
def test_returns_false_when_additional_config_is_missing(self):
|
|
428
|
+
"""Test it returns False when additional_config is missing."""
|
|
429
|
+
result = quantize_qwix.apply_qwix_on_abstract_model(
|
|
430
|
+
self.mock_vllm_config_no_additional_config)
|
|
431
|
+
self.assertFalse(result)
|
|
432
|
+
|
|
433
|
+
def test_returns_true_when_additional_config_is_present(self):
|
|
434
|
+
"""Test it returns False when additional_config is missing."""
|
|
435
|
+
result = quantize_qwix.apply_qwix_on_abstract_model(
|
|
436
|
+
self.mock_vllm_config)
|
|
437
|
+
self.assertTrue(result)
|
|
438
|
+
|
|
439
|
+
def test_returns_false_when_use_abstract_model_is_false(self):
|
|
440
|
+
"""Test it returns False when use_abstract_model is False."""
|
|
441
|
+
result = quantize_qwix.apply_qwix_on_abstract_model(
|
|
442
|
+
self.mock_vllm_config_no_abstract_model)
|
|
443
|
+
self.assertFalse(result)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
|
|
447
|
+
"""Tests for the load_random_weights_into_qwix_abstract_model function."""
|
|
448
|
+
|
|
449
|
+
def setUp(self):
|
|
450
|
+
"""Set up a mock environment for testing."""
|
|
451
|
+
if not jax.devices():
|
|
452
|
+
self.skipTest(
|
|
453
|
+
"JAX device not found, skipping JAX-dependent tests.")
|
|
454
|
+
|
|
455
|
+
self.rng = jax.random.PRNGKey(0)
|
|
456
|
+
self.mesh = Mesh(jax.devices(), ('data', ))
|
|
457
|
+
self.quantization_config = {
|
|
458
|
+
"weight_block_size": [64, 1],
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
# Mock model structure
|
|
462
|
+
self.model = MagicMock(spec=['weight_loader', 'initialize_cache'])
|
|
463
|
+
self.model.weight_loader = MagicMock(
|
|
464
|
+
spec=['scale_dtype', 'scale_shap_map_for_random_weight_loading'])
|
|
465
|
+
self.model.weight_loader.scale_dtype = jnp.float16
|
|
466
|
+
self.model.weight_loader.scale_shap_map_for_random_weight_loading = {}
|
|
467
|
+
|
|
468
|
+
@patch(
|
|
469
|
+
'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
|
|
470
|
+
)
|
|
471
|
+
@patch(
|
|
472
|
+
'tpu_inference.models.jax.utils.quantization.quantization_utils.get_random_sharded_array'
|
|
473
|
+
)
|
|
474
|
+
def test_successful_initialization(self, mock_get_random_array,
|
|
475
|
+
mock_iter_graph):
|
|
476
|
+
"""Test that variables are correctly initialized."""
|
|
477
|
+
# Setup mock graph elements
|
|
478
|
+
mock_weight_param = nnx.Param(jnp.empty((128, 64), dtype=jnp.int8),
|
|
479
|
+
sharding=P('data', None))
|
|
480
|
+
mock_scale_var = nnx.Variable(jnp.empty((1, 1), dtype=jnp.float16))
|
|
481
|
+
mock_rng_var = nnx.Variable(jax.random.PRNGKey(0))
|
|
482
|
+
mock_random_array = jax.numpy.ones(1)
|
|
483
|
+
mock_get_random_array.return_value = mock_random_array
|
|
484
|
+
|
|
485
|
+
mock_iter_graph.return_value = [
|
|
486
|
+
(('layers', '0', 'attention', 'wq', 'kernel'), mock_weight_param),
|
|
487
|
+
(('layers', '0', 'attention', 'wq', 'array', 'scale'),
|
|
488
|
+
mock_scale_var),
|
|
489
|
+
(('rng', 'params', 'key'), mock_rng_var),
|
|
490
|
+
]
|
|
491
|
+
|
|
492
|
+
quantize_qwix.load_random_weights_into_qwix_abstract_model(
|
|
493
|
+
self.rng, self.model, self.mesh, self.quantization_config)
|
|
494
|
+
|
|
495
|
+
# Assert weight is updated
|
|
496
|
+
self.assertIs(mock_weight_param.value, mock_random_array)
|
|
497
|
+
# Assert scale is updated
|
|
498
|
+
self.assertIs(mock_scale_var.value, mock_random_array)
|
|
499
|
+
# Assert RNG key is updated with the passed-in RNG
|
|
500
|
+
self.assertIs(mock_rng_var.value, self.rng)
|
|
501
|
+
# Assert initialize_cache is called
|
|
502
|
+
self.model.initialize_cache.assert_called_once()
|
|
503
|
+
|
|
504
|
+
def test_invalid_config_raises_assertion_error(self):
|
|
505
|
+
"""Test that an invalid quantization_block_sizes config raises an error."""
|
|
506
|
+
invalid_config = {"weight_block_size": [64]} # Length is 1, not 2
|
|
507
|
+
with self.assertRaisesRegex(AssertionError,
|
|
508
|
+
"Expected only 2 quantization block"):
|
|
509
|
+
quantize_qwix.load_random_weights_into_qwix_abstract_model(
|
|
510
|
+
self.rng, self.model, self.mesh, invalid_config)
|
|
511
|
+
|
|
512
|
+
@patch(
|
|
513
|
+
'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
|
|
514
|
+
)
|
|
515
|
+
def test_param_shape_setting_no_scale_map(self, mock_iter_graph):
|
|
516
|
+
"""Test correct scale shape calculation when not in the map."""
|
|
517
|
+
old_weight_param_val = jnp.empty((128, 64))
|
|
518
|
+
mock_weight_param = nnx.Param(old_weight_param_val, dtype=jnp.int8)
|
|
519
|
+
old_scale_var_val = jnp.empty((0, 0))
|
|
520
|
+
mock_scale_var = nnx.Variable(old_scale_var_val)
|
|
521
|
+
|
|
522
|
+
mock_iter_graph.return_value = [
|
|
523
|
+
(('layers', '0', 'attention', 'wq', 'kernel'), mock_weight_param),
|
|
524
|
+
(('layers', '0', 'attention', 'wq', 'array', 'scale'),
|
|
525
|
+
mock_scale_var),
|
|
526
|
+
]
|
|
527
|
+
|
|
528
|
+
quantize_qwix.load_random_weights_into_qwix_abstract_model(
|
|
529
|
+
self.rng, self.model, self.mesh, self.quantization_config)
|
|
530
|
+
|
|
531
|
+
new_weight_param_val = mock_weight_param.value
|
|
532
|
+
new_scale_var_val = mock_scale_var.value
|
|
533
|
+
|
|
534
|
+
expected_scale_shape = (128 // 64, 64 // 64)
|
|
535
|
+
actual_scale_shape = new_scale_var_val.shape
|
|
536
|
+
|
|
537
|
+
expected_weight_shape = (128, 64)
|
|
538
|
+
actual_weight_shape = new_weight_param_val.shape
|
|
539
|
+
|
|
540
|
+
self.assertEqual(expected_scale_shape, actual_scale_shape)
|
|
541
|
+
self.assertEqual(expected_weight_shape, actual_weight_shape)
|
|
542
|
+
self.assertNotEqual(old_scale_var_val.shape, new_scale_var_val.shape)
|
|
543
|
+
assert jnp.not_equal(old_weight_param_val, new_weight_param_val).all()
|
|
544
|
+
|
|
545
|
+
@patch(
|
|
546
|
+
'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
|
|
547
|
+
)
|
|
548
|
+
def test_param_shape_setting_with_scale_map(self, mock_iter_graph):
|
|
549
|
+
"""Test correct scale shape calculation when in the map."""
|
|
550
|
+
old_weight_param_val = jnp.empty((128, 64))
|
|
551
|
+
mock_weight_param = nnx.Param(old_weight_param_val, dtype=jnp.int8)
|
|
552
|
+
old_scale_var_val = jnp.empty((0, 0))
|
|
553
|
+
mock_scale_var = nnx.Variable(old_scale_var_val)
|
|
554
|
+
|
|
555
|
+
expected_scale_shape = (55, 34)
|
|
556
|
+
|
|
557
|
+
self.model.weight_loader.scale_shap_map_for_random_weight_loading = {
|
|
558
|
+
'wq': expected_scale_shape
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
mock_iter_graph.return_value = [
|
|
562
|
+
(('layers', '0', 'attention', 'wq', 'kernel'), mock_weight_param),
|
|
563
|
+
(('layers', '0', 'attention', 'wq', 'array', 'scale'),
|
|
564
|
+
mock_scale_var),
|
|
565
|
+
]
|
|
566
|
+
|
|
567
|
+
quantize_qwix.load_random_weights_into_qwix_abstract_model(
|
|
568
|
+
self.rng, self.model, self.mesh, self.quantization_config)
|
|
569
|
+
|
|
570
|
+
new_weight_param_val = mock_weight_param.value
|
|
571
|
+
new_scale_var_val = mock_scale_var.value
|
|
572
|
+
|
|
573
|
+
actual_scale_shape = new_scale_var_val.shape
|
|
574
|
+
|
|
575
|
+
expected_weight_shape = (128, 64)
|
|
576
|
+
actual_weight_shape = new_weight_param_val.shape
|
|
577
|
+
|
|
578
|
+
self.assertEqual(expected_scale_shape, actual_scale_shape)
|
|
579
|
+
self.assertEqual(expected_weight_shape, actual_weight_shape)
|
|
580
|
+
self.assertNotEqual(old_scale_var_val.shape, new_scale_var_val.shape)
|
|
581
|
+
assert jnp.not_equal(old_weight_param_val, new_weight_param_val).all()
|
|
582
|
+
|
|
583
|
+
@patch('jax.random.randint')
|
|
584
|
+
@patch('jax.random.normal')
|
|
585
|
+
@patch('jax.make_array_from_callback')
|
|
586
|
+
def test_get_random_sharded_array_dtype_dispatch(self, mock_make_array,
|
|
587
|
+
mock_normal,
|
|
588
|
+
mock_randint):
|
|
589
|
+
"""Test that integer dtypes call randint and floats call normal."""
|
|
590
|
+
# Test integer
|
|
591
|
+
quantize_qwix.get_random_sharded_array(
|
|
592
|
+
self.rng, self.mesh, nnx.Param(jnp.empty((2, 2)), sharding=P()),
|
|
593
|
+
(2, 2), jnp.int8, "int_param")
|
|
594
|
+
mock_randint.assert_called_once()
|
|
595
|
+
mock_normal.assert_not_called()
|
|
596
|
+
|
|
597
|
+
mock_randint.reset_mock()
|
|
598
|
+
mock_normal.reset_mock()
|
|
599
|
+
|
|
600
|
+
# Test float
|
|
601
|
+
quantize_qwix.get_random_sharded_array(
|
|
602
|
+
self.rng, self.mesh, nnx.Param(jnp.empty((2, 2)), sharding=P()),
|
|
603
|
+
(2, 2), jnp.float32, "float_param")
|
|
604
|
+
mock_randint.assert_not_called()
|
|
605
|
+
mock_normal.assert_called_once()
|
|
606
|
+
|
|
607
|
+
@patch(
|
|
608
|
+
"tpu_inference.models.jax.utils.quantization.quantization_utils.logger.warning"
|
|
609
|
+
)
|
|
610
|
+
@patch("jax.make_array_from_callback")
|
|
611
|
+
def test_get_random_sharded_array_sharding_fallback(
|
|
612
|
+
self, mock_make_array, mock_logger_warning):
|
|
613
|
+
"""Test that sharding failure logs a warning and uses a fallback."""
|
|
614
|
+
# First call raises an error, second call (fallback) succeeds
|
|
615
|
+
mock_make_array.side_effect = [
|
|
616
|
+
ValueError("Sharding failed"),
|
|
617
|
+
MagicMock()
|
|
618
|
+
]
|
|
619
|
+
|
|
620
|
+
param = nnx.Param(jnp.empty((2, 2)), sharding=P('data', None))
|
|
621
|
+
quantize_qwix.get_random_sharded_array(self.rng, self.mesh, param,
|
|
622
|
+
(2, 2), jnp.float32,
|
|
623
|
+
"test_param")
|
|
624
|
+
|
|
625
|
+
# Check that a warning was logged
|
|
626
|
+
mock_logger_warning.assert_called_once()
|
|
627
|
+
self.assertIn("Could not create sharded scale for test_param",
|
|
628
|
+
mock_logger_warning.call_args[0][0])
|
|
629
|
+
|
|
630
|
+
# Check that the fallback was attempted with an empty PartitionSpec
|
|
631
|
+
fallback_call_args = mock_make_array.call_args_list[1]
|
|
632
|
+
fallback_sharding = fallback_call_args.args[1]
|
|
633
|
+
self.assertEqual(fallback_sharding, NamedSharding(self.mesh, P()))
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
class TestManualQwixQuantization(unittest.TestCase):
|
|
637
|
+
"""Tests for manual Qwix quantization functions."""
|
|
638
|
+
|
|
639
|
+
def setUp(self):
|
|
640
|
+
if not jax.devices():
|
|
641
|
+
self.skipTest(
|
|
642
|
+
"JAX device not found, skipping JAX-dependent tests.")
|
|
643
|
+
self.weight = jnp.ones((4, 4))
|
|
644
|
+
self.inputs = jnp.ones((8, 4))
|
|
645
|
+
self.qtype = jnp.int8
|
|
646
|
+
self.channelwise_axes = [0]
|
|
647
|
+
self.tiled_axes = {}
|
|
648
|
+
self.calibration_method = 'max'
|
|
649
|
+
|
|
650
|
+
@patch(
|
|
651
|
+
'tpu_inference.models.jax.utils.quantization.quantization_utils.ptq.create_quantized_param'
|
|
652
|
+
)
|
|
653
|
+
def test_manually_quantize_qwix_weight(self, mock_create_param):
|
|
654
|
+
"""Test that manually_quantize_qwix_weight calls ptq.create_quantized_param correctly."""
|
|
655
|
+
quantize_qwix.manually_quantize_qwix_weight(
|
|
656
|
+
weight=self.weight,
|
|
657
|
+
qtype=self.qtype,
|
|
658
|
+
channelwise_axes=self.channelwise_axes,
|
|
659
|
+
tiled_axes=self.tiled_axes,
|
|
660
|
+
calibration_method=self.calibration_method)
|
|
661
|
+
|
|
662
|
+
mock_create_param.assert_called_once()
|
|
663
|
+
args, _ = mock_create_param.call_args
|
|
664
|
+
passed_weight, passed_how_to_quantize = args
|
|
665
|
+
|
|
666
|
+
self.assertTrue(jnp.array_equal(passed_weight, self.weight))
|
|
667
|
+
self.assertIsInstance(passed_how_to_quantize, ptq.qarray.HowToQuantize)
|
|
668
|
+
self.assertEqual(passed_how_to_quantize.qtype, self.qtype)
|
|
669
|
+
self.assertEqual(passed_how_to_quantize.channelwise_axes,
|
|
670
|
+
self.channelwise_axes)
|
|
671
|
+
self.assertEqual(passed_how_to_quantize.tiled_axes, self.tiled_axes)
|
|
672
|
+
self.assertEqual(passed_how_to_quantize.calibration_method,
|
|
673
|
+
self.calibration_method)
|
|
674
|
+
|
|
675
|
+
@patch(
|
|
676
|
+
'tpu_inference.models.jax.utils.quantization.quantization_utils.ptq.quantize_act'
|
|
677
|
+
)
|
|
678
|
+
@patch('qwix.pallas.get_current_rule')
|
|
679
|
+
def test_manually_quantize_qwix_activation(self, mock_get_rule,
|
|
680
|
+
mock_quantize_act):
|
|
681
|
+
"""Test that manually_quantize_qwix_activation calls ptq.quantize_act correctly."""
|
|
682
|
+
mock_rule = MagicMock()
|
|
683
|
+
mock_rule.act_static_scale = False
|
|
684
|
+
mock_get_rule.return_value = mock_rule
|
|
685
|
+
rule_name = "test_rule"
|
|
686
|
+
|
|
687
|
+
quantize_qwix.manually_quantize_qwix_activation(
|
|
688
|
+
inputs=self.inputs,
|
|
689
|
+
rule_name=rule_name,
|
|
690
|
+
qtype=self.qtype,
|
|
691
|
+
channelwise_axes=self.channelwise_axes,
|
|
692
|
+
tiled_axes=self.tiled_axes,
|
|
693
|
+
calibration_method=self.calibration_method)
|
|
694
|
+
|
|
695
|
+
mock_get_rule.assert_called_once_with(rule_name)
|
|
696
|
+
mock_quantize_act.assert_called_once()
|
|
697
|
+
|
|
698
|
+
args, _ = mock_quantize_act.call_args
|
|
699
|
+
passed_inputs, passed_how, passed_rule, passed_act_name = args
|
|
700
|
+
|
|
701
|
+
self.assertTrue(jnp.array_equal(passed_inputs, self.inputs))
|
|
702
|
+
self.assertIsInstance(passed_how, ptq.qarray.HowToQuantize)
|
|
703
|
+
self.assertEqual(passed_how.qtype, self.qtype)
|
|
704
|
+
self.assertEqual(passed_how.channelwise_axes, self.channelwise_axes)
|
|
705
|
+
self.assertEqual(passed_how.tiled_axes, self.tiled_axes)
|
|
706
|
+
self.assertEqual(passed_how.calibration_method,
|
|
707
|
+
self.calibration_method)
|
|
708
|
+
self.assertIs(passed_rule, mock_rule)
|
|
709
|
+
self.assertEqual(passed_act_name, "") # act_name is hardcoded to ""
|
|
710
|
+
|
|
711
|
+
@patch('qwix.pallas.get_current_rule')
|
|
712
|
+
def test_manually_quantize_qwix_activation_static_scale_raises_error(
|
|
713
|
+
self, mock_get_rule):
|
|
714
|
+
"""Test that an assertion is raised if the rule has static scale."""
|
|
715
|
+
mock_rule = MagicMock()
|
|
716
|
+
mock_rule.act_static_scale = True
|
|
717
|
+
mock_get_rule.return_value = mock_rule
|
|
718
|
+
|
|
719
|
+
with self.assertRaisesRegex(AssertionError,
|
|
720
|
+
"Static scale not supported right now"):
|
|
721
|
+
quantize_qwix.manually_quantize_qwix_activation(
|
|
722
|
+
inputs=self.inputs,
|
|
723
|
+
rule_name="any_rule",
|
|
724
|
+
qtype=self.qtype,
|
|
725
|
+
channelwise_axes=self.channelwise_axes,
|
|
726
|
+
tiled_axes=self.tiled_axes,
|
|
727
|
+
calibration_method=self.calibration_method)
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
class TestGetQuantDtypeFromQwixConfig(unittest.TestCase):
|
|
731
|
+
"""Tests for the get_quant_dtype_from_qwix_config function."""
|
|
732
|
+
|
|
733
|
+
def setUp(self):
|
|
734
|
+
self.mock_vllm_config = MagicMock()
|
|
735
|
+
self.mock_vllm_config.additional_config = {}
|
|
736
|
+
|
|
737
|
+
def test_get_quant_dtype_success(self):
|
|
738
|
+
"""Test successful extraction of dtypes from a valid config."""
|
|
739
|
+
self.mock_vllm_config.additional_config = {
|
|
740
|
+
"quantization": {
|
|
741
|
+
"qwix": {
|
|
742
|
+
"scale_dtype":
|
|
743
|
+
"float16",
|
|
744
|
+
"rules": [
|
|
745
|
+
{
|
|
746
|
+
"module_path": ".*mlp.*",
|
|
747
|
+
"weight_qtype": "int4"
|
|
748
|
+
},
|
|
749
|
+
{
|
|
750
|
+
"module_path": ".*",
|
|
751
|
+
"weight_qtype": "int8"
|
|
752
|
+
},
|
|
753
|
+
],
|
|
754
|
+
}
|
|
755
|
+
}
|
|
756
|
+
}
|
|
757
|
+
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
|
|
758
|
+
self.mock_vllm_config)
|
|
759
|
+
self.assertEqual(scale_dtype, jnp.float16)
|
|
760
|
+
self.assertEqual(quant_dtype, jnp.int8)
|
|
761
|
+
|
|
762
|
+
def test_get_quant_dtype_default_scale(self):
|
|
763
|
+
"""Test that scale_dtype defaults to bfloat16 when not specified."""
|
|
764
|
+
self.mock_vllm_config.additional_config = {
|
|
765
|
+
"quantization": {
|
|
766
|
+
"qwix": {
|
|
767
|
+
"rules": [{
|
|
768
|
+
"module_path": ".*",
|
|
769
|
+
"weight_qtype": "int8"
|
|
770
|
+
}]
|
|
771
|
+
}
|
|
772
|
+
}
|
|
773
|
+
}
|
|
774
|
+
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
|
|
775
|
+
self.mock_vllm_config)
|
|
776
|
+
self.assertEqual(scale_dtype, jnp.bfloat16)
|
|
777
|
+
self.assertEqual(quant_dtype, jnp.int8)
|
|
778
|
+
|
|
779
|
+
def test_no_quantization_config_returns_defaults(self):
|
|
780
|
+
"""Test that default dtypes are returned when config is missing."""
|
|
781
|
+
self.mock_vllm_config.additional_config = {}
|
|
782
|
+
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
|
|
783
|
+
self.mock_vllm_config)
|
|
784
|
+
self.assertEqual(scale_dtype, jnp.bfloat16)
|
|
785
|
+
self.assertIsNone(quant_dtype)
|
|
786
|
+
|
|
787
|
+
def test_get_quant_dtype_no_wildcard_rule_returns_none(self):
|
|
788
|
+
"""Test that quant_dtype is None if no wildcard rule is found."""
|
|
789
|
+
self.mock_vllm_config.additional_config = {
|
|
790
|
+
"quantization": {
|
|
791
|
+
"qwix": {
|
|
792
|
+
"rules": [{
|
|
793
|
+
"module_path": ".*mlp.*",
|
|
794
|
+
"weight_qtype": "int4"
|
|
795
|
+
}]
|
|
796
|
+
}
|
|
797
|
+
}
|
|
798
|
+
}
|
|
799
|
+
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
|
|
800
|
+
self.mock_vllm_config)
|
|
801
|
+
self.assertEqual(scale_dtype, jnp.bfloat16)
|
|
802
|
+
self.assertIsNone(quant_dtype)
|
|
803
|
+
|
|
804
|
+
def test_get_quant_dtype_wildcard_rule_missing_qtype_raises_error(self):
|
|
805
|
+
"""Test that an assertion is raised if the wildcard rule is missing weight_qtype."""
|
|
806
|
+
self.mock_vllm_config.additional_config = {
|
|
807
|
+
"quantization": {
|
|
808
|
+
"qwix": {
|
|
809
|
+
"rules": [{
|
|
810
|
+
"module_path": ".*"
|
|
811
|
+
}]
|
|
812
|
+
}
|
|
813
|
+
}
|
|
814
|
+
}
|
|
815
|
+
with self.assertRaisesRegex(AssertionError,
|
|
816
|
+
"Quantization dtype not found"):
|
|
817
|
+
quantize_qwix.get_quant_dtype_from_qwix_config(
|
|
818
|
+
self.mock_vllm_config)
|
|
819
|
+
|
|
820
|
+
def test_get_quant_dtype_no_rules_key_returns_none(self):
|
|
821
|
+
"""Test that quant_dtype is None if 'rules' key is missing."""
|
|
822
|
+
self.mock_vllm_config.additional_config = {
|
|
823
|
+
"quantization": {
|
|
824
|
+
"qwix": {
|
|
825
|
+
"scale_dtype": "float16",
|
|
826
|
+
}
|
|
827
|
+
}
|
|
828
|
+
}
|
|
829
|
+
scale_dtype, quant_dtype = quantize_qwix.get_quant_dtype_from_qwix_config(
|
|
830
|
+
self.mock_vllm_config)
|
|
831
|
+
self.assertEqual(scale_dtype, jnp.float16)
|
|
832
|
+
self.assertIsNone(quant_dtype)
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
if __name__ == '__main__':
|
|
836
|
+
unittest.main()
|