tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -11,9 +11,9 @@ from jax.sharding import Mesh, NamedSharding
|
|
|
11
11
|
from jax.sharding import PartitionSpec as P
|
|
12
12
|
from qwix._src.providers import ptq
|
|
13
13
|
|
|
14
|
-
import tpu_inference.models.jax.utils.
|
|
14
|
+
import tpu_inference.models.jax.utils.qwix.qwix_utils as quantize_qwix # noqa: E402
|
|
15
15
|
from tpu_inference.models.common.model_loader import apply_qwix_quantization
|
|
16
|
-
from tpu_inference.models.jax.utils.
|
|
16
|
+
from tpu_inference.models.jax.utils.qwix.qwix_utils import (
|
|
17
17
|
DEFAULT_MAX_NUM_BLOCKS_PER_REQ, DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS,
|
|
18
18
|
DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS)
|
|
19
19
|
|
|
@@ -29,8 +29,7 @@ module_mocks = {
|
|
|
29
29
|
'vllm.config': MagicMock(),
|
|
30
30
|
'tpu_inference': MagicMock(),
|
|
31
31
|
'tpu_inference.logger': MagicMock(init_logger=lambda name: MagicMock()),
|
|
32
|
-
'tpu_inference.models.jax.utils.
|
|
33
|
-
MagicMock(),
|
|
32
|
+
'tpu_inference.models.jax.utils.qwix.qwix_utils': MagicMock(),
|
|
34
33
|
}
|
|
35
34
|
|
|
36
35
|
|
|
@@ -136,16 +135,16 @@ class TestQwixQuantizeNnxModel(unittest.TestCase):
|
|
|
136
135
|
self.model.vllm_config.sharding_config.total_dp_size = 1
|
|
137
136
|
|
|
138
137
|
with patch(
|
|
139
|
-
"tpu_inference.models.jax.utils.
|
|
138
|
+
"tpu_inference.models.jax.utils.qwix.qwix_utils.init_logger",
|
|
140
139
|
return_value=MagicMock()
|
|
141
140
|
), patch(
|
|
142
141
|
"tpu_inference.utils.hbm_usage_gb",
|
|
143
142
|
return_value=[(0.0, 0.0), (0.0, 0.0)]
|
|
144
143
|
), patch(
|
|
145
|
-
"tpu_inference.models.jax.utils.
|
|
144
|
+
"tpu_inference.models.jax.utils.qwix.qwix_utils.create_kv_caches",
|
|
146
145
|
return_value=self.mock_kv_caches
|
|
147
146
|
), patch(
|
|
148
|
-
"tpu_inference.models.jax.utils.
|
|
147
|
+
"tpu_inference.models.jax.utils.qwix.qwix_utils.quantization_config_file_path_to_dict",
|
|
149
148
|
return_value=self.qwix_config):
|
|
150
149
|
returned_model = quantize_qwix.qwix_quantize_nnx_model(
|
|
151
150
|
model=self.model,
|
|
@@ -320,10 +319,9 @@ class TestApplyQwixQuantizationLogic(unittest.TestCase):
|
|
|
320
319
|
self.assertIs(result2, self.mock_model)
|
|
321
320
|
|
|
322
321
|
@patch(
|
|
323
|
-
'tpu_inference.models.jax.utils.
|
|
322
|
+
'tpu_inference.models.jax.utils.qwix.qwix_utils.qwix_quantize_nnx_model'
|
|
324
323
|
)
|
|
325
|
-
@patch(
|
|
326
|
-
'tpu_inference.models.jax.utils.quantization.quantization_utils.utils')
|
|
324
|
+
@patch('tpu_inference.models.jax.utils.qwix.qwix_utils.utils')
|
|
327
325
|
def test_apply_to_abstract_model(self, mock_utils, mock_quantize_func):
|
|
328
326
|
"""Test quantization is correctly applied to an abstract model factory."""
|
|
329
327
|
mock_utils.get_padded_num_heads.return_value = 8
|
|
@@ -360,10 +358,9 @@ class TestApplyQwixQuantizationLogic(unittest.TestCase):
|
|
|
360
358
|
self.assertIs(result_model, quantized_model)
|
|
361
359
|
|
|
362
360
|
@patch(
|
|
363
|
-
'tpu_inference.models.jax.utils.
|
|
361
|
+
'tpu_inference.models.jax.utils.qwix.qwix_utils.qwix_quantize_nnx_model'
|
|
364
362
|
)
|
|
365
|
-
@patch(
|
|
366
|
-
'tpu_inference.models.jax.utils.quantization.quantization_utils.utils')
|
|
363
|
+
@patch('tpu_inference.models.jax.utils.qwix.qwix_utils.utils')
|
|
367
364
|
def test_apply_to_abstract_model_with_initialize_cache(
|
|
368
365
|
self, mock_utils, mock_quantize_func):
|
|
369
366
|
"""Test abstract model quantization with 'initialize_cache' method."""
|
|
@@ -464,15 +461,13 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
|
|
|
464
461
|
# Mock model structure
|
|
465
462
|
self.model = MagicMock(spec=['weight_loader', 'initialize_cache'])
|
|
466
463
|
self.model.weight_loader = MagicMock(
|
|
467
|
-
spec=['scale_dtype', '
|
|
464
|
+
spec=['scale_dtype', 'scale_shape_map_for_random_weight_loading'])
|
|
468
465
|
self.model.weight_loader.scale_dtype = jnp.float16
|
|
469
|
-
self.model.weight_loader.
|
|
466
|
+
self.model.weight_loader.scale_shape_map_for_random_weight_loading = {}
|
|
470
467
|
|
|
468
|
+
@patch('tpu_inference.models.jax.utils.qwix.qwix_utils.nnx.iter_graph')
|
|
471
469
|
@patch(
|
|
472
|
-
'tpu_inference.models.jax.utils.
|
|
473
|
-
)
|
|
474
|
-
@patch(
|
|
475
|
-
'tpu_inference.models.jax.utils.quantization.quantization_utils.get_random_sharded_array'
|
|
470
|
+
'tpu_inference.models.jax.utils.qwix.qwix_utils.get_random_sharded_array'
|
|
476
471
|
)
|
|
477
472
|
def test_successful_initialization(self, mock_get_random_array,
|
|
478
473
|
mock_iter_graph):
|
|
@@ -485,6 +480,10 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
|
|
|
485
480
|
mock_random_array = jax.numpy.ones(1)
|
|
486
481
|
mock_get_random_array.return_value = mock_random_array
|
|
487
482
|
|
|
483
|
+
self.model.weight_loader.scale_shape_map_for_random_weight_loading = {
|
|
484
|
+
'attention.wq': (1, 1)
|
|
485
|
+
}
|
|
486
|
+
|
|
488
487
|
mock_iter_graph.return_value = [
|
|
489
488
|
(('layers', '0', 'attention', 'wq', 'kernel'), mock_weight_param),
|
|
490
489
|
(('layers', '0', 'attention', 'wq', 'array', 'scale'),
|
|
@@ -512,9 +511,7 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
|
|
|
512
511
|
quantize_qwix.load_random_weights_into_qwix_abstract_model(
|
|
513
512
|
self.rng, self.model, self.mesh, invalid_config)
|
|
514
513
|
|
|
515
|
-
@patch(
|
|
516
|
-
'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
|
|
517
|
-
)
|
|
514
|
+
@patch('tpu_inference.models.jax.utils.qwix.qwix_utils.nnx.iter_graph')
|
|
518
515
|
def test_param_shape_setting_no_scale_map(self, mock_iter_graph):
|
|
519
516
|
"""Test correct scale shape calculation when not in the map."""
|
|
520
517
|
old_weight_param_val = jnp.empty((128, 64))
|
|
@@ -528,26 +525,11 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
|
|
|
528
525
|
mock_scale_var),
|
|
529
526
|
]
|
|
530
527
|
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
new_weight_param_val = mock_weight_param.value
|
|
535
|
-
new_scale_var_val = mock_scale_var.value
|
|
536
|
-
|
|
537
|
-
expected_scale_shape = (128 // 64, 64 // 64)
|
|
538
|
-
actual_scale_shape = new_scale_var_val.shape
|
|
539
|
-
|
|
540
|
-
expected_weight_shape = (128, 64)
|
|
541
|
-
actual_weight_shape = new_weight_param_val.shape
|
|
542
|
-
|
|
543
|
-
self.assertEqual(expected_scale_shape, actual_scale_shape)
|
|
544
|
-
self.assertEqual(expected_weight_shape, actual_weight_shape)
|
|
545
|
-
self.assertNotEqual(old_scale_var_val.shape, new_scale_var_val.shape)
|
|
546
|
-
assert jnp.not_equal(old_weight_param_val, new_weight_param_val).all()
|
|
528
|
+
with self.assertRaises(ValueError):
|
|
529
|
+
quantize_qwix.load_random_weights_into_qwix_abstract_model(
|
|
530
|
+
self.rng, self.model, self.mesh, self.quantization_config)
|
|
547
531
|
|
|
548
|
-
@patch(
|
|
549
|
-
'tpu_inference.models.jax.utils.quantization.quantization_utils.nnx.iter_graph'
|
|
550
|
-
)
|
|
532
|
+
@patch('tpu_inference.models.jax.utils.qwix.qwix_utils.nnx.iter_graph')
|
|
551
533
|
def test_param_shape_setting_with_scale_map(self, mock_iter_graph):
|
|
552
534
|
"""Test correct scale shape calculation when in the map."""
|
|
553
535
|
old_weight_param_val = jnp.empty((128, 64))
|
|
@@ -557,8 +539,8 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
|
|
|
557
539
|
|
|
558
540
|
expected_scale_shape = (55, 34)
|
|
559
541
|
|
|
560
|
-
self.model.weight_loader.
|
|
561
|
-
'wq': expected_scale_shape
|
|
542
|
+
self.model.weight_loader.scale_shape_map_for_random_weight_loading = {
|
|
543
|
+
'attention.wq': expected_scale_shape
|
|
562
544
|
}
|
|
563
545
|
|
|
564
546
|
mock_iter_graph.return_value = [
|
|
@@ -607,9 +589,7 @@ class TestLoadRandomWeightsIntoQwixAbstractModel(unittest.TestCase):
|
|
|
607
589
|
mock_randint.assert_not_called()
|
|
608
590
|
mock_normal.assert_called_once()
|
|
609
591
|
|
|
610
|
-
@patch(
|
|
611
|
-
"tpu_inference.models.jax.utils.quantization.quantization_utils.logger.warning"
|
|
612
|
-
)
|
|
592
|
+
@patch("tpu_inference.models.jax.utils.qwix.qwix_utils.logger.warning")
|
|
613
593
|
@patch("jax.make_array_from_callback")
|
|
614
594
|
def test_get_random_sharded_array_sharding_fallback(
|
|
615
595
|
self, mock_make_array, mock_logger_warning):
|
|
@@ -651,7 +631,7 @@ class TestManualQwixQuantization(unittest.TestCase):
|
|
|
651
631
|
self.calibration_method = 'max'
|
|
652
632
|
|
|
653
633
|
@patch(
|
|
654
|
-
'tpu_inference.models.jax.utils.
|
|
634
|
+
'tpu_inference.models.jax.utils.qwix.qwix_utils.ptq.create_quantized_param'
|
|
655
635
|
)
|
|
656
636
|
def test_manually_quantize_qwix_weight(self, mock_create_param):
|
|
657
637
|
"""Test that manually_quantize_qwix_weight calls ptq.create_quantized_param correctly."""
|
|
@@ -675,9 +655,7 @@ class TestManualQwixQuantization(unittest.TestCase):
|
|
|
675
655
|
self.assertEqual(passed_how_to_quantize.calibration_method,
|
|
676
656
|
self.calibration_method)
|
|
677
657
|
|
|
678
|
-
@patch(
|
|
679
|
-
'tpu_inference.models.jax.utils.quantization.quantization_utils.ptq.quantize_act'
|
|
680
|
-
)
|
|
658
|
+
@patch('tpu_inference.models.jax.utils.qwix.qwix_utils.ptq.quantize_act')
|
|
681
659
|
@patch('qwix.pallas.get_current_rule')
|
|
682
660
|
def test_manually_quantize_qwix_activation(self, mock_get_rule,
|
|
683
661
|
mock_quantize_act):
|
|
@@ -835,5 +813,157 @@ class TestGetQuantDtypeFromQwixConfig(unittest.TestCase):
|
|
|
835
813
|
self.assertIsNone(quant_dtype)
|
|
836
814
|
|
|
837
815
|
|
|
816
|
+
class TestGetDefaultQwixQuantizationConfig(unittest.TestCase):
|
|
817
|
+
"""Tests for the get_default_qwix_quantization_config function."""
|
|
818
|
+
|
|
819
|
+
def setUp(self):
|
|
820
|
+
# Mocking the default configs that the function expects to find in the module
|
|
821
|
+
self.mock_deepseek_config = {
|
|
822
|
+
"qwix": {
|
|
823
|
+
"rules": [{
|
|
824
|
+
"module_path": ".*",
|
|
825
|
+
"tile_size": 0
|
|
826
|
+
}]
|
|
827
|
+
}
|
|
828
|
+
}
|
|
829
|
+
self.mock_llama_config = {"qwix": {"rules": [{"name": "llama_rule"}]}}
|
|
830
|
+
self.mock_gpt_oss_config = {"qwix": {"rules": [{"name": "gpt_rule"}]}}
|
|
831
|
+
|
|
832
|
+
# Patch the constants in the module where the function resides
|
|
833
|
+
self.patchers = [
|
|
834
|
+
patch(
|
|
835
|
+
"tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG",
|
|
836
|
+
self.mock_deepseek_config),
|
|
837
|
+
patch(
|
|
838
|
+
"tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_LLAMA4_FP8_CONFIG",
|
|
839
|
+
self.mock_llama_config),
|
|
840
|
+
patch(
|
|
841
|
+
"tpu_inference.models.jax.utils.qwix.qwix_utils.DEFAULT_GPT_OSS_FP4_CONFIG",
|
|
842
|
+
self.mock_gpt_oss_config),
|
|
843
|
+
patch("tpu_inference.models.jax.utils.qwix.qwix_utils.logger",
|
|
844
|
+
MagicMock())
|
|
845
|
+
]
|
|
846
|
+
for p in self.patchers:
|
|
847
|
+
p.start()
|
|
848
|
+
|
|
849
|
+
def tearDown(self):
|
|
850
|
+
for p in self.patchers:
|
|
851
|
+
p.stop()
|
|
852
|
+
|
|
853
|
+
def test_skip_quantization_returns_none(self):
|
|
854
|
+
"""Test that skip_quantization=True returns None immediately."""
|
|
855
|
+
result = quantize_qwix.get_default_qwix_quantization_config(
|
|
856
|
+
MagicMock(), True)
|
|
857
|
+
self.assertIsNone(result)
|
|
858
|
+
|
|
859
|
+
def test_unsupported_model_returns_none(self):
|
|
860
|
+
"""Test that an unknown model type returns None."""
|
|
861
|
+
hf_config = MagicMock()
|
|
862
|
+
hf_config.model_type = "unknown_model"
|
|
863
|
+
result = quantize_qwix.get_default_qwix_quantization_config(
|
|
864
|
+
hf_config, False)
|
|
865
|
+
self.assertIsNone(result)
|
|
866
|
+
|
|
867
|
+
def test_deepseek_v3_success(self):
|
|
868
|
+
"""Test DeepSeek V3 config with valid weight_block_size."""
|
|
869
|
+
hf_config = MagicMock()
|
|
870
|
+
hf_config.model_type = "DeepSeek_V3"
|
|
871
|
+
hf_config.quantization_config = {
|
|
872
|
+
"quant_method": "fp8",
|
|
873
|
+
"weight_block_size": [1, 128]
|
|
874
|
+
}
|
|
875
|
+
|
|
876
|
+
result = quantize_qwix.get_default_qwix_quantization_config(
|
|
877
|
+
hf_config, False)
|
|
878
|
+
|
|
879
|
+
# Check if tile_size was updated from 0 to 128
|
|
880
|
+
self.assertEqual(result["qwix"]["rules"][0]["tile_size"], 128)
|
|
881
|
+
# Ensure it's a deep copy (original mock shouldn't change)
|
|
882
|
+
self.assertEqual(
|
|
883
|
+
self.mock_deepseek_config["qwix"]["rules"][0]["tile_size"], 0)
|
|
884
|
+
|
|
885
|
+
def test_deepseek_v3_invalid_block_size(self):
|
|
886
|
+
"""Test DeepSeek V3 raises ValueError on invalid block size format."""
|
|
887
|
+
hf_config = MagicMock()
|
|
888
|
+
hf_config.model_type = "deepseek_v3"
|
|
889
|
+
hf_config.quantization_config = {
|
|
890
|
+
"quant_method": "fp8",
|
|
891
|
+
"weight_block_size": [128]
|
|
892
|
+
}
|
|
893
|
+
|
|
894
|
+
with self.assertRaisesRegex(ValueError, "Invalid weight_block_size"):
|
|
895
|
+
quantize_qwix.get_default_qwix_quantization_config(
|
|
896
|
+
hf_config, False)
|
|
897
|
+
|
|
898
|
+
def test_deepseek_v3_invalid_block_size_2d_subchannel(self):
|
|
899
|
+
"""Test DeepSeek V3 raises ValueError on invalid block size format."""
|
|
900
|
+
hf_config = MagicMock()
|
|
901
|
+
hf_config.model_type = "deepseek_v3"
|
|
902
|
+
hf_config.quantization_config = {
|
|
903
|
+
"quant_method": "fp8",
|
|
904
|
+
"weight_block_size": [512, 512]
|
|
905
|
+
}
|
|
906
|
+
|
|
907
|
+
with self.assertRaisesRegex(AssertionError,
|
|
908
|
+
"Expected first dimension to be 1"):
|
|
909
|
+
quantize_qwix.get_default_qwix_quantization_config(
|
|
910
|
+
hf_config, False)
|
|
911
|
+
|
|
912
|
+
def test_deepseek_v3_no_weight_block_size(self):
|
|
913
|
+
"""Test DeepSeek V3 config with valid weight_block_size."""
|
|
914
|
+
hf_config = MagicMock()
|
|
915
|
+
hf_config.model_type = "DeepSeek_V3"
|
|
916
|
+
hf_config.quantization_config = {
|
|
917
|
+
"quant_method": "fp8",
|
|
918
|
+
}
|
|
919
|
+
|
|
920
|
+
with self.assertRaisesRegex(
|
|
921
|
+
AssertionError,
|
|
922
|
+
"Expected weight_block_size in quantization_config"):
|
|
923
|
+
|
|
924
|
+
quantize_qwix.get_default_qwix_quantization_config(
|
|
925
|
+
hf_config, False)
|
|
926
|
+
|
|
927
|
+
def test_deepseek_v3_tile_size_assertion(self):
|
|
928
|
+
"""Test DeepSeek V3 raises AssertionError if tile_size is <= 1."""
|
|
929
|
+
hf_config = MagicMock()
|
|
930
|
+
hf_config.model_type = "deepseek_v3"
|
|
931
|
+
hf_config.quantization_config = {
|
|
932
|
+
"quant_method": "fp8",
|
|
933
|
+
"weight_block_size": [1, 1]
|
|
934
|
+
}
|
|
935
|
+
|
|
936
|
+
with self.assertRaises(AssertionError):
|
|
937
|
+
quantize_qwix.get_default_qwix_quantization_config(
|
|
938
|
+
hf_config, False)
|
|
939
|
+
|
|
940
|
+
def test_llama4_success(self):
|
|
941
|
+
"""Test Llama 4 default config path."""
|
|
942
|
+
hf_config = MagicMock()
|
|
943
|
+
hf_config.model_type = "llama4"
|
|
944
|
+
hf_config.quantization_config = {"quant_method": "compressed-tensors"}
|
|
945
|
+
|
|
946
|
+
result = quantize_qwix.get_default_qwix_quantization_config(
|
|
947
|
+
hf_config, False)
|
|
948
|
+
self.assertEqual(result, self.mock_llama_config)
|
|
949
|
+
|
|
950
|
+
def test_gpt_oss_success(self):
|
|
951
|
+
"""Test GPT-OSS default config path."""
|
|
952
|
+
hf_config = MagicMock()
|
|
953
|
+
hf_config.model_type = "gpt_oss"
|
|
954
|
+
hf_config.quantization_config = {"quant_method": "mxfp4"}
|
|
955
|
+
|
|
956
|
+
result = quantize_qwix.get_default_qwix_quantization_config(
|
|
957
|
+
hf_config, False)
|
|
958
|
+
self.assertEqual(result, self.mock_gpt_oss_config)
|
|
959
|
+
|
|
960
|
+
def test_missing_attributes_handled(self):
|
|
961
|
+
"""Test that function handles hf_config objects missing model_type safely."""
|
|
962
|
+
hf_config = object() # No attributes
|
|
963
|
+
result = quantize_qwix.get_default_qwix_quantization_config(
|
|
964
|
+
hf_config, False)
|
|
965
|
+
self.assertIsNone(result)
|
|
966
|
+
|
|
967
|
+
|
|
838
968
|
if __name__ == '__main__':
|
|
839
969
|
unittest.main()
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import jax
|
|
16
|
+
from jax import numpy as jnp
|
|
17
|
+
from jax._src import test_util as jtu
|
|
18
|
+
from jax.sharding import Mesh
|
|
19
|
+
|
|
20
|
+
from tpu_inference.layers.jax.rope import (DeepseekScalingRotaryEmbedding,
|
|
21
|
+
RotaryEmbedding)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class RotaryEmbeddingTest(jtu.JaxTestCase):
|
|
25
|
+
|
|
26
|
+
def test_apply_rope(self):
|
|
27
|
+
head_dim = 2
|
|
28
|
+
rope_theta = 10000
|
|
29
|
+
original_max_position_embeddings = 2
|
|
30
|
+
rope = RotaryEmbedding(
|
|
31
|
+
rotary_dim=head_dim,
|
|
32
|
+
rope_theta=rope_theta,
|
|
33
|
+
original_max_position_embeddings=original_max_position_embeddings,
|
|
34
|
+
dtype=jnp.float32)
|
|
35
|
+
rope.initialize_cache()
|
|
36
|
+
self.assertTrue(
|
|
37
|
+
rope.sin_cos_cache.shape == (original_max_position_embeddings,
|
|
38
|
+
head_dim))
|
|
39
|
+
expected_sin_cos = jnp.array([[1, 0], [0.5403023, 0.841471]],
|
|
40
|
+
dtype=jnp.float32)
|
|
41
|
+
self.assertArraysAllClose(rope.sin_cos_cache, expected_sin_cos)
|
|
42
|
+
|
|
43
|
+
num_tokens = 2
|
|
44
|
+
num_heads = 1
|
|
45
|
+
positions = jnp.arange(num_tokens)
|
|
46
|
+
x = jnp.ones((num_tokens, num_heads, head_dim))
|
|
47
|
+
x_rope = rope.apply_rope(positions, x)
|
|
48
|
+
expected_x_rope = jnp.array([[[1, 1]], [[-0.30116874, 1.3817732]]],
|
|
49
|
+
dtype=jnp.float32)
|
|
50
|
+
self.assertTrue(x_rope.shape == x.shape)
|
|
51
|
+
self.assertArraysAllClose(x_rope, expected_x_rope)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class DeepseekScalingRotaryEmbeddingTest(jtu.JaxTestCase):
|
|
55
|
+
|
|
56
|
+
def test_apply_rope(self):
|
|
57
|
+
head_dim = 2
|
|
58
|
+
rope_theta = 10000
|
|
59
|
+
original_max_position_embeddings = 1
|
|
60
|
+
scaling_factor = 2
|
|
61
|
+
devices = jax.devices()
|
|
62
|
+
mesh = Mesh(devices, ('data', ))
|
|
63
|
+
|
|
64
|
+
rope = DeepseekScalingRotaryEmbedding(
|
|
65
|
+
rotary_dim=head_dim,
|
|
66
|
+
rope_theta=rope_theta,
|
|
67
|
+
original_max_position_embeddings=original_max_position_embeddings,
|
|
68
|
+
scaling_factor=scaling_factor,
|
|
69
|
+
dtype=jnp.float32)
|
|
70
|
+
rope.initialize_cache(mesh)
|
|
71
|
+
expected_padded_dim = 128
|
|
72
|
+
self.assertTrue(
|
|
73
|
+
rope.sin_cos_cache.shape == (scaling_factor *
|
|
74
|
+
original_max_position_embeddings,
|
|
75
|
+
expected_padded_dim))
|
|
76
|
+
|
|
77
|
+
valid_cache_slice = rope.sin_cos_cache[:, :head_dim]
|
|
78
|
+
|
|
79
|
+
expected_sin_cos = jnp.array([[1.0693147, 0], [0.5777532, 0.8997973]],
|
|
80
|
+
dtype=jnp.float32)
|
|
81
|
+
|
|
82
|
+
self.assertArraysAllClose(valid_cache_slice, expected_sin_cos)
|
|
83
|
+
|
|
84
|
+
num_tokens = 2
|
|
85
|
+
num_heads = 1
|
|
86
|
+
positions = jnp.arange(num_tokens)
|
|
87
|
+
x = jnp.ones((num_tokens, num_heads, head_dim))
|
|
88
|
+
x_rope = rope.apply_rope(positions, x)
|
|
89
|
+
expected_x_rope = jnp.array(
|
|
90
|
+
[[[1.0693147, 1.0693147]], [[-0.32204413, 1.4775505]]],
|
|
91
|
+
dtype=jnp.float32)
|
|
92
|
+
self.assertTrue(x_rope.shape == x.shape)
|
|
93
|
+
self.assertArraysAllClose(x_rope, expected_x_rope)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import unittest
|
|
16
|
+
from unittest.mock import MagicMock
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
|
|
20
|
+
from tpu_inference.layers.common.sharding import (Sharding, ShardingConfig,
|
|
21
|
+
ShardingRulesConfig,
|
|
22
|
+
ShardingStrategy)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TestSharding(unittest.TestCase):
|
|
26
|
+
"""Unit test suite for the sharding configuration logic."""
|
|
27
|
+
|
|
28
|
+
def setUp(self):
|
|
29
|
+
"""Sets up the testing environment before each test."""
|
|
30
|
+
|
|
31
|
+
self.mock_devices = [MagicMock(coords=i) for i in range(8)]
|
|
32
|
+
self.original_jax_devices = jax.devices
|
|
33
|
+
jax.devices = lambda: self.mock_devices
|
|
34
|
+
|
|
35
|
+
def tearDown(self):
|
|
36
|
+
"""Restores the original jax.devices function after tests."""
|
|
37
|
+
jax.devices = self.original_jax_devices
|
|
38
|
+
|
|
39
|
+
def test_sharding_strategy_init(self):
|
|
40
|
+
"""Tests the initialization of the ShardingStrategy."""
|
|
41
|
+
strategy = ShardingStrategy(
|
|
42
|
+
tensor_parallelism=2,
|
|
43
|
+
expert_parallelism=4,
|
|
44
|
+
data_parallelism=1,
|
|
45
|
+
sequence_parallelism=1,
|
|
46
|
+
)
|
|
47
|
+
self.assertEqual(strategy.tensor_parallelism, 2)
|
|
48
|
+
self.assertEqual(strategy.expert_parallelism, 4)
|
|
49
|
+
|
|
50
|
+
def test_sharding_config_init(self):
|
|
51
|
+
"""Tests the initialization of ShardingConfig."""
|
|
52
|
+
config = ShardingConfig()
|
|
53
|
+
self.assertIsInstance(config.prefill_rules, ShardingRulesConfig)
|
|
54
|
+
self.assertIsInstance(config.generate_rules, ShardingRulesConfig)
|
|
55
|
+
|
|
56
|
+
custom_rules = ShardingRulesConfig(activation_ffw_td=("model", None))
|
|
57
|
+
config_with_rules = ShardingConfig(prefill_rules=custom_rules)
|
|
58
|
+
self.assertEqual(config_with_rules.prefill_rules.activation_ffw_td,
|
|
59
|
+
("model", None))
|
|
60
|
+
|
|
61
|
+
def test_apply_overrides(self):
|
|
62
|
+
"""Tests the _apply_overrides method for valid and invalid keys."""
|
|
63
|
+
sharding = Sharding(
|
|
64
|
+
prefill_rules={},
|
|
65
|
+
generate_rules={},
|
|
66
|
+
)
|
|
67
|
+
config_obj = ShardingRulesConfig()
|
|
68
|
+
|
|
69
|
+
valid_overrides = {"activation_ffw_td": ("model", None)}
|
|
70
|
+
sharding._apply_overrides(config_obj, valid_overrides)
|
|
71
|
+
self.assertEqual(config_obj.activation_ffw_td, ("model", None))
|
|
72
|
+
|
|
73
|
+
invalid_overrides = {"non_existent_attribute": (None, "model")}
|
|
74
|
+
with self.assertRaises(AttributeError):
|
|
75
|
+
sharding._apply_overrides(config_obj, invalid_overrides)
|
|
76
|
+
|
|
77
|
+
def test_default_sharding_config(self):
|
|
78
|
+
"""Tests that default sharding rules are created correctly."""
|
|
79
|
+
sharding = Sharding(
|
|
80
|
+
prefill_rules={},
|
|
81
|
+
generate_rules={},
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
sharding_cfg = sharding.get_sharding_cfg()
|
|
85
|
+
generate_rules = sharding_cfg.generate_rules
|
|
86
|
+
|
|
87
|
+
self.assertEqual(generate_rules.ffw_weight_df, (None, "model"))
|
|
88
|
+
self.assertEqual(generate_rules.moe_router_de, (None, "model"))
|
|
89
|
+
self.assertEqual(generate_rules.attn_q_weight_dnh,
|
|
90
|
+
(None, "model", None))
|
|
91
|
+
|
|
92
|
+
def test_sharding_init_with_overrides(self):
|
|
93
|
+
"""Tests Sharding initialization with programmatic overrides."""
|
|
94
|
+
generate_overrides = {"logits_tv": ("data", "model")}
|
|
95
|
+
|
|
96
|
+
sharding = Sharding(
|
|
97
|
+
generate_rules=generate_overrides,
|
|
98
|
+
prefill_rules={},
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
sharding_cfg = sharding.get_sharding_cfg()
|
|
102
|
+
self.assertNotEqual(sharding_cfg.generate_rules.logits_tv,
|
|
103
|
+
(None, "model"))
|
|
104
|
+
self.assertEqual(sharding_cfg.generate_rules.logits_tv,
|
|
105
|
+
("data", "model"))
|
|
106
|
+
|
|
107
|
+
def test_get_overrides_from_vllm_config(self):
|
|
108
|
+
"""Tests fetching sharding overrides from a mock VllmConfig."""
|
|
109
|
+
|
|
110
|
+
mock_vllm_config_prefill = MagicMock()
|
|
111
|
+
mock_vllm_config_prefill.additional_config = {
|
|
112
|
+
"sharding": {
|
|
113
|
+
"logical_rules": {
|
|
114
|
+
"all": {
|
|
115
|
+
"norm_scale": ("model", )
|
|
116
|
+
},
|
|
117
|
+
"prefill": {
|
|
118
|
+
"activation_ffw_td": ("data", "model")
|
|
119
|
+
},
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
sharding_prefill = Sharding(
|
|
124
|
+
vllm_config=mock_vllm_config_prefill,
|
|
125
|
+
prefill_rules={},
|
|
126
|
+
generate_rules={},
|
|
127
|
+
)
|
|
128
|
+
prefill_overrides = sharding_prefill._get_overrides("prefill")
|
|
129
|
+
|
|
130
|
+
self.assertEqual(prefill_overrides["norm_scale"], ("model", ))
|
|
131
|
+
self.assertEqual(prefill_overrides["activation_ffw_td"],
|
|
132
|
+
("data", "model"))
|
|
133
|
+
|
|
134
|
+
mock_vllm_config_generate = MagicMock()
|
|
135
|
+
mock_vllm_config_generate.additional_config = {
|
|
136
|
+
"sharding": {
|
|
137
|
+
"logical_rules": {
|
|
138
|
+
"all": {
|
|
139
|
+
"norm_scale": ("model", )
|
|
140
|
+
},
|
|
141
|
+
"prefill": {
|
|
142
|
+
"activation_ffw_td": ("data", "model")
|
|
143
|
+
},
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
sharding_generate = Sharding(
|
|
148
|
+
vllm_config=mock_vllm_config_generate,
|
|
149
|
+
prefill_rules={},
|
|
150
|
+
generate_rules={},
|
|
151
|
+
)
|
|
152
|
+
generate_overrides = sharding_generate._get_overrides("generate")
|
|
153
|
+
|
|
154
|
+
self.assertEqual(generate_overrides["norm_scale"], ("model", ))
|
|
155
|
+
self.assertNotIn("activation_ffw_td", generate_overrides)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
if __name__ == "__main__":
|
|
159
|
+
unittest.main()
|