tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,29 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
2
16
|
|
|
3
17
|
import jax
|
|
4
18
|
import jax.numpy as jnp
|
|
5
19
|
import torch
|
|
6
|
-
from jax.
|
|
7
|
-
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
20
|
+
from jax.sharding import Mesh, PartitionSpec
|
|
8
21
|
from torch.nn.parameter import Parameter
|
|
9
|
-
from torchax.interop import
|
|
22
|
+
from torchax.interop import torch_view
|
|
10
23
|
from torchax.ops.mappings import t2j
|
|
11
|
-
from vllm.
|
|
24
|
+
from vllm.attention.layer import Attention
|
|
12
25
|
from vllm.model_executor.layers.fused_moe.config import (
|
|
13
|
-
FusedMoEConfig, FusedMoEQuantConfig,
|
|
26
|
+
FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config)
|
|
14
27
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
|
15
28
|
FusedMoEMethodBase)
|
|
16
29
|
from vllm.model_executor.layers.linear import LinearBase
|
|
@@ -26,48 +39,30 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
|
26
39
|
|
|
27
40
|
from tpu_inference.layers.common.quant_methods import (MXFP4,
|
|
28
41
|
get_tpu_quant_method)
|
|
29
|
-
from tpu_inference.layers.
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
from tpu_inference.layers.vllm.
|
|
42
|
+
from tpu_inference.layers.common.quantization import \
|
|
43
|
+
dequantize_tensor_from_mxfp4_packed
|
|
44
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
45
|
+
from tpu_inference.layers.vllm.fused_moe import (FusedMoEBackend,
|
|
46
|
+
fused_moe_apply,
|
|
47
|
+
select_moe_backend)
|
|
48
|
+
from tpu_inference.layers.vllm.process_weights.fused_moe_weights import (
|
|
49
|
+
FusedMoEWeights, process_moe_weights, quantize_moe_weights,
|
|
50
|
+
shard_moe_weights)
|
|
51
|
+
from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig
|
|
33
52
|
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
34
53
|
VllmUnquantizedLinearMethod
|
|
54
|
+
from tpu_inference.logger import init_logger
|
|
55
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
35
56
|
|
|
36
|
-
|
|
57
|
+
REQUANTIZED_BLOCK_SIZE = 512
|
|
37
58
|
|
|
38
59
|
P = PartitionSpec
|
|
39
|
-
logger = init_logger(__name__)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
# TODO(kyuyeunk): Move these functions into a common utility file.
|
|
43
|
-
def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
|
|
44
|
-
assert u8_packed_e2m1.dtype == jnp.uint8
|
|
45
|
-
e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
|
|
46
|
-
# bitcast creates one more dimension that splits 8 bits into two e2m1.
|
|
47
|
-
# we flatten them with the last dim.
|
|
48
|
-
return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
|
|
52
|
-
e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
53
|
-
exponents = u8.astype(jnp.int32) + e8_finfo.minexp
|
|
54
|
-
ones = jnp.ones_like(u8, dtype=jnp.float32)
|
|
55
|
-
return jnp.ldexp(ones, exponents)
|
|
56
60
|
|
|
57
|
-
|
|
58
|
-
def dequantize_block_weight(weight: jax.Array,
|
|
59
|
-
scale: jax.Array,
|
|
60
|
-
block_size: int,
|
|
61
|
-
out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
|
|
62
|
-
orig_shape = weight.shape
|
|
63
|
-
weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
|
|
64
|
-
weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
|
|
65
|
-
scale, -1)
|
|
66
|
-
return weight_dequantized.reshape(orig_shape).astype(out_dtype)
|
|
61
|
+
logger = init_logger(__name__)
|
|
67
62
|
|
|
68
63
|
|
|
69
64
|
@register_quantization_config(get_tpu_quant_method(MXFP4))
|
|
70
|
-
class VllmMxfp4Config(Mxfp4Config,
|
|
65
|
+
class VllmMxfp4Config(Mxfp4Config, VllmQuantConfig):
|
|
71
66
|
|
|
72
67
|
@classmethod
|
|
73
68
|
def get_name(cls):
|
|
@@ -75,7 +70,6 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
75
70
|
|
|
76
71
|
def get_quant_method(self, layer: torch.nn.Module,
|
|
77
72
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
|
78
|
-
from vllm.attention.layer import Attention # Avoid circular import
|
|
79
73
|
|
|
80
74
|
if isinstance(layer, LinearBase):
|
|
81
75
|
linear_config = self.get_linear_config(layer)
|
|
@@ -85,17 +79,14 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
85
79
|
fused_mapping=self.packed_modules_mapping,
|
|
86
80
|
):
|
|
87
81
|
return VllmUnquantizedLinearMethod(linear_config)
|
|
88
|
-
# TODO: Add support for MXFP4 Linear Method.
|
|
89
|
-
# MXFP4 LinearMethod is available in AMD-Quark, refer to that
|
|
90
|
-
# implementation if you are interested in enabling MXFP4 here.
|
|
91
82
|
logger.warning_once(
|
|
92
83
|
"MXFP4 linear layer is not implemented - falling back to "
|
|
93
84
|
"UnquantizedLinearMethod.")
|
|
94
85
|
return VllmUnquantizedLinearMethod(linear_config)
|
|
95
86
|
elif isinstance(layer, FusedMoE):
|
|
96
|
-
|
|
87
|
+
moe_config = self.get_moe_config(layer)
|
|
88
|
+
return VllmMxfp4MoEMethod(moe_config, self.mesh)
|
|
97
89
|
elif isinstance(layer, Attention):
|
|
98
|
-
# TODO: Add support for MXFP4 Attention.
|
|
99
90
|
logger.warning_once("MXFP4 attention layer is not implemented. "
|
|
100
91
|
"Skipping quantization for this layer.")
|
|
101
92
|
return None
|
|
@@ -103,164 +94,132 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
103
94
|
|
|
104
95
|
class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
105
96
|
|
|
106
|
-
def __init__(
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
moe: FusedMoEConfig,
|
|
100
|
+
mesh: Mesh,
|
|
101
|
+
ep_axis_name: str = "model",
|
|
102
|
+
):
|
|
107
103
|
FusedMoEMethodBase.__init__(self, moe)
|
|
108
104
|
|
|
109
105
|
# We piggyback on triton implementation as it applies minimal hardware
|
|
110
106
|
# specific post processing to the weights.
|
|
111
107
|
self.mxfp4_backend = Mxfp4Backend.TRITON
|
|
108
|
+
|
|
112
109
|
self.mesh = mesh
|
|
110
|
+
self.moe_backend = select_moe_backend(self.moe)
|
|
111
|
+
|
|
112
|
+
self.extra_backend_kwargs = {}
|
|
113
|
+
if self.moe_backend == FusedMoEBackend.FUSED_MOE:
|
|
114
|
+
# When fused moe kernle is used, we pass extra arguments like
|
|
115
|
+
# tuned block sizes to the kernel.
|
|
116
|
+
self.extra_backend_kwargs = dict(
|
|
117
|
+
subc_quant_wsz=REQUANTIZED_BLOCK_SIZE,
|
|
118
|
+
ep_axis_name=ep_axis_name,
|
|
119
|
+
# TODO: Use autotune table once we have it.
|
|
120
|
+
bt=256,
|
|
121
|
+
bf=1024,
|
|
122
|
+
bd1=1024,
|
|
123
|
+
bd2=1024,
|
|
124
|
+
btc=256,
|
|
125
|
+
bfc=1024,
|
|
126
|
+
bd1c=1024,
|
|
127
|
+
bd2c=1024,
|
|
128
|
+
)
|
|
113
129
|
|
|
114
130
|
def get_fused_moe_quant_config(
|
|
115
131
|
self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
layer.w13_bias,
|
|
120
|
-
layer.w2_bias,
|
|
132
|
+
return mxfp4_w4a16_moe_quant_config(
|
|
133
|
+
w1_scale=layer.w13_weight_scale,
|
|
134
|
+
w2_scale=layer.w2_weight_scale,
|
|
135
|
+
w1_bias=layer.w13_bias,
|
|
136
|
+
w2_bias=layer.w2_bias,
|
|
121
137
|
)
|
|
122
138
|
|
|
123
139
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
124
140
|
assert isinstance(layer, FusedMoE)
|
|
141
|
+
assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
|
|
125
142
|
|
|
126
|
-
w13_weight =
|
|
127
|
-
w13_weight_scale =
|
|
128
|
-
t2j(layer.w13_weight_scale, use_dlpack=False))
|
|
143
|
+
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
144
|
+
w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
|
|
129
145
|
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
130
146
|
|
|
131
|
-
w2_weight =
|
|
132
|
-
w2_weight_scale =
|
|
133
|
-
t2j(layer.w2_weight_scale, use_dlpack=False))
|
|
147
|
+
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
148
|
+
w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
|
|
134
149
|
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
135
150
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
n_shards = self.mesh.shape["model"]
|
|
185
|
-
assert intermediate_size % n_shards == 0
|
|
186
|
-
w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
|
|
187
|
-
output_sizes,
|
|
188
|
-
n_shards,
|
|
189
|
-
dim=1)
|
|
190
|
-
w13_weight = jax.device_put(
|
|
191
|
-
w13_weight,
|
|
192
|
-
Format(Layout((0, 1, 2)),
|
|
193
|
-
NamedSharding(self.mesh, P(None, "model", None))))
|
|
194
|
-
w2_weight = jax.device_put(
|
|
195
|
-
w2_weight,
|
|
196
|
-
Format(Layout((0, 1, 2)),
|
|
197
|
-
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
198
|
-
|
|
199
|
-
w13_bias = reorder_concatenated_tensor_for_sharding(w13_bias,
|
|
200
|
-
output_sizes,
|
|
201
|
-
n_shards,
|
|
202
|
-
dim=1)
|
|
203
|
-
w13_bias = jax.device_put(
|
|
204
|
-
w13_bias,
|
|
205
|
-
Format(Layout((0, 1)),
|
|
206
|
-
NamedSharding(self.mesh, P(None, "model"))))
|
|
207
|
-
w2_bias = jax.device_put(
|
|
208
|
-
w2_bias,
|
|
209
|
-
Format(Layout((0, 1)), NamedSharding(self.mesh, P(None,
|
|
210
|
-
None))))
|
|
151
|
+
@jax.jit
|
|
152
|
+
def process_mxfp4_moe_weights(
|
|
153
|
+
w13_weight: jax.Array,
|
|
154
|
+
w13_weight_scale: jax.Array,
|
|
155
|
+
w13_bias: jax.Array,
|
|
156
|
+
w2_weight: jax.Array,
|
|
157
|
+
w2_weight_scale: jax.Array,
|
|
158
|
+
w2_bias: jax.Array,
|
|
159
|
+
) -> FusedMoEWeights:
|
|
160
|
+
# Dequantize fp4 weights into fp32.
|
|
161
|
+
w13_weight = dequantize_tensor_from_mxfp4_packed(
|
|
162
|
+
w13_weight, w13_weight_scale, 2)
|
|
163
|
+
w2_weight = dequantize_tensor_from_mxfp4_packed(
|
|
164
|
+
w2_weight, w2_weight_scale, 2)
|
|
165
|
+
|
|
166
|
+
w13_interleave = layer.activation == "swigluoai"
|
|
167
|
+
w13_reorder_size = get_mesh_shape_product(
|
|
168
|
+
self.mesh, ShardingAxisName.MLP_TENSOR)
|
|
169
|
+
|
|
170
|
+
weights = quantize_moe_weights(
|
|
171
|
+
FusedMoEWeights(
|
|
172
|
+
w13_weight=w13_weight,
|
|
173
|
+
w13_weight_scale=None,
|
|
174
|
+
w13_bias=w13_bias,
|
|
175
|
+
w2_weight=w2_weight,
|
|
176
|
+
w2_weight_scale=None,
|
|
177
|
+
w2_bias=w2_bias,
|
|
178
|
+
),
|
|
179
|
+
jnp.float4_e2m1fn,
|
|
180
|
+
REQUANTIZED_BLOCK_SIZE,
|
|
181
|
+
)
|
|
182
|
+
return process_moe_weights(
|
|
183
|
+
weights,
|
|
184
|
+
moe_backend=self.moe_backend,
|
|
185
|
+
w13_reorder_size=w13_reorder_size,
|
|
186
|
+
w13_interleave=w13_interleave,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
weights = process_mxfp4_moe_weights(
|
|
190
|
+
w13_weight,
|
|
191
|
+
w13_weight_scale,
|
|
192
|
+
w13_bias,
|
|
193
|
+
w2_weight,
|
|
194
|
+
w2_weight_scale,
|
|
195
|
+
w2_bias,
|
|
196
|
+
)
|
|
197
|
+
weights = torch_view(
|
|
198
|
+
shard_moe_weights(weights, self.moe_backend, self.mesh))
|
|
211
199
|
|
|
212
|
-
layer.w13_weight = Parameter(
|
|
213
|
-
|
|
214
|
-
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
|
|
200
|
+
layer.w13_weight = Parameter(weights.w13_weight, requires_grad=False)
|
|
201
|
+
layer.w2_weight = Parameter(weights.w2_weight, requires_grad=False)
|
|
215
202
|
|
|
216
|
-
layer.
|
|
217
|
-
|
|
203
|
+
layer.w13_weight_scale = Parameter(weights.w13_weight_scale,
|
|
204
|
+
requires_grad=False)
|
|
205
|
+
layer.w2_weight_scale = Parameter(weights.w2_weight_scale,
|
|
206
|
+
requires_grad=False)
|
|
218
207
|
|
|
219
|
-
|
|
208
|
+
layer.w13_bias = Parameter(weights.w13_bias, requires_grad=False)
|
|
209
|
+
layer.w2_bias = Parameter(weights.w2_bias, requires_grad=False)
|
|
220
210
|
|
|
221
211
|
def apply(
|
|
222
212
|
self,
|
|
223
213
|
layer: torch.nn.Module,
|
|
224
214
|
x: torch.Tensor,
|
|
225
215
|
router_logits: torch.Tensor,
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
routed_scaling_factor: float = 1.0,
|
|
236
|
-
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
237
|
-
apply_router_weight_on_input: bool = False,
|
|
238
|
-
activation: str = "silu",
|
|
239
|
-
enable_eplb: bool = False,
|
|
240
|
-
expert_load_view: Optional[torch.Tensor] = None,
|
|
241
|
-
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
242
|
-
logical_replica_count: Optional[torch.Tensor] = None,
|
|
243
|
-
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
244
|
-
assert isinstance(layer, FusedMoE)
|
|
245
|
-
if scoring_func != "softmax":
|
|
246
|
-
raise NotImplementedError(
|
|
247
|
-
"Only softmax is supported for scoring_func")
|
|
248
|
-
|
|
249
|
-
# Use the original implementation
|
|
250
|
-
output = fused_moe_func_padded(
|
|
251
|
-
jax_view(x),
|
|
252
|
-
jax_view(layer.w13_weight),
|
|
253
|
-
jax_view(layer.w2_weight),
|
|
254
|
-
jax_view(layer.w13_bias) if self.moe.has_bias else None,
|
|
255
|
-
jax_view(layer.w2_bias) if self.moe.has_bias else None,
|
|
256
|
-
jax_view(router_logits),
|
|
257
|
-
topk=top_k,
|
|
258
|
-
global_num_experts=global_num_experts,
|
|
259
|
-
renormalize=renormalize,
|
|
260
|
-
reduce_results=layer.reduce_results,
|
|
261
|
-
mesh=self.mesh,
|
|
262
|
-
use_ep=layer.use_ep,
|
|
263
|
-
activation=activation,
|
|
216
|
+
) -> torch.Tensor:
|
|
217
|
+
|
|
218
|
+
return fused_moe_apply(
|
|
219
|
+
layer,
|
|
220
|
+
x,
|
|
221
|
+
router_logits,
|
|
222
|
+
self.moe_backend,
|
|
223
|
+
self.mesh,
|
|
224
|
+
self.extra_backend_kwargs,
|
|
264
225
|
)
|
|
265
|
-
|
|
266
|
-
return torch_view(output)
|