tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +46 -17
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +44 -17
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
- tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,18 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Union
|
|
2
16
|
|
|
3
17
|
import jax
|
|
4
18
|
import jax.numpy as jnp
|
|
@@ -10,7 +24,7 @@ from torchax.interop import jax_view, torch_view
|
|
|
10
24
|
from torchax.ops.mappings import t2j
|
|
11
25
|
from vllm.logger import init_logger
|
|
12
26
|
from vllm.model_executor.layers.fused_moe.config import (
|
|
13
|
-
FusedMoEConfig, FusedMoEQuantConfig,
|
|
27
|
+
FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config)
|
|
14
28
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
|
15
29
|
FusedMoEMethodBase)
|
|
16
30
|
from vllm.model_executor.layers.linear import LinearBase
|
|
@@ -28,44 +42,22 @@ from tpu_inference import envs
|
|
|
28
42
|
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
29
43
|
from tpu_inference.layers.common.quant_methods import (MXFP4,
|
|
30
44
|
get_tpu_quant_method)
|
|
31
|
-
from tpu_inference.layers.
|
|
45
|
+
from tpu_inference.layers.common.quantization import (
|
|
46
|
+
dequantize_tensor_from_mxfp4_packed, quantize_tensor)
|
|
47
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
48
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func
|
|
32
49
|
from tpu_inference.layers.vllm.linear_common import \
|
|
33
50
|
reorder_concatenated_tensor_for_sharding
|
|
34
51
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
35
52
|
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
36
53
|
VllmUnquantizedLinearMethod
|
|
54
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
37
55
|
|
|
38
|
-
|
|
56
|
+
REQUANTIZED_BLOCK_SIZE = 512
|
|
39
57
|
|
|
40
58
|
P = PartitionSpec
|
|
41
|
-
logger = init_logger(__name__)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
# TODO(kyuyeunk): Move these functions into a common utility file.
|
|
45
|
-
def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
|
|
46
|
-
assert u8_packed_e2m1.dtype == jnp.uint8
|
|
47
|
-
e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
|
|
48
|
-
# bitcast creates one more dimension that splits 8 bits into two e2m1.
|
|
49
|
-
# we flatten them with the last dim.
|
|
50
|
-
return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
|
|
54
|
-
e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
55
|
-
exponents = u8.astype(jnp.int32) + e8_finfo.minexp
|
|
56
|
-
ones = jnp.ones_like(u8, dtype=jnp.float32)
|
|
57
|
-
return jnp.ldexp(ones, exponents)
|
|
58
59
|
|
|
59
|
-
|
|
60
|
-
def dequantize_block_weight(weight: jax.Array,
|
|
61
|
-
scale: jax.Array,
|
|
62
|
-
block_size: int,
|
|
63
|
-
out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
|
|
64
|
-
orig_shape = weight.shape
|
|
65
|
-
weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
|
|
66
|
-
weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
|
|
67
|
-
scale, -1)
|
|
68
|
-
return weight_dequantized.reshape(orig_shape).astype(out_dtype)
|
|
60
|
+
logger = init_logger(__name__)
|
|
69
61
|
|
|
70
62
|
|
|
71
63
|
@register_quantization_config(get_tpu_quant_method(MXFP4))
|
|
@@ -87,9 +79,6 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
87
79
|
fused_mapping=self.packed_modules_mapping,
|
|
88
80
|
):
|
|
89
81
|
return VllmUnquantizedLinearMethod(linear_config)
|
|
90
|
-
# TODO: Add support for MXFP4 Linear Method.
|
|
91
|
-
# MXFP4 LinearMethod is available in AMD-Quark, refer to that
|
|
92
|
-
# implementation if you are interested in enabling MXFP4 here.
|
|
93
82
|
logger.warning_once(
|
|
94
83
|
"MXFP4 linear layer is not implemented - falling back to "
|
|
95
84
|
"UnquantizedLinearMethod.")
|
|
@@ -98,7 +87,6 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
98
87
|
moe_config = self.get_moe_config(layer)
|
|
99
88
|
return VllmMxfp4MoEMethod(moe_config, self.mesh)
|
|
100
89
|
elif isinstance(layer, Attention):
|
|
101
|
-
# TODO: Add support for MXFP4 Attention.
|
|
102
90
|
logger.warning_once("MXFP4 attention layer is not implemented. "
|
|
103
91
|
"Skipping quantization for this layer.")
|
|
104
92
|
return None
|
|
@@ -117,225 +105,306 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
117
105
|
self.mxfp4_backend = Mxfp4Backend.TRITON
|
|
118
106
|
|
|
119
107
|
self.mesh = mesh
|
|
120
|
-
self.use_kernel = envs.USE_MOE_EP_KERNEL
|
|
108
|
+
self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
|
|
121
109
|
self.ep_axis_name = ep_axis_name
|
|
122
110
|
# TODO: Use autotune table once we have it.
|
|
123
111
|
self.block_size = {
|
|
124
|
-
"bt":
|
|
112
|
+
"bt": 256,
|
|
125
113
|
"bf": 1024,
|
|
126
|
-
"bd1":
|
|
127
|
-
"bd2":
|
|
128
|
-
"btc":
|
|
114
|
+
"bd1": 1024,
|
|
115
|
+
"bd2": 1024,
|
|
116
|
+
"btc": 256,
|
|
129
117
|
"bfc": 1024,
|
|
130
|
-
"bd1c":
|
|
131
|
-
"bd2c":
|
|
118
|
+
"bd1c": 1024,
|
|
119
|
+
"bd2c": 1024,
|
|
132
120
|
}
|
|
133
121
|
|
|
134
122
|
def get_fused_moe_quant_config(
|
|
135
123
|
self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
layer.w13_bias,
|
|
140
|
-
layer.w2_bias,
|
|
124
|
+
return mxfp4_w4a16_moe_quant_config(
|
|
125
|
+
w1_scale=layer.w13_weight_scale,
|
|
126
|
+
w2_scale=layer.w2_weight_scale,
|
|
127
|
+
w1_bias=layer.w13_bias,
|
|
128
|
+
w2_bias=layer.w2_bias,
|
|
141
129
|
)
|
|
142
130
|
|
|
143
131
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
144
132
|
assert isinstance(layer, FusedMoE)
|
|
145
133
|
assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
|
|
146
134
|
|
|
147
|
-
w13_weight =
|
|
148
|
-
w13_weight_scale =
|
|
149
|
-
t2j(layer.w13_weight_scale, use_dlpack=False))
|
|
135
|
+
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
136
|
+
w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
|
|
150
137
|
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
151
138
|
|
|
152
|
-
w2_weight =
|
|
153
|
-
w2_weight_scale =
|
|
154
|
-
t2j(layer.w2_weight_scale, use_dlpack=False))
|
|
139
|
+
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
140
|
+
w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
|
|
155
141
|
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
156
142
|
|
|
157
|
-
#
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
#
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
143
|
+
# Wrap functions in jit to speedup requantization.
|
|
144
|
+
@jax.jit
|
|
145
|
+
def wrapper(w13_weight, w13_weight_scale, w13_bias, w2_weight,
|
|
146
|
+
w2_weight_scale, w2_bias):
|
|
147
|
+
# Dequantize fp4 weights into fp32.
|
|
148
|
+
w13_weight = dequantize_tensor_from_mxfp4_packed(
|
|
149
|
+
w13_weight, w13_weight_scale, 2)
|
|
150
|
+
w2_weight = dequantize_tensor_from_mxfp4_packed(
|
|
151
|
+
w2_weight, w2_weight_scale, 2)
|
|
152
|
+
|
|
153
|
+
num_experts, orig_hidden_size, orig_intermediate_size = w2_weight.shape
|
|
154
|
+
|
|
155
|
+
# Requantize the weights into TPU friendly block size.
|
|
156
|
+
w13_weight, w13_weight_scale = quantize_tensor(
|
|
157
|
+
jnp.float4_e2m1fn, w13_weight, 2, REQUANTIZED_BLOCK_SIZE, True)
|
|
158
|
+
w2_weight, w2_weight_scale = quantize_tensor(
|
|
159
|
+
jnp.float4_e2m1fn, w2_weight, 2, REQUANTIZED_BLOCK_SIZE, True)
|
|
160
|
+
|
|
161
|
+
intermediate_size = w2_weight.shape[-1]
|
|
162
|
+
hidden_size = w13_weight.shape[-1]
|
|
163
|
+
|
|
164
|
+
# Dims may have been padded to align with subchannel size during
|
|
165
|
+
# quantization. We pad the corresponding dim on other weight.
|
|
166
|
+
# NOTE: We perform padding after quantization as padding value can
|
|
167
|
+
# affect quantization numerics.
|
|
168
|
+
intermediate_padding_size = 2 * (intermediate_size -
|
|
169
|
+
orig_intermediate_size)
|
|
170
|
+
w13_weight = jnp.pad(w13_weight,
|
|
171
|
+
((0, 0), (0, intermediate_padding_size),
|
|
172
|
+
(0, 0)))
|
|
173
|
+
w13_weight_scale = jnp.pad(w13_weight_scale,
|
|
174
|
+
((0, 0), (0, intermediate_padding_size),
|
|
175
|
+
(0, 0)))
|
|
176
|
+
w13_bias = jnp.pad(w13_bias,
|
|
177
|
+
((0, 0), (0, intermediate_padding_size)))
|
|
178
|
+
|
|
179
|
+
hidden_padding_size = hidden_size - orig_hidden_size
|
|
180
|
+
w2_weight = jnp.pad(w2_weight,
|
|
181
|
+
((0, 0), (0, hidden_padding_size), (0, 0)))
|
|
182
|
+
w2_weight_scale = jnp.pad(w2_weight_scale,
|
|
183
|
+
((0, 0), (0, hidden_padding_size),
|
|
184
|
+
(0, 0)))
|
|
185
|
+
w2_bias = jnp.pad(w2_bias, ((0, 0), (0, hidden_padding_size)))
|
|
186
|
+
|
|
187
|
+
if layer.activation == "swigluoai":
|
|
188
|
+
# When using swigluoai, vLLM splits gmm output in a interleaved way.
|
|
189
|
+
# However, interleaved split is not performant on TPU. Therefore,
|
|
190
|
+
# we preprocess the weight so that splitting gmm output by middle
|
|
191
|
+
# can still get the same result.
|
|
192
|
+
w1_weight = w13_weight[:, ::2, :]
|
|
193
|
+
w3_weight = w13_weight[:, 1::2, :]
|
|
194
|
+
w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
|
|
195
|
+
|
|
196
|
+
w1_weight_scale = w13_weight_scale[:, ::2, :]
|
|
197
|
+
w3_weight_scale = w13_weight_scale[:, 1::2, :]
|
|
198
|
+
w13_weight_scale = jnp.concat(
|
|
199
|
+
[w1_weight_scale, w3_weight_scale], axis=1)
|
|
200
|
+
|
|
201
|
+
w1_bias = w13_bias[:, ::2]
|
|
202
|
+
w3_bias = w13_bias[:, 1::2]
|
|
203
|
+
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
204
|
+
|
|
205
|
+
if self.use_kernel:
|
|
206
|
+
# Kernel expects:
|
|
207
|
+
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
208
|
+
# w2: (num_experts, intermediate_size, hidden_size)
|
|
209
|
+
# Current format:
|
|
210
|
+
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
211
|
+
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
212
|
+
|
|
213
|
+
w13_weight = w13_weight.reshape(num_experts, 2,
|
|
214
|
+
intermediate_size, hidden_size)
|
|
215
|
+
|
|
216
|
+
w13_weight_scale = w13_weight_scale.reshape(
|
|
217
|
+
num_experts, 2, intermediate_size, 1, -1)
|
|
218
|
+
w2_weight_scale = w2_weight_scale.reshape(
|
|
219
|
+
num_experts, hidden_size, 1, -1)
|
|
220
|
+
|
|
221
|
+
w13_bias = w13_bias.astype(jnp.float32).reshape(
|
|
222
|
+
num_experts, 2, 1, intermediate_size)
|
|
223
|
+
w2_bias = w2_bias.astype(jnp.float32).reshape(
|
|
224
|
+
num_experts, 1, hidden_size)
|
|
225
|
+
|
|
226
|
+
# Transpose non-constracting dim to right most dim
|
|
227
|
+
w13_weight = jnp.swapaxes(w13_weight, 2, 3)
|
|
228
|
+
w2_weight = jnp.swapaxes(w2_weight, 1, 2)
|
|
229
|
+
|
|
230
|
+
w13_weight_scale = jnp.swapaxes(w13_weight_scale, 2, 4)
|
|
231
|
+
w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 3)
|
|
212
232
|
|
|
213
233
|
# Apply EP sharding
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
Format(Layout((0, 1, 2)),
|
|
232
|
-
NamedSharding(self.mesh, P("model", None, None))))
|
|
233
|
-
|
|
234
|
-
w13_bias = jax.device_put(
|
|
235
|
-
w13_bias,
|
|
236
|
-
Format(Layout((0, 1)),
|
|
237
|
-
NamedSharding(self.mesh, P("model", None))))
|
|
238
|
-
w2_bias = jax.device_put(
|
|
239
|
-
w2_bias,
|
|
240
|
-
Format(Layout((0, 1)),
|
|
241
|
-
NamedSharding(self.mesh, P("model", None))))
|
|
242
|
-
|
|
234
|
+
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
235
|
+
|
|
236
|
+
w13_weight = jax.lax.with_sharding_constraint(
|
|
237
|
+
w13_weight, Format(Layout((0, 1, 2, 3)), ep_sharding))
|
|
238
|
+
w2_weight = jax.lax.with_sharding_constraint(
|
|
239
|
+
w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
|
|
240
|
+
|
|
241
|
+
w13_weight_scale = jax.lax.with_sharding_constraint(
|
|
242
|
+
w13_weight_scale,
|
|
243
|
+
Format(Layout((0, 1, 2, 3, 4)), ep_sharding))
|
|
244
|
+
w2_weight_scale = jax.lax.with_sharding_constraint(
|
|
245
|
+
w2_weight_scale, Format(Layout((0, 1, 2, 3)), ep_sharding))
|
|
246
|
+
|
|
247
|
+
w13_bias = jax.lax.with_sharding_constraint(
|
|
248
|
+
w13_bias, Format(Layout((0, 1, 2, 3)), ep_sharding))
|
|
249
|
+
w2_bias = jax.lax.with_sharding_constraint(
|
|
250
|
+
w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
|
|
243
251
|
else:
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
252
|
+
w13_weight_scale = jnp.swapaxes(w13_weight_scale, 1, 2)
|
|
253
|
+
w13_weight_scale = jnp.expand_dims(w13_weight_scale, 2)
|
|
254
|
+
w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 2)
|
|
255
|
+
w2_weight_scale = jnp.expand_dims(w2_weight_scale, 2)
|
|
256
|
+
|
|
257
|
+
w13_bias = jnp.expand_dims(w13_bias, 1)
|
|
258
|
+
w2_bias = jnp.expand_dims(w2_bias, 1)
|
|
259
|
+
|
|
260
|
+
if layer.use_ep:
|
|
261
|
+
ep_sharding = NamedSharding(self.mesh,
|
|
262
|
+
P(ShardingAxisName.EXPERT))
|
|
263
|
+
|
|
264
|
+
w13_weight = jax.lax.with_sharding_constraint(
|
|
265
|
+
w13_weight, ep_sharding)
|
|
266
|
+
w2_weight = jax.lax.with_sharding_constraint(
|
|
267
|
+
w2_weight, ep_sharding)
|
|
268
|
+
|
|
269
|
+
w13_weight_scale = jax.lax.with_sharding_constraint(
|
|
270
|
+
w13_weight_scale, ep_sharding)
|
|
271
|
+
w2_weight_scale = jax.lax.with_sharding_constraint(
|
|
272
|
+
w2_weight_scale, ep_sharding)
|
|
273
|
+
|
|
274
|
+
w13_bias = jax.lax.with_sharding_constraint(
|
|
275
|
+
w13_bias, ep_sharding)
|
|
276
|
+
w2_bias = jax.lax.with_sharding_constraint(
|
|
277
|
+
w2_bias, ep_sharding)
|
|
278
|
+
|
|
279
|
+
else:
|
|
280
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
281
|
+
n_shards = get_mesh_shape_product(
|
|
282
|
+
self.mesh, ShardingAxisName.MLP_TENSOR)
|
|
283
|
+
assert intermediate_size % n_shards == 0
|
|
284
|
+
|
|
285
|
+
# Reorder w13 weights so that splitting w1 and w3 output
|
|
286
|
+
# can happen locally without any collective operations.
|
|
287
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
288
|
+
w13_weight,
|
|
289
|
+
output_sizes,
|
|
290
|
+
n_shards,
|
|
291
|
+
dim=1,
|
|
292
|
+
)
|
|
293
|
+
w13_weight_scale = reorder_concatenated_tensor_for_sharding(
|
|
294
|
+
w13_weight_scale,
|
|
295
|
+
output_sizes,
|
|
296
|
+
n_shards,
|
|
297
|
+
dim=3,
|
|
298
|
+
)
|
|
299
|
+
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
300
|
+
w13_bias,
|
|
301
|
+
output_sizes,
|
|
302
|
+
n_shards,
|
|
303
|
+
dim=2,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
w13_weight = jax.lax.with_sharding_constraint(
|
|
307
|
+
w13_weight,
|
|
308
|
+
NamedSharding(
|
|
309
|
+
self.mesh,
|
|
310
|
+
P(None, ShardingAxisName.MLP_TENSOR, None)))
|
|
311
|
+
w2_weight = jax.lax.with_sharding_constraint(
|
|
312
|
+
w2_weight,
|
|
313
|
+
NamedSharding(
|
|
314
|
+
self.mesh,
|
|
315
|
+
P(None, None, ShardingAxisName.MLP_TENSOR)))
|
|
316
|
+
w13_weight_scale = jax.lax.with_sharding_constraint(
|
|
317
|
+
w13_weight_scale,
|
|
318
|
+
NamedSharding(
|
|
319
|
+
self.mesh,
|
|
320
|
+
P(None, None, None, ShardingAxisName.MLP_TENSOR)))
|
|
321
|
+
w2_weight_scale = jax.lax.with_sharding_constraint(
|
|
322
|
+
w2_weight_scale,
|
|
323
|
+
NamedSharding(
|
|
324
|
+
self.mesh,
|
|
325
|
+
P(None, ShardingAxisName.MLP_TENSOR, None, None)))
|
|
326
|
+
w13_bias = jax.lax.with_sharding_constraint(
|
|
327
|
+
w13_bias,
|
|
328
|
+
NamedSharding(
|
|
329
|
+
self.mesh,
|
|
330
|
+
P(None, None, ShardingAxisName.MLP_TENSOR)))
|
|
331
|
+
w2_bias = jax.lax.with_sharding_constraint(
|
|
332
|
+
w2_bias, NamedSharding(self.mesh, P(None, None, None)))
|
|
333
|
+
|
|
334
|
+
return w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale, w2_bias
|
|
335
|
+
|
|
336
|
+
w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale, w2_bias = wrapper(
|
|
337
|
+
w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale,
|
|
338
|
+
w2_bias)
|
|
270
339
|
|
|
271
340
|
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
272
341
|
requires_grad=False)
|
|
273
|
-
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
|
|
274
|
-
|
|
275
342
|
layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
276
|
-
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
277
343
|
|
|
278
|
-
|
|
344
|
+
layer.w13_weight_scale = Parameter(torch_view(w13_weight_scale),
|
|
345
|
+
requires_grad=False)
|
|
346
|
+
layer.w2_weight_scale = Parameter(torch_view(w2_weight_scale),
|
|
347
|
+
requires_grad=False)
|
|
348
|
+
|
|
349
|
+
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
|
|
350
|
+
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
279
351
|
|
|
280
352
|
def apply(
|
|
281
353
|
self,
|
|
282
354
|
layer: torch.nn.Module,
|
|
283
355
|
x: torch.Tensor,
|
|
284
356
|
router_logits: torch.Tensor,
|
|
285
|
-
top_k: int,
|
|
286
|
-
renormalize: bool,
|
|
287
|
-
use_grouped_topk: bool = False,
|
|
288
|
-
topk_group: Optional[int] = None,
|
|
289
|
-
num_expert_group: Optional[int] = None,
|
|
290
|
-
global_num_experts: int = -1,
|
|
291
|
-
expert_map: Optional[torch.Tensor] = None,
|
|
292
|
-
custom_routing_function: Optional[Callable] = None,
|
|
293
|
-
scoring_func: str = "softmax",
|
|
294
|
-
routed_scaling_factor: float = 1.0,
|
|
295
|
-
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
296
|
-
apply_router_weight_on_input: bool = False,
|
|
297
|
-
activation: str = "silu",
|
|
298
|
-
enable_eplb: bool = False,
|
|
299
|
-
expert_load_view: Optional[torch.Tensor] = None,
|
|
300
|
-
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
301
|
-
logical_replica_count: Optional[torch.Tensor] = None,
|
|
302
357
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
303
358
|
assert isinstance(layer, FusedMoE)
|
|
304
|
-
if scoring_func != "softmax":
|
|
359
|
+
if layer.scoring_func != "softmax":
|
|
305
360
|
raise NotImplementedError(
|
|
306
361
|
"Only softmax is supported for scoring_func")
|
|
307
362
|
|
|
308
|
-
|
|
363
|
+
x = jax_view(x)
|
|
364
|
+
w13_weight = jax_view(layer.w13_weight)
|
|
365
|
+
w2_weight = jax_view(layer.w2_weight)
|
|
366
|
+
w13_weight_scale = jax_view(layer.w13_weight_scale)
|
|
367
|
+
w2_weight_scale = jax_view(layer.w2_weight_scale)
|
|
368
|
+
w13_bias = jax_view(layer.w13_bias)
|
|
369
|
+
w2_bias = jax_view(layer.w2_bias)
|
|
370
|
+
gating_output = jax_view(router_logits)
|
|
371
|
+
|
|
372
|
+
if self.use_kernel:
|
|
373
|
+
actual_hidden_size = x.shape[-1]
|
|
374
|
+
padding_size = w13_weight.shape[-2] - actual_hidden_size
|
|
375
|
+
x = jnp.pad(x, ((0, 0), (0, padding_size)))
|
|
309
376
|
output = fused_ep_moe(
|
|
310
377
|
mesh=self.mesh,
|
|
311
|
-
tokens=
|
|
312
|
-
w1=
|
|
313
|
-
w2=
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
378
|
+
tokens=x,
|
|
379
|
+
w1=w13_weight,
|
|
380
|
+
w2=w2_weight,
|
|
381
|
+
w1_scale=w13_weight_scale,
|
|
382
|
+
w2_scale=w2_weight_scale,
|
|
383
|
+
b1=w13_bias,
|
|
384
|
+
b2=w2_bias,
|
|
385
|
+
gating_output=gating_output,
|
|
386
|
+
subc_quant_wsz=REQUANTIZED_BLOCK_SIZE,
|
|
387
|
+
top_k=layer.top_k,
|
|
318
388
|
ep_axis_name=self.ep_axis_name,
|
|
319
|
-
renormalize_topk_logits=renormalize,
|
|
320
|
-
act_fn=activation,
|
|
389
|
+
renormalize_topk_logits=layer.renormalize,
|
|
390
|
+
act_fn=layer.activation,
|
|
321
391
|
**self.block_size,
|
|
322
|
-
)
|
|
392
|
+
)[:, :actual_hidden_size]
|
|
323
393
|
else:
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
renormalize=renormalize,
|
|
335
|
-
reduce_results=layer.reduce_results,
|
|
394
|
+
output = fused_moe_func(
|
|
395
|
+
hidden_states=x,
|
|
396
|
+
w1=w13_weight,
|
|
397
|
+
w2=w2_weight,
|
|
398
|
+
w1_scale=w13_weight_scale,
|
|
399
|
+
w2_scale=w2_weight_scale,
|
|
400
|
+
w1_bias=w13_bias,
|
|
401
|
+
w2_bias=w2_bias,
|
|
402
|
+
gating_output=gating_output,
|
|
403
|
+
topk=layer.top_k,
|
|
404
|
+
renormalize=layer.renormalize,
|
|
336
405
|
mesh=self.mesh,
|
|
337
406
|
use_ep=layer.use_ep,
|
|
338
|
-
activation=activation,
|
|
407
|
+
activation=layer.activation,
|
|
339
408
|
)
|
|
340
409
|
|
|
341
410
|
return torch_view(output)
|