tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -11,9 +11,9 @@ from jax.sharding import Mesh, NamedSharding
|
|
|
11
11
|
from jax.sharding import PartitionSpec as P
|
|
12
12
|
from qwix._src.providers import ptq
|
|
13
13
|
|
|
14
|
-
import tpu_inference.models.jax.utils.
|
|
14
|
+
import tpu_inference.models.jax.utils.qwix.qwix_utils as quantize_qwix # noqa: E402
|
|
15
15
|
from tpu_inference.models.common.model_loader import apply_qwix_quantization
|
|
16
|
-
from tpu_inference.models.jax.utils.
|
|
16
|
+
from tpu_inference.models.jax.utils.qwix.qwix_utils import (
|
|
17
17
|
DEFAULT_MAX_NUM_BLOCKS_PER_REQ, DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS,
|
|
18
18
|
DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS)
|
|
19
19
|
|
|
@@ -29,8 +29,7 @@ module_mocks = {
|
|
|
29
29
|
'vllm.config': MagicMock(),
|
|
30
30
|
'tpu_inference': MagicMock(),
|
|
31
31
|
'tpu_inference.logger': MagicMock(init_logger=lambda name: MagicMock()),
|
|
32
|
-
'tpu_inference.models.jax.utils.
|
|
33
|
-
MagicMock(),
|
|
32
|
+
'tpu_inference.models.jax.utils.qwix.qwix_utils': MagicMock(),
|
|
34
33
|
}
|
|
35
34
|
|
|
36
35
|
|
|
@@ -112,6 +111,8 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
|
|
|
112
111
|
self.mesh = Mesh(jax.devices(), ('model', ))
|
|
113
112
|
self.rng = jax.random.PRNGKey(0)
|
|
114
113
|
self.model = SimpleModel(rngs=nnx.Rngs(0))
|
|
114
|
+
self.model.vllm_config = MagicMock()
|
|
115
|
+
self.model.vllm_config.model_config.use_mla = False
|
|
115
116
|
|
|
116
117
|
self.qwix_config = [
|
|
117
118
|
{
|
|
@@ -131,18 +132,19 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
|
|
|
131
132
|
"""Test that qwix.quantize_model is called with the correct arguments."""
|
|
132
133
|
quantized_model_mock = MagicMock(spec=nnx.Module)
|
|
133
134
|
mock_quantize_model.return_value = quantized_model_mock
|
|
135
|
+
self.model.vllm_config.sharding_config.total_dp_size = 1
|
|
134
136
|
|
|
135
137
|
with patch(
|
|
136
|
-
"tpu_inference.models.jax.utils.
|
|
138
|
+
"tpu_inference.models.jax.utils.qwix.qwix_utils.init_logger",
|
|
137
139
|
return_value=MagicMock()
|
|
138
140
|
), patch(
|
|
139
141
|
"tpu_inference.utils.hbm_usage_gb",
|
|
140
142
|
return_value=[(0.0, 0.0), (0.0, 0.0)]
|
|
141
143
|
), patch(
|
|
142
|
-
"tpu_inference.models.jax.utils.
|
|
144
|
+
"tpu_inference.models.jax.utils.qwix.qwix_utils.create_kv_caches",
|
|
143
145
|
return_value=self.mock_kv_caches
|
|
144
146
|
), patch(
|
|
145
|
-
"tpu_inference.models.jax.utils.
|
|
147
|
+
"tpu_inference.models.jax.utils.qwix.qwix_utils.quantization_config_file_path_to_dict",
|
|
146
148
|
return_value=self.qwix_config):
|
|
147
149
|
returned_model = quantize_qwix.qwix_quantize_nnx_model(
|
|
148
150
|
model=self.model,
|
|
@@ -317,10 +319,9 @@ class TestApplyQwixQuantizationLogic(unittest.TestCase):
|
|
|
317
319
|
self.assertIs(result2, self.mock_model)
|
|
318
320
|
|
|
319
321
|
@patch(
|
|
320
|
-
'tpu_inference.models.jax.utils.
|
|
322
|
+
'tpu_inference.models.jax.utils.qwix.qwix_utils.qwix_quantize_nnx_model'
|
|
321
323
|
)
|
|
322
|
-
@patch(
|
|
323
|
-
'tpu_inference.models.jax.utils.quantization.quantization_utils.utils')
|
|
324
|
+
@patch('tpu_inference.models.jax.utils.qwix.qwix_utils.utils')
|
|
324
325
|
def test_apply_to_abstract_model(self, mock_utils, mock_quantize_func):
|
|
325
326
|
"""Test quantization is correctly applied to an abstract model factory."""
|
|
326
327
|
mock_utils.get_padded_num_heads.return_value = 8
|
|
@@ -357,10 +358,9 @@ class TestApplyQwixQuantizationLogic(unittest.TestCase):
|
|
|
357
358
|
self.assertIs(result_model, quantized_model)
|
|
358
359
|
|
|
359
360
|
@patch(
|
|
360
|
-
'tpu_inference.models.jax.utils.
|
|
361
|
+
'tpu_inference.models.jax.utils.qwix.qwix_utils.qwix_quantize_nnx_model'
|
|
361
362
|
)
|
|
362
|
-
@patch(
|
|
363
|
-
'tpu_inference.models.jax.utils.quantization.quantization_utils.utils')
|
|
363
|
+
@patch('tpu_inference.models.jax.utils.qwix.qwix_utils.utils')
|
|
364
364
|
def test_apply_to_abstract_model_with_initialize_cache(
|
|
365
365
|
self, mock_utils, mock_quantize_func):
|
|
366
366
|
"""Test abstract model quantization with 'initialize_cache' method."""
|
|
@@ -461,15 +461,13 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
|
|
|
461
461
|
# Mock model structure
|
|
462
462
|
self.model = MagicMock(spec=['weight_loader', 'initialize_cache'])
|
|
463
463
|
self.model.weight_loader = MagicMock(
|
|
464
|
-
spec=['scale_dtype', '
|
|
464
|
+
spec=['scale_dtype', 'scale_shape_map_for_random_weight_loading'])
|
|
465
465
|
self.model.weight_loader.scale_dtype = jnp.float16
|
|
466
|
-
self.model.weight_loader.
|
|
466
|
+
self.model.weight_loader.scale_shape_map_for_random_weight_loading = {}
|
|
467
467
|
|
|
468
|
+
@patch('tpu_inference.models.jax.utils.qwix.qwix_utils.nnx.iter_graph')
|
|
468
469
|
@patch(
|
|
469
|
-
'tpu_inference.models.jax.utils.
|
|
470
|
-
)
|
|
471
|
-
@patch(
|
|
472
|
-
'tpu_inference.models.jax.utils.quantization.quantization_utils.get_random_sharded_array'
|
|
470
|
+
'tpu_inference.models.jax.utils.qwix.qwix_utils.get_random_sharded_array'
|
|
473
471
|
)
|
|
474
472
|
def test_successful_initialization(self, mock_get_random_array,
|
|
475
473
|
mock_iter_graph):
|
|
@@ -482,6 +480,10 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
|
|
|
482
480
|
mock_random_array = jax.numpy.ones(1)
|
|
483
481
|
mock_get_random_array.return_value = mock_random_array
|
|
484
482
|
|
|
483
|
+
self.model.weight_loader.scale_shape_map_for_random_weight_loading = {
|
|
484
|
+
'attention.wq': (1, 1)
|
|
485
|
+
}
|
|
486
|
+
|
|
485
487
|
mock_iter_graph.return_value = [
|
|
486
488
|
(('layers', '0', 'attention', 'wq', 'kernel'), mock_weight_param),
|
|
487
489
|
(('layers', '0', 'attention', 'wq', 'array', 'scale'),
|
|
@@ -509,9 +511,7 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
|
|
|
509
511
|
quantize_qwix.load_random_weights_into_qwix_abstract_model(
|
|
510
512
|
self.rng, self.model, self.mesh, invalid_config)
|
|
511
513
|
|
|
512
|
-
@patch(
|
|
513
|
-
'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
|
|
514
|
-
)
|
|
514
|
+
@patch('tpu_inference.models.jax.utils.qwix.qwix_utils.nnx.iter_graph')
|
|
515
515
|
def test_param_shape_setting_no_scale_map(self, mock_iter_graph):
|
|
516
516
|
"""Test correct scale shape calculation when not in the map."""
|
|
517
517
|
old_weight_param_val = jnp.empty((128, 64))
|
|
@@ -525,26 +525,11 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
|
|
|
525
525
|
mock_scale_var),
|
|
526
526
|
]
|
|
527
527
|
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
new_weight_param_val = mock_weight_param.value
|
|
532
|
-
new_scale_var_val = mock_scale_var.value
|
|
533
|
-
|
|
534
|
-
expected_scale_shape = (128 // 64, 64 // 64)
|
|
535
|
-
actual_scale_shape = new_scale_var_val.shape
|
|
536
|
-
|
|
537
|
-
expected_weight_shape = (128, 64)
|
|
538
|
-
actual_weight_shape = new_weight_param_val.shape
|
|
539
|
-
|
|
540
|
-
self.assertEqual(expected_scale_shape, actual_scale_shape)
|
|
541
|
-
self.assertEqual(expected_weight_shape, actual_weight_shape)
|
|
542
|
-
self.assertNotEqual(old_scale_var_val.shape, new_scale_var_val.shape)
|
|
543
|
-
assert jnp.not_equal(old_weight_param_val, new_weight_param_val).all()
|
|
528
|
+
with self.assertRaises(ValueError):
|
|
529
|
+
quantize_qwix.load_random_weights_into_qwix_abstract_model(
|
|
530
|
+
self.rng, self.model, self.mesh, self.quantization_config)
|
|
544
531
|
|
|
545
|
-
@patch(
|
|
546
|
-
'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
|
|
547
|
-
)
|
|
532
|
+
@patch('tpu_inference.models.jax.utils.qwix.qwix_utils.nnx.iter_graph')
|
|
548
533
|
def test_param_shape_setting_with_scale_map(self, mock_iter_graph):
|
|
549
534
|
"""Test correct scale shape calculation when in the map."""
|
|
550
535
|
old_weight_param_val = jnp.empty((128, 64))
|
|
@@ -554,8 +539,8 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
|
|
|
554
539
|
|
|
555
540
|
expected_scale_shape = (55, 34)
|
|
556
541
|
|
|
557
|
-
self.model.weight_loader.
|
|
558
|
-
'wq': expected_scale_shape
|
|
542
|
+
self.model.weight_loader.scale_shape_map_for_random_weight_loading = {
|
|
543
|
+
'attention.wq': expected_scale_shape
|
|
559
544
|
}
|
|
560
545
|
|
|
561
546
|
mock_iter_graph.return_value = [
|
|
@@ -604,9 +589,7 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
|
|
|
604
589
|
mock_randint.assert_not_called()
|
|
605
590
|
mock_normal.assert_called_once()
|
|
606
591
|
|
|
607
|
-
@patch(
|
|
608
|
-
"tpu_inference.models.jax.utils.quantization.quantization_utils.logger.warning"
|
|
609
|
-
)
|
|
592
|
+
@patch("tpu_inference.models.jax.utils.qwix.qwix_utils.logger.warning")
|
|
610
593
|
@patch("jax.make_array_from_callback")
|
|
611
594
|
def test_get_random_sharded_array_sharding_fallback(
|
|
612
595
|
self, mock_make_array, mock_logger_warning):
|
|
@@ -648,7 +631,7 @@ class TestManualQwixQuantization(unittest.TestCase):
|
|
|
648
631
|
self.calibration_method = 'max'
|
|
649
632
|
|
|
650
633
|
@patch(
|
|
651
|
-
'tpu_inference.models.jax.utils.
|
|
634
|
+
'tpu_inference.models.jax.utils.qwix.qwix_utils.ptq.create_quantized_param'
|
|
652
635
|
)
|
|
653
636
|
def test_manually_quantize_qwix_weight(self, mock_create_param):
|
|
654
637
|
"""Test that manually_quantize_qwix_weight calls ptq.create_quantized_param correctly."""
|
|
@@ -672,9 +655,7 @@ class TestManualQwixQuantization(unittest.TestCase):
|
|
|
672
655
|
self.assertEqual(passed_how_to_quantize.calibration_method,
|
|
673
656
|
self.calibration_method)
|
|
674
657
|
|
|
675
|
-
@patch(
|
|
676
|
-
'tpu_inference.models.jax.utils.quantization.quantization_utils.ptq.quantize_act'
|
|
677
|
-
)
|
|
658
|
+
@patch('tpu_inference.models.jax.utils.qwix.qwix_utils.ptq.quantize_act')
|
|
678
659
|
@patch('qwix.pallas.get_current_rule')
|
|
679
660
|
def test_manually_quantize_qwix_activation(self, mock_get_rule,
|
|
680
661
|
mock_quantize_act):
|
|
@@ -832,5 +813,157 @@ class TestGetQuantDtypeFromQwixConfig(unittest.TestCase):
|
|
|
832
813
|
self.assertIsNone(quant_dtype)
|
|
833
814
|
|
|
834
815
|
|
|
816
|
+
class TestGetDefaultQwixQuantizationConfig(unittest.TestCase):
|
|
817
|
+
"""Tests for the get_default_qwix_quantization_config function."""
|
|
818
|
+
|
|
819
|
+
def setUp(self):
|
|
820
|
+
# Mocking the default configs that the function expects to find in the module
|
|
821
|
+
self.mock_deepseek_config = {
|
|
822
|
+
"qwix": {
|
|
823
|
+
"rules": [{
|
|
824
|
+
"module_path": ".*",
|
|
825
|
+
"tile_size": 0
|
|
826
|
+
}]
|
|
827
|
+
}
|
|
828
|
+
}
|
|
829
|
+
self.mock_llama_config = {"qwix": {"rules": [{"name": "llama_rule"}]}}
|
|
830
|
+
self.mock_gpt_oss_config = {"qwix": {"rules": [{"name": "gpt_rule"}]}}
|
|
831
|
+
|
|
832
|
+
# Patch the constants in the module where the function resides
|
|
833
|
+
self.patchers = [
|
|
834
|
+
patch(
|
|
835
|
+
"tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG",
|
|
836
|
+
self.mock_deepseek_config),
|
|
837
|
+
patch(
|
|
838
|
+
"tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_LLAMA4_FP8_CONFIG",
|
|
839
|
+
self.mock_llama_config),
|
|
840
|
+
patch(
|
|
841
|
+
"tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_GPT_OSS_FP4_CONFIG",
|
|
842
|
+
self.mock_gpt_oss_config),
|
|
843
|
+
patch("tpu_inference.models.jax.utils.qwix.qwix_utils.logger",
|
|
844
|
+
MagicMock())
|
|
845
|
+
]
|
|
846
|
+
for p in self.patchers:
|
|
847
|
+
p.start()
|
|
848
|
+
|
|
849
|
+
def tearDown(self):
|
|
850
|
+
for p in self.patchers:
|
|
851
|
+
p.stop()
|
|
852
|
+
|
|
853
|
+
def test_skip_quantization_returns_none(self):
|
|
854
|
+
"""Test that skip_quantization=True returns None immediately."""
|
|
855
|
+
result = quantize_qwix.get_default_qwix_quantization_config(
|
|
856
|
+
MagicMock(), True)
|
|
857
|
+
self.assertIsNone(result)
|
|
858
|
+
|
|
859
|
+
def test_unsupported_model_returns_none(self):
|
|
860
|
+
"""Test that an unknown model type returns None."""
|
|
861
|
+
hf_config = MagicMock()
|
|
862
|
+
hf_config.model_type = "unknown_model"
|
|
863
|
+
result = quantize_qwix.get_default_qwix_quantization_config(
|
|
864
|
+
hf_config, False)
|
|
865
|
+
self.assertIsNone(result)
|
|
866
|
+
|
|
867
|
+
def test_deepseek_v3_success(self):
|
|
868
|
+
"""Test DeepSeek V3 config with valid weight_block_size."""
|
|
869
|
+
hf_config = MagicMock()
|
|
870
|
+
hf_config.model_type = "DeepSeek_V3"
|
|
871
|
+
hf_config.quantization_config = {
|
|
872
|
+
"quant_method": "fp8",
|
|
873
|
+
"weight_block_size": [1, 128]
|
|
874
|
+
}
|
|
875
|
+
|
|
876
|
+
result = quantize_qwix.get_default_qwix_quantization_config(
|
|
877
|
+
hf_config, False)
|
|
878
|
+
|
|
879
|
+
# Check if tile_size was updated from 0 to 128
|
|
880
|
+
self.assertEqual(result["qwix"]["rules"][0]["tile_size"], 128)
|
|
881
|
+
# Ensure it's a deep copy (original mock shouldn't change)
|
|
882
|
+
self.assertEqual(
|
|
883
|
+
self.mock_deepseek_config["qwix"]["rules"][0]["tile_size"], 0)
|
|
884
|
+
|
|
885
|
+
def test_deepseek_v3_invalid_block_size(self):
|
|
886
|
+
"""Test DeepSeek V3 raises ValueError on invalid block size format."""
|
|
887
|
+
hf_config = MagicMock()
|
|
888
|
+
hf_config.model_type = "deepseek_v3"
|
|
889
|
+
hf_config.quantization_config = {
|
|
890
|
+
"quant_method": "fp8",
|
|
891
|
+
"weight_block_size": [128]
|
|
892
|
+
}
|
|
893
|
+
|
|
894
|
+
with self.assertRaisesRegex(ValueError, "Invalid weight_block_size"):
|
|
895
|
+
quantize_qwix.get_default_qwix_quantization_config(
|
|
896
|
+
hf_config, False)
|
|
897
|
+
|
|
898
|
+
def test_deepseek_v3_invalid_block_size_2d_subchannel(self):
|
|
899
|
+
"""Test DeepSeek V3 raises ValueError on invalid block size format."""
|
|
900
|
+
hf_config = MagicMock()
|
|
901
|
+
hf_config.model_type = "deepseek_v3"
|
|
902
|
+
hf_config.quantization_config = {
|
|
903
|
+
"quant_method": "fp8",
|
|
904
|
+
"weight_block_size": [512, 512]
|
|
905
|
+
}
|
|
906
|
+
|
|
907
|
+
with self.assertRaisesRegex(AssertionError,
|
|
908
|
+
"Expected first dimension to be 1"):
|
|
909
|
+
quantize_qwix.get_default_qwix_quantization_config(
|
|
910
|
+
hf_config, False)
|
|
911
|
+
|
|
912
|
+
def test_deepseek_v3_no_weight_block_size(self):
|
|
913
|
+
"""Test DeepSeek V3 config with valid weight_block_size."""
|
|
914
|
+
hf_config = MagicMock()
|
|
915
|
+
hf_config.model_type = "DeepSeek_V3"
|
|
916
|
+
hf_config.quantization_config = {
|
|
917
|
+
"quant_method": "fp8",
|
|
918
|
+
}
|
|
919
|
+
|
|
920
|
+
with self.assertRaisesRegex(
|
|
921
|
+
AssertionError,
|
|
922
|
+
"Expected weight_block_size in quantization_config"):
|
|
923
|
+
|
|
924
|
+
quantize_qwix.get_default_qwix_quantization_config(
|
|
925
|
+
hf_config, False)
|
|
926
|
+
|
|
927
|
+
def test_deepseek_v3_tile_size_assertion(self):
|
|
928
|
+
"""Test DeepSeek V3 raises AssertionError if tile_size is <= 1."""
|
|
929
|
+
hf_config = MagicMock()
|
|
930
|
+
hf_config.model_type = "deepseek_v3"
|
|
931
|
+
hf_config.quantization_config = {
|
|
932
|
+
"quant_method": "fp8",
|
|
933
|
+
"weight_block_size": [1, 1]
|
|
934
|
+
}
|
|
935
|
+
|
|
936
|
+
with self.assertRaises(AssertionError):
|
|
937
|
+
quantize_qwix.get_default_qwix_quantization_config(
|
|
938
|
+
hf_config, False)
|
|
939
|
+
|
|
940
|
+
def test_llama4_success(self):
|
|
941
|
+
"""Test Llama 4 default config path."""
|
|
942
|
+
hf_config = MagicMock()
|
|
943
|
+
hf_config.model_type = "llama4"
|
|
944
|
+
hf_config.quantization_config = {"quant_method": "compressed-tensors"}
|
|
945
|
+
|
|
946
|
+
result = quantize_qwix.get_default_qwix_quantization_config(
|
|
947
|
+
hf_config, False)
|
|
948
|
+
self.assertEqual(result, self.mock_llama_config)
|
|
949
|
+
|
|
950
|
+
def test_gpt_oss_success(self):
|
|
951
|
+
"""Test GPT-OSS default config path."""
|
|
952
|
+
hf_config = MagicMock()
|
|
953
|
+
hf_config.model_type = "gpt_oss"
|
|
954
|
+
hf_config.quantization_config = {"quant_method": "mxfp4"}
|
|
955
|
+
|
|
956
|
+
result = quantize_qwix.get_default_qwix_quantization_config(
|
|
957
|
+
hf_config, False)
|
|
958
|
+
self.assertEqual(result, self.mock_gpt_oss_config)
|
|
959
|
+
|
|
960
|
+
def test_missing_attributes_handled(self):
|
|
961
|
+
"""Test that function handles hf_config objects missing model_type safely."""
|
|
962
|
+
hf_config = object() # No attributes
|
|
963
|
+
result = quantize_qwix.get_default_qwix_quantization_config(
|
|
964
|
+
hf_config, False)
|
|
965
|
+
self.assertIsNone(result)
|
|
966
|
+
|
|
967
|
+
|
|
835
968
|
if __name__ == '__main__':
|
|
836
969
|
unittest.main()
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
from jax import numpy as jnp
|
|
17
|
+
from jax._src import test_util as jtu
|
|
18
|
+
from jax.sharding import Mesh
|
|
19
|
+
|
|
20
|
+
from tpu_inference.layers.jax.rope import (DeepseekScalingRotaryEmbedding,
|
|
21
|
+
RotaryEmbedding)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class RotaryEmbeddingTest(jtu.JaxTestCase):
|
|
25
|
+
|
|
26
|
+
def test_apply_rope(self):
|
|
27
|
+
head_dim = 2
|
|
28
|
+
rope_theta = 10000
|
|
29
|
+
original_max_position_embeddings = 2
|
|
30
|
+
rope = RotaryEmbedding(
|
|
31
|
+
rotary_dim=head_dim,
|
|
32
|
+
rope_theta=rope_theta,
|
|
33
|
+
original_max_position_embeddings=original_max_position_embeddings,
|
|
34
|
+
dtype=jnp.float32)
|
|
35
|
+
rope.initialize_cache()
|
|
36
|
+
self.assertTrue(
|
|
37
|
+
rope.sin_cos_cache.shape == (original_max_position_embeddings,
|
|
38
|
+
head_dim))
|
|
39
|
+
expected_sin_cos = jnp.array([[1, 0], [0.5403023, 0.841471]],
|
|
40
|
+
dtype=jnp.float32)
|
|
41
|
+
self.assertArraysAllClose(rope.sin_cos_cache, expected_sin_cos)
|
|
42
|
+
|
|
43
|
+
num_tokens = 2
|
|
44
|
+
num_heads = 1
|
|
45
|
+
positions = jnp.arange(num_tokens)
|
|
46
|
+
x = jnp.ones((num_tokens, num_heads, head_dim))
|
|
47
|
+
x_rope = rope.apply_rope(positions, x)
|
|
48
|
+
expected_x_rope = jnp.array([[[1, 1]], [[-0.30116874, 1.3817732]]],
|
|
49
|
+
dtype=jnp.float32)
|
|
50
|
+
self.assertTrue(x_rope.shape == x.shape)
|
|
51
|
+
self.assertArraysAllClose(x_rope, expected_x_rope)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class DeepseekScalingRotaryEmbeddingTest(jtu.JaxTestCase):
|
|
55
|
+
|
|
56
|
+
def test_apply_rope(self):
|
|
57
|
+
head_dim = 2
|
|
58
|
+
rope_theta = 10000
|
|
59
|
+
original_max_position_embeddings = 1
|
|
60
|
+
scaling_factor = 2
|
|
61
|
+
devices = jax.devices()
|
|
62
|
+
mesh = Mesh(devices, ('data', ))
|
|
63
|
+
|
|
64
|
+
rope = DeepseekScalingRotaryEmbedding(
|
|
65
|
+
rotary_dim=head_dim,
|
|
66
|
+
rope_theta=rope_theta,
|
|
67
|
+
original_max_position_embeddings=original_max_position_embeddings,
|
|
68
|
+
scaling_factor=scaling_factor,
|
|
69
|
+
dtype=jnp.float32)
|
|
70
|
+
rope.initialize_cache(mesh)
|
|
71
|
+
expected_padded_dim = 128
|
|
72
|
+
self.assertTrue(
|
|
73
|
+
rope.sin_cos_cache.shape == (scaling_factor *
|
|
74
|
+
original_max_position_embeddings,
|
|
75
|
+
expected_padded_dim))
|
|
76
|
+
|
|
77
|
+
valid_cache_slice = rope.sin_cos_cache[:, :head_dim]
|
|
78
|
+
|
|
79
|
+
expected_sin_cos = jnp.array([[1.0693147, 0], [0.5777532, 0.8997973]],
|
|
80
|
+
dtype=jnp.float32)
|
|
81
|
+
|
|
82
|
+
self.assertArraysAllClose(valid_cache_slice, expected_sin_cos)
|
|
83
|
+
|
|
84
|
+
num_tokens = 2
|
|
85
|
+
num_heads = 1
|
|
86
|
+
positions = jnp.arange(num_tokens)
|
|
87
|
+
x = jnp.ones((num_tokens, num_heads, head_dim))
|
|
88
|
+
x_rope = rope.apply_rope(positions, x)
|
|
89
|
+
expected_x_rope = jnp.array(
|
|
90
|
+
[[[1.0693147, 1.0693147]], [[-0.32204413, 1.4775505]]],
|
|
91
|
+
dtype=jnp.float32)
|
|
92
|
+
self.assertTrue(x_rope.shape == x.shape)
|
|
93
|
+
self.assertArraysAllClose(x_rope, expected_x_rope)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import unittest
|
|
16
|
+
from unittest.mock import MagicMock
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
|
|
20
|
+
from tpu_inference.layers.common.sharding import (Sharding, ShardingConfig,
|
|
21
|
+
ShardingRulesConfig,
|
|
22
|
+
ShardingStrategy)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TestSharding(unittest.TestCase):
|
|
26
|
+
"""Unit test suite for the sharding configuration logic."""
|
|
27
|
+
|
|
28
|
+
def setUp(self):
|
|
29
|
+
"""Sets up the testing environment before each test."""
|
|
30
|
+
|
|
31
|
+
self.mock_devices = [MagicMock(coords=i) for i in range(8)]
|
|
32
|
+
self.original_jax_devices = jax.devices
|
|
33
|
+
jax.devices = lambda: self.mock_devices
|
|
34
|
+
|
|
35
|
+
def tearDown(self):
|
|
36
|
+
"""Restores the original jax.devices function after tests."""
|
|
37
|
+
jax.devices = self.original_jax_devices
|
|
38
|
+
|
|
39
|
+
def test_sharding_strategy_init(self):
|
|
40
|
+
"""Tests the initialization of the ShardingStrategy."""
|
|
41
|
+
strategy = ShardingStrategy(
|
|
42
|
+
tensor_parallelism=2,
|
|
43
|
+
expert_parallelism=4,
|
|
44
|
+
data_parallelism=1,
|
|
45
|
+
sequence_parallelism=1,
|
|
46
|
+
)
|
|
47
|
+
self.assertEqual(strategy.tensor_parallelism, 2)
|
|
48
|
+
self.assertEqual(strategy.expert_parallelism, 4)
|
|
49
|
+
|
|
50
|
+
def test_sharding_config_init(self):
|
|
51
|
+
"""Tests the initialization of ShardingConfig."""
|
|
52
|
+
config = ShardingConfig()
|
|
53
|
+
self.assertIsInstance(config.prefill_rules, ShardingRulesConfig)
|
|
54
|
+
self.assertIsInstance(config.generate_rules, ShardingRulesConfig)
|
|
55
|
+
|
|
56
|
+
custom_rules = ShardingRulesConfig(activation_ffw_td=("model", None))
|
|
57
|
+
config_with_rules = ShardingConfig(prefill_rules=custom_rules)
|
|
58
|
+
self.assertEqual(config_with_rules.prefill_rules.activation_ffw_td,
|
|
59
|
+
("model", None))
|
|
60
|
+
|
|
61
|
+
def test_apply_overrides(self):
|
|
62
|
+
"""Tests the _apply_overrides method for valid and invalid keys."""
|
|
63
|
+
sharding = Sharding(
|
|
64
|
+
prefill_rules={},
|
|
65
|
+
generate_rules={},
|
|
66
|
+
)
|
|
67
|
+
config_obj = ShardingRulesConfig()
|
|
68
|
+
|
|
69
|
+
valid_overrides = {"activation_ffw_td": ("model", None)}
|
|
70
|
+
sharding._apply_overrides(config_obj, valid_overrides)
|
|
71
|
+
self.assertEqual(config_obj.activation_ffw_td, ("model", None))
|
|
72
|
+
|
|
73
|
+
invalid_overrides = {"non_existent_attribute": (None, "model")}
|
|
74
|
+
with self.assertRaises(AttributeError):
|
|
75
|
+
sharding._apply_overrides(config_obj, invalid_overrides)
|
|
76
|
+
|
|
77
|
+
def test_default_sharding_config(self):
|
|
78
|
+
"""Tests that default sharding rules are created correctly."""
|
|
79
|
+
sharding = Sharding(
|
|
80
|
+
prefill_rules={},
|
|
81
|
+
generate_rules={},
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
sharding_cfg = sharding.get_sharding_cfg()
|
|
85
|
+
generate_rules = sharding_cfg.generate_rules
|
|
86
|
+
|
|
87
|
+
self.assertEqual(generate_rules.ffw_weight_df, (None, "model"))
|
|
88
|
+
self.assertEqual(generate_rules.moe_router_de, (None, "model"))
|
|
89
|
+
self.assertEqual(generate_rules.attn_q_weight_dnh,
|
|
90
|
+
(None, "model", None))
|
|
91
|
+
|
|
92
|
+
def test_sharding_init_with_overrides(self):
|
|
93
|
+
"""Tests Sharding initialization with programmatic overrides."""
|
|
94
|
+
generate_overrides = {"logits_tv": ("data", "model")}
|
|
95
|
+
|
|
96
|
+
sharding = Sharding(
|
|
97
|
+
generate_rules=generate_overrides,
|
|
98
|
+
prefill_rules={},
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
sharding_cfg = sharding.get_sharding_cfg()
|
|
102
|
+
self.assertNotEqual(sharding_cfg.generate_rules.logits_tv,
|
|
103
|
+
(None, "model"))
|
|
104
|
+
self.assertEqual(sharding_cfg.generate_rules.logits_tv,
|
|
105
|
+
("data", "model"))
|
|
106
|
+
|
|
107
|
+
def test_get_overrides_from_vllm_config(self):
|
|
108
|
+
"""Tests fetching sharding overrides from a mock VllmConfig."""
|
|
109
|
+
|
|
110
|
+
mock_vllm_config_prefill = MagicMock()
|
|
111
|
+
mock_vllm_config_prefill.additional_config = {
|
|
112
|
+
"sharding": {
|
|
113
|
+
"logical_rules": {
|
|
114
|
+
"all": {
|
|
115
|
+
"norm_scale": ("model", )
|
|
116
|
+
},
|
|
117
|
+
"prefill": {
|
|
118
|
+
"activation_ffw_td": ("data", "model")
|
|
119
|
+
},
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
sharding_prefill = Sharding(
|
|
124
|
+
vllm_config=mock_vllm_config_prefill,
|
|
125
|
+
prefill_rules={},
|
|
126
|
+
generate_rules={},
|
|
127
|
+
)
|
|
128
|
+
prefill_overrides = sharding_prefill._get_overrides("prefill")
|
|
129
|
+
|
|
130
|
+
self.assertEqual(prefill_overrides["norm_scale"], ("model", ))
|
|
131
|
+
self.assertEqual(prefill_overrides["activation_ffw_td"],
|
|
132
|
+
("data", "model"))
|
|
133
|
+
|
|
134
|
+
mock_vllm_config_generate = MagicMock()
|
|
135
|
+
mock_vllm_config_generate.additional_config = {
|
|
136
|
+
"sharding": {
|
|
137
|
+
"logical_rules": {
|
|
138
|
+
"all": {
|
|
139
|
+
"norm_scale": ("model", )
|
|
140
|
+
},
|
|
141
|
+
"prefill": {
|
|
142
|
+
"activation_ffw_td": ("data", "model")
|
|
143
|
+
},
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
sharding_generate = Sharding(
|
|
148
|
+
vllm_config=mock_vllm_config_generate,
|
|
149
|
+
prefill_rules={},
|
|
150
|
+
generate_rules={},
|
|
151
|
+
)
|
|
152
|
+
generate_overrides = sharding_generate._get_overrides("generate")
|
|
153
|
+
|
|
154
|
+
self.assertEqual(generate_overrides["norm_scale"], ("model", ))
|
|
155
|
+
self.assertNotIn("activation_ffw_td", generate_overrides)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
if __name__ == "__main__":
|
|
159
|
+
unittest.main()
|