tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,29 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
2
16
|
|
|
3
17
|
import jax
|
|
4
18
|
import jax.numpy as jnp
|
|
5
19
|
import torch
|
|
6
|
-
from jax.
|
|
7
|
-
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
20
|
+
from jax.sharding import Mesh, PartitionSpec
|
|
8
21
|
from torch.nn.parameter import Parameter
|
|
9
|
-
from torchax.interop import
|
|
22
|
+
from torchax.interop import torch_view
|
|
10
23
|
from torchax.ops.mappings import t2j
|
|
11
|
-
from vllm.
|
|
24
|
+
from vllm.attention.layer import Attention
|
|
12
25
|
from vllm.model_executor.layers.fused_moe.config import (
|
|
13
|
-
FusedMoEConfig, FusedMoEQuantConfig,
|
|
26
|
+
FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config)
|
|
14
27
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
|
15
28
|
FusedMoEMethodBase)
|
|
16
29
|
from vllm.model_executor.layers.linear import LinearBase
|
|
@@ -24,52 +37,32 @@ from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
|
|
|
24
37
|
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
25
38
|
is_layer_skipped
|
|
26
39
|
|
|
27
|
-
from tpu_inference import envs
|
|
28
|
-
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
29
40
|
from tpu_inference.layers.common.quant_methods import (MXFP4,
|
|
30
41
|
get_tpu_quant_method)
|
|
31
|
-
from tpu_inference.layers.
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
from tpu_inference.layers.vllm.
|
|
42
|
+
from tpu_inference.layers.common.quantization import \
|
|
43
|
+
dequantize_tensor_from_mxfp4_packed
|
|
44
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
45
|
+
from tpu_inference.layers.vllm.fused_moe import (FusedMoEBackend,
|
|
46
|
+
fused_moe_apply,
|
|
47
|
+
select_moe_backend)
|
|
48
|
+
from tpu_inference.layers.vllm.process_weights.fused_moe_weights import (
|
|
49
|
+
FusedMoEWeights, process_moe_weights, quantize_moe_weights,
|
|
50
|
+
shard_moe_weights)
|
|
51
|
+
from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig
|
|
35
52
|
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
36
53
|
VllmUnquantizedLinearMethod
|
|
54
|
+
from tpu_inference.logger import init_logger
|
|
55
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
37
56
|
|
|
38
|
-
|
|
57
|
+
REQUANTIZED_BLOCK_SIZE = 512
|
|
39
58
|
|
|
40
59
|
P = PartitionSpec
|
|
41
|
-
logger = init_logger(__name__)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
# TODO(kyuyeunk): Move these functions into a common utility file.
|
|
45
|
-
def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
|
|
46
|
-
assert u8_packed_e2m1.dtype == jnp.uint8
|
|
47
|
-
e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
|
|
48
|
-
# bitcast creates one more dimension that splits 8 bits into two e2m1.
|
|
49
|
-
# we flatten them with the last dim.
|
|
50
|
-
return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
|
|
51
|
-
|
|
52
60
|
|
|
53
|
-
|
|
54
|
-
e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
55
|
-
exponents = u8.astype(jnp.int32) + e8_finfo.minexp
|
|
56
|
-
ones = jnp.ones_like(u8, dtype=jnp.float32)
|
|
57
|
-
return jnp.ldexp(ones, exponents)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def dequantize_block_weight(weight: jax.Array,
|
|
61
|
-
scale: jax.Array,
|
|
62
|
-
block_size: int,
|
|
63
|
-
out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
|
|
64
|
-
orig_shape = weight.shape
|
|
65
|
-
weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
|
|
66
|
-
weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
|
|
67
|
-
scale, -1)
|
|
68
|
-
return weight_dequantized.reshape(orig_shape).astype(out_dtype)
|
|
61
|
+
logger = init_logger(__name__)
|
|
69
62
|
|
|
70
63
|
|
|
71
64
|
@register_quantization_config(get_tpu_quant_method(MXFP4))
|
|
72
|
-
class VllmMxfp4Config(Mxfp4Config,
|
|
65
|
+
class VllmMxfp4Config(Mxfp4Config, VllmQuantConfig):
|
|
73
66
|
|
|
74
67
|
@classmethod
|
|
75
68
|
def get_name(cls):
|
|
@@ -77,7 +70,6 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
77
70
|
|
|
78
71
|
def get_quant_method(self, layer: torch.nn.Module,
|
|
79
72
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
|
80
|
-
from vllm.attention.layer import Attention # Avoid circular import
|
|
81
73
|
|
|
82
74
|
if isinstance(layer, LinearBase):
|
|
83
75
|
linear_config = self.get_linear_config(layer)
|
|
@@ -102,10 +94,12 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
102
94
|
|
|
103
95
|
class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
104
96
|
|
|
105
|
-
def __init__(
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
moe: FusedMoEConfig,
|
|
100
|
+
mesh: Mesh,
|
|
101
|
+
ep_axis_name: str = "model",
|
|
102
|
+
):
|
|
109
103
|
FusedMoEMethodBase.__init__(self, moe)
|
|
110
104
|
|
|
111
105
|
# We piggyback on triton implementation as it applies minimal hardware
|
|
@@ -113,200 +107,119 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
113
107
|
self.mxfp4_backend = Mxfp4Backend.TRITON
|
|
114
108
|
|
|
115
109
|
self.mesh = mesh
|
|
116
|
-
self.
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
self.
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
110
|
+
self.moe_backend = select_moe_backend(self.moe)
|
|
111
|
+
|
|
112
|
+
self.extra_backend_kwargs = {}
|
|
113
|
+
if self.moe_backend == FusedMoEBackend.FUSED_MOE:
|
|
114
|
+
# When fused moe kernle is used, we pass extra arguments like
|
|
115
|
+
# tuned block sizes to the kernel.
|
|
116
|
+
self.extra_backend_kwargs = dict(
|
|
117
|
+
subc_quant_wsz=REQUANTIZED_BLOCK_SIZE,
|
|
118
|
+
ep_axis_name=ep_axis_name,
|
|
119
|
+
# TODO: Use autotune table once we have it.
|
|
120
|
+
bt=256,
|
|
121
|
+
bf=1024,
|
|
122
|
+
bd1=1024,
|
|
123
|
+
bd2=1024,
|
|
124
|
+
btc=256,
|
|
125
|
+
bfc=1024,
|
|
126
|
+
bd1c=1024,
|
|
127
|
+
bd2c=1024,
|
|
128
|
+
)
|
|
129
129
|
|
|
130
130
|
def get_fused_moe_quant_config(
|
|
131
131
|
self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
layer.w13_bias,
|
|
136
|
-
layer.w2_bias,
|
|
132
|
+
return mxfp4_w4a16_moe_quant_config(
|
|
133
|
+
w1_scale=layer.w13_weight_scale,
|
|
134
|
+
w2_scale=layer.w2_weight_scale,
|
|
135
|
+
w1_bias=layer.w13_bias,
|
|
136
|
+
w2_bias=layer.w2_bias,
|
|
137
137
|
)
|
|
138
138
|
|
|
139
139
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
140
140
|
assert isinstance(layer, FusedMoE)
|
|
141
141
|
assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
|
|
142
142
|
|
|
143
|
-
w13_weight =
|
|
144
|
-
w13_weight_scale =
|
|
145
|
-
t2j(layer.w13_weight_scale, use_dlpack=False))
|
|
143
|
+
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
144
|
+
w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
|
|
146
145
|
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
147
146
|
|
|
148
|
-
w2_weight =
|
|
149
|
-
w2_weight_scale =
|
|
150
|
-
t2j(layer.w2_weight_scale, use_dlpack=False))
|
|
147
|
+
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
148
|
+
w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
|
|
151
149
|
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
152
150
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
# Transpose non-constracting dim to right most dim
|
|
191
|
-
w13_weight_transposed = jnp.swapaxes(w13_reshaped, 2, 3)
|
|
192
|
-
w2_weight_transposed = jnp.swapaxes(w2_weight, 1, 2)
|
|
193
|
-
|
|
194
|
-
# Apply EP sharding
|
|
195
|
-
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
196
|
-
|
|
197
|
-
w13_weight = jax.device_put(
|
|
198
|
-
w13_weight_transposed, Format(Layout((0, 1, 2, 3)),
|
|
199
|
-
ep_sharding))
|
|
200
|
-
w2_weight = jax.device_put(w2_weight_transposed,
|
|
201
|
-
Format(Layout((0, 1, 2)), ep_sharding))
|
|
202
|
-
|
|
203
|
-
w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
|
|
204
|
-
w13_bias = jax.device_put(w13_bias,
|
|
205
|
-
Format(Layout((0, 1, 2)), ep_sharding))
|
|
206
|
-
w2_bias = jax.device_put(w2_bias,
|
|
207
|
-
Format(Layout((0, 1)), ep_sharding))
|
|
208
|
-
|
|
209
|
-
else:
|
|
210
|
-
if layer.use_ep:
|
|
211
|
-
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
212
|
-
w13_weight = jax.device_put(
|
|
213
|
-
w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
|
|
214
|
-
w2_weight = jax.device_put(
|
|
215
|
-
w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
|
|
216
|
-
|
|
217
|
-
w13_bias = jax.device_put(w13_bias,
|
|
218
|
-
Format(Layout((0, 1)), ep_sharding))
|
|
219
|
-
w2_bias = jax.device_put(w2_bias,
|
|
220
|
-
Format(Layout((0, 1)), ep_sharding))
|
|
221
|
-
|
|
222
|
-
else:
|
|
223
|
-
output_sizes = [intermediate_size, intermediate_size]
|
|
224
|
-
n_shards = self.mesh.shape["model"]
|
|
225
|
-
assert intermediate_size % n_shards == 0
|
|
151
|
+
@jax.jit
|
|
152
|
+
def process_mxfp4_moe_weights(
|
|
153
|
+
w13_weight: jax.Array,
|
|
154
|
+
w13_weight_scale: jax.Array,
|
|
155
|
+
w13_bias: jax.Array,
|
|
156
|
+
w2_weight: jax.Array,
|
|
157
|
+
w2_weight_scale: jax.Array,
|
|
158
|
+
w2_bias: jax.Array,
|
|
159
|
+
) -> FusedMoEWeights:
|
|
160
|
+
# Dequantize fp4 weights into fp32.
|
|
161
|
+
w13_weight = dequantize_tensor_from_mxfp4_packed(
|
|
162
|
+
w13_weight, w13_weight_scale, 2)
|
|
163
|
+
w2_weight = dequantize_tensor_from_mxfp4_packed(
|
|
164
|
+
w2_weight, w2_weight_scale, 2)
|
|
165
|
+
|
|
166
|
+
w13_interleave = layer.activation == "swigluoai"
|
|
167
|
+
w13_reorder_size = get_mesh_shape_product(
|
|
168
|
+
self.mesh, ShardingAxisName.MLP_TENSOR)
|
|
169
|
+
|
|
170
|
+
weights = quantize_moe_weights(
|
|
171
|
+
FusedMoEWeights(
|
|
172
|
+
w13_weight=w13_weight,
|
|
173
|
+
w13_weight_scale=None,
|
|
174
|
+
w13_bias=w13_bias,
|
|
175
|
+
w2_weight=w2_weight,
|
|
176
|
+
w2_weight_scale=None,
|
|
177
|
+
w2_bias=w2_bias,
|
|
178
|
+
),
|
|
179
|
+
jnp.float4_e2m1fn,
|
|
180
|
+
REQUANTIZED_BLOCK_SIZE,
|
|
181
|
+
)
|
|
182
|
+
return process_moe_weights(
|
|
183
|
+
weights,
|
|
184
|
+
moe_backend=self.moe_backend,
|
|
185
|
+
w13_reorder_size=w13_reorder_size,
|
|
186
|
+
w13_interleave=w13_interleave,
|
|
187
|
+
)
|
|
226
188
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
w2_weight = jax.device_put(
|
|
238
|
-
w2_weight,
|
|
239
|
-
Format(Layout((0, 1, 2)),
|
|
240
|
-
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
189
|
+
weights = process_mxfp4_moe_weights(
|
|
190
|
+
w13_weight,
|
|
191
|
+
w13_weight_scale,
|
|
192
|
+
w13_bias,
|
|
193
|
+
w2_weight,
|
|
194
|
+
w2_weight_scale,
|
|
195
|
+
w2_bias,
|
|
196
|
+
)
|
|
197
|
+
weights = torch_view(
|
|
198
|
+
shard_moe_weights(weights, self.moe_backend, self.mesh))
|
|
241
199
|
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
output_sizes,
|
|
245
|
-
n_shards,
|
|
246
|
-
dim=1,
|
|
247
|
-
)
|
|
248
|
-
w13_bias = jax.device_put(
|
|
249
|
-
w13_bias,
|
|
250
|
-
Format(Layout((0, 1)),
|
|
251
|
-
NamedSharding(self.mesh, P(None, "model"))))
|
|
252
|
-
w2_bias = jax.device_put(
|
|
253
|
-
w2_bias,
|
|
254
|
-
Format(Layout((0, 1)),
|
|
255
|
-
NamedSharding(self.mesh, P(None, None))))
|
|
200
|
+
layer.w13_weight = Parameter(weights.w13_weight, requires_grad=False)
|
|
201
|
+
layer.w2_weight = Parameter(weights.w2_weight, requires_grad=False)
|
|
256
202
|
|
|
257
|
-
layer.
|
|
258
|
-
|
|
259
|
-
layer.
|
|
203
|
+
layer.w13_weight_scale = Parameter(weights.w13_weight_scale,
|
|
204
|
+
requires_grad=False)
|
|
205
|
+
layer.w2_weight_scale = Parameter(weights.w2_weight_scale,
|
|
206
|
+
requires_grad=False)
|
|
260
207
|
|
|
261
|
-
layer.w13_bias = Parameter(
|
|
262
|
-
layer.w2_bias = Parameter(
|
|
208
|
+
layer.w13_bias = Parameter(weights.w13_bias, requires_grad=False)
|
|
209
|
+
layer.w2_bias = Parameter(weights.w2_bias, requires_grad=False)
|
|
263
210
|
|
|
264
211
|
def apply(
|
|
265
212
|
self,
|
|
266
213
|
layer: torch.nn.Module,
|
|
267
214
|
x: torch.Tensor,
|
|
268
215
|
router_logits: torch.Tensor,
|
|
269
|
-
) ->
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
w2_bias = jax_view(layer.w2_bias)
|
|
280
|
-
gating_output = jax_view(router_logits)
|
|
281
|
-
|
|
282
|
-
if self.use_kernel:
|
|
283
|
-
output = fused_ep_moe(
|
|
284
|
-
mesh=self.mesh,
|
|
285
|
-
tokens=x,
|
|
286
|
-
w1=w13_weight,
|
|
287
|
-
w2=w2_weight,
|
|
288
|
-
b1=w13_bias,
|
|
289
|
-
b2=w2_bias,
|
|
290
|
-
gating_output=gating_output,
|
|
291
|
-
top_k=layer.top_k,
|
|
292
|
-
ep_axis_name=self.ep_axis_name,
|
|
293
|
-
renormalize_topk_logits=layer.renormalize,
|
|
294
|
-
act_fn=layer.activation,
|
|
295
|
-
**self.block_size,
|
|
296
|
-
)
|
|
297
|
-
else:
|
|
298
|
-
output = fused_moe_func(
|
|
299
|
-
hidden_states=x,
|
|
300
|
-
w1=w13_weight,
|
|
301
|
-
w2=w2_weight,
|
|
302
|
-
w1_bias=w13_bias,
|
|
303
|
-
w2_bias=w2_bias,
|
|
304
|
-
gating_output=gating_output,
|
|
305
|
-
topk=layer.top_k,
|
|
306
|
-
renormalize=layer.renormalize,
|
|
307
|
-
mesh=self.mesh,
|
|
308
|
-
use_ep=layer.use_ep,
|
|
309
|
-
activation=layer.activation,
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
return torch_view(output)
|
|
216
|
+
) -> torch.Tensor:
|
|
217
|
+
|
|
218
|
+
return fused_moe_apply(
|
|
219
|
+
layer,
|
|
220
|
+
x,
|
|
221
|
+
router_logits,
|
|
222
|
+
self.moe_backend,
|
|
223
|
+
self.mesh,
|
|
224
|
+
self.extra_backend_kwargs,
|
|
225
|
+
)
|