tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import torch
|
|
20
|
+
from compressed_tensors.quantization import QuantizationStrategy
|
|
21
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
22
|
+
from torchax.interop import jax_view, torch_view
|
|
23
|
+
from vllm.logger import init_logger
|
|
24
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
|
|
25
|
+
CompressedTensorsW8A8Int8
|
|
26
|
+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
|
|
27
|
+
convert_to_channelwise
|
|
28
|
+
|
|
29
|
+
from tpu_inference.layers.vllm.linear_common import (
|
|
30
|
+
sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
|
|
31
|
+
torch_to_jax_param)
|
|
32
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
|
|
33
|
+
|
|
34
|
+
P = PartitionSpec
|
|
35
|
+
logger = init_logger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
|
|
39
|
+
|
|
40
|
+
def __init__(self, strategy: str, is_static_input_scheme: bool,
|
|
41
|
+
input_symmetric: bool, jax_config: JaxCommonLinearConfig):
|
|
42
|
+
super().__init__(strategy, is_static_input_scheme, input_symmetric)
|
|
43
|
+
|
|
44
|
+
self.jax_config = jax_config
|
|
45
|
+
self.is_channelwise = (self.strategy == QuantizationStrategy.CHANNEL),
|
|
46
|
+
|
|
47
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
48
|
+
weight = torch_to_jax_param(
|
|
49
|
+
layer.weight,
|
|
50
|
+
NamedSharding(self.jax_config.mesh,
|
|
51
|
+
self.jax_config.weight_sharding),
|
|
52
|
+
self.jax_config.output_sizes,
|
|
53
|
+
self.jax_config.n_shards,
|
|
54
|
+
self.jax_config.fuse_matmuls,
|
|
55
|
+
)
|
|
56
|
+
delattr(layer, "weight")
|
|
57
|
+
layer.weight = weight
|
|
58
|
+
|
|
59
|
+
weight_scale = layer.weight_scale
|
|
60
|
+
is_fused_module = len(layer.logical_widths) > 1
|
|
61
|
+
if is_fused_module and not self.is_channelwise:
|
|
62
|
+
weight_scale = convert_to_channelwise(weight_scale,
|
|
63
|
+
layer.logical_widths)
|
|
64
|
+
weight_scale = weight_scale.squeeze(-1)
|
|
65
|
+
|
|
66
|
+
weight_scale = torch_to_jax_param(
|
|
67
|
+
weight_scale,
|
|
68
|
+
NamedSharding(self.jax_config.mesh, self.jax_config.bias_sharding),
|
|
69
|
+
self.jax_config.output_sizes,
|
|
70
|
+
self.jax_config.n_shards,
|
|
71
|
+
self.jax_config.fuse_matmuls,
|
|
72
|
+
)
|
|
73
|
+
delattr(layer, "weight_scale")
|
|
74
|
+
layer.weight_scale = weight_scale
|
|
75
|
+
|
|
76
|
+
if layer.bias is not None and not layer.skip_bias_add:
|
|
77
|
+
if layer.return_bias:
|
|
78
|
+
logger.warning_once("Bias might return incorrect value.")
|
|
79
|
+
|
|
80
|
+
bias = torch_to_jax_param(
|
|
81
|
+
layer.bias,
|
|
82
|
+
NamedSharding(self.jax_config.mesh,
|
|
83
|
+
self.jax_config.bias_sharding),
|
|
84
|
+
self.jax_config.output_sizes,
|
|
85
|
+
self.jax_config.n_shards,
|
|
86
|
+
self.jax_config.fuse_matmuls,
|
|
87
|
+
)
|
|
88
|
+
delattr(layer, "bias")
|
|
89
|
+
layer.bias = bias
|
|
90
|
+
|
|
91
|
+
# TODO(kyuyeunk): Support static range input quantization.
|
|
92
|
+
assert getattr(layer, "input_scale", None) is None
|
|
93
|
+
assert getattr(layer, "input_zero_point", None) is None
|
|
94
|
+
assert getattr(layer, "azp_adj", None) is None
|
|
95
|
+
|
|
96
|
+
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
97
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
98
|
+
with jax.named_scope(layer._get_name()):
|
|
99
|
+
if self.jax_config.fuse_matmuls:
|
|
100
|
+
out = self._apply_fused(layer, x, bias)
|
|
101
|
+
else:
|
|
102
|
+
out = self._apply_split(layer, x, bias)
|
|
103
|
+
|
|
104
|
+
return out
|
|
105
|
+
|
|
106
|
+
def _apply_fused(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
107
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
108
|
+
x_jax = jax_view(x)
|
|
109
|
+
weight_jax = jax_view(layer.weight)
|
|
110
|
+
weight_scale_jax = jax_view(layer.weight_scale)
|
|
111
|
+
|
|
112
|
+
outs = sharded_quantized_matmul(
|
|
113
|
+
x_jax,
|
|
114
|
+
weight_jax,
|
|
115
|
+
weight_scale_jax,
|
|
116
|
+
self.jax_config.mesh,
|
|
117
|
+
self.jax_config.weight_sharding,
|
|
118
|
+
)
|
|
119
|
+
if bias is not None and not layer.skip_bias_add:
|
|
120
|
+
outs += jax_view(bias)
|
|
121
|
+
|
|
122
|
+
outs = slice_sharded_tensor_for_concatenation(
|
|
123
|
+
outs, self.jax_config.output_sizes, self.jax_config.n_shards)
|
|
124
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
125
|
+
return torch_view(out)
|
|
126
|
+
|
|
127
|
+
def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
128
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
129
|
+
assert isinstance(layer.weight, torch.nn.ParameterList)
|
|
130
|
+
|
|
131
|
+
x_jax = jax_view(x)
|
|
132
|
+
outs = []
|
|
133
|
+
for i, (weight, weight_scale) in enumerate(
|
|
134
|
+
zip(layer.weight, layer.weight_scale)):
|
|
135
|
+
weight_jax = jax_view(weight)
|
|
136
|
+
weight_scale_jax = jax_view(weight_scale)
|
|
137
|
+
|
|
138
|
+
out = sharded_quantized_matmul(
|
|
139
|
+
x_jax,
|
|
140
|
+
weight_jax,
|
|
141
|
+
weight_scale_jax,
|
|
142
|
+
self.jax_config.mesh,
|
|
143
|
+
self.jax_config.weight_sharding,
|
|
144
|
+
)
|
|
145
|
+
if bias is not None and not layer.skip_bias_add:
|
|
146
|
+
out += jax_view(bias[i])
|
|
147
|
+
|
|
148
|
+
outs.append(out)
|
|
149
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
150
|
+
return torch_view(out)
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Union
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import torch
|
|
19
|
+
from jax.sharding import PartitionSpec
|
|
20
|
+
from vllm.logger import init_logger
|
|
21
|
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
22
|
+
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
|
23
|
+
from vllm.model_executor.layers.quantization import \
|
|
24
|
+
register_quantization_config
|
|
25
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
26
|
+
QuantizeMethodBase
|
|
27
|
+
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
|
28
|
+
Fp8LinearMethod)
|
|
29
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
30
|
+
is_layer_skipped
|
|
31
|
+
|
|
32
|
+
from tpu_inference.layers.common.quant_methods import FP8, get_tpu_quant_method
|
|
33
|
+
from tpu_inference.layers.vllm.quantization.common import (
|
|
34
|
+
JaxCommonConfig, JaxCommonLinearConfig)
|
|
35
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
36
|
+
VllmUnquantizedLinearMethod
|
|
37
|
+
|
|
38
|
+
P = PartitionSpec
|
|
39
|
+
logger = init_logger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@register_quantization_config(get_tpu_quant_method(FP8))
|
|
43
|
+
class VllmFp8Config(Fp8Config, JaxCommonConfig):
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def get_name(cls):
|
|
47
|
+
return FP8
|
|
48
|
+
|
|
49
|
+
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
50
|
+
return [torch.bfloat16]
|
|
51
|
+
|
|
52
|
+
def get_quant_method(
|
|
53
|
+
self, layer: torch.nn.Module, prefix: str
|
|
54
|
+
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
|
|
55
|
+
if isinstance(layer, LinearBase):
|
|
56
|
+
linear_config = self.get_linear_config(layer)
|
|
57
|
+
if is_layer_skipped(prefix, self.ignored_layers):
|
|
58
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
59
|
+
return VllmFp8LinearMethod(self, linear_config)
|
|
60
|
+
elif isinstance(layer, FusedMoE):
|
|
61
|
+
raise NotImplementedError(
|
|
62
|
+
"FP8 FusedMoE is currently not supported in torchax-jax")
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class VllmFp8LinearMethod(Fp8LinearMethod):
|
|
67
|
+
|
|
68
|
+
def __init__(self, quant_config: VllmFp8Config,
|
|
69
|
+
jax_config: JaxCommonLinearConfig):
|
|
70
|
+
super().__init__(quant_config)
|
|
71
|
+
self.jax_config = jax_config
|
|
72
|
+
self._configure_sharding()
|
|
73
|
+
|
|
74
|
+
def _configure_sharding(self) -> None:
|
|
75
|
+
|
|
76
|
+
raise NotImplementedError(
|
|
77
|
+
"Configure PartitionSpec for weight_sharding and scale_sharding "
|
|
78
|
+
"based on layer type (RowParallel/ColumnParallel)")
|
|
79
|
+
|
|
80
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
81
|
+
|
|
82
|
+
raise NotImplementedError(
|
|
83
|
+
"Convert layer.weight, layer.weight_scale, and optionally "
|
|
84
|
+
"layer.input_scale and layer.bias from torch tensors to JAX arrays "
|
|
85
|
+
"using torch_to_jax_param() with appropriate sharding")
|
|
86
|
+
|
|
87
|
+
def apply(self,
|
|
88
|
+
layer: torch.nn.Module,
|
|
89
|
+
x: torch.Tensor,
|
|
90
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
91
|
+
|
|
92
|
+
with jax.named_scope(layer._get_name()):
|
|
93
|
+
if self.jax_config.fuse_matmuls:
|
|
94
|
+
out = self._apply_fused(layer, x, bias)
|
|
95
|
+
else:
|
|
96
|
+
out = self._apply_split(layer, x, bias)
|
|
97
|
+
|
|
98
|
+
return out
|
|
99
|
+
|
|
100
|
+
def _apply_fused(self,
|
|
101
|
+
layer: torch.nn.Module,
|
|
102
|
+
x: torch.Tensor,
|
|
103
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
104
|
+
|
|
105
|
+
raise NotImplementedError(
|
|
106
|
+
"Implement single matmul for fused outputs: "
|
|
107
|
+
"quantize input to fp8, perform fp8 matmul with weight and scales, "
|
|
108
|
+
"dequantize output, and add bias if present")
|
|
109
|
+
|
|
110
|
+
def _apply_split(self,
|
|
111
|
+
layer: torch.nn.Module,
|
|
112
|
+
x: torch.Tensor,
|
|
113
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
114
|
+
|
|
115
|
+
raise NotImplementedError(
|
|
116
|
+
"Implement separate matmuls per output partition: "
|
|
117
|
+
"split weight/scale by output_sizes, perform fp8 matmul for each, "
|
|
118
|
+
"concatenate results, and add bias if present")
|
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Union
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import torch
|
|
20
|
+
from jax.experimental.layout import Format, Layout
|
|
21
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
22
|
+
from torch.nn.parameter import Parameter
|
|
23
|
+
from torchax.interop import jax_view, torch_view
|
|
24
|
+
from torchax.ops.mappings import t2j
|
|
25
|
+
from vllm.logger import init_logger
|
|
26
|
+
from vllm.model_executor.layers.fused_moe.config import (
|
|
27
|
+
FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config)
|
|
28
|
+
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
|
29
|
+
FusedMoEMethodBase)
|
|
30
|
+
from vllm.model_executor.layers.linear import LinearBase
|
|
31
|
+
from vllm.model_executor.layers.quantization import \
|
|
32
|
+
register_quantization_config
|
|
33
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
34
|
+
QuantizeMethodBase
|
|
35
|
+
from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
|
|
36
|
+
Mxfp4Config,
|
|
37
|
+
Mxfp4MoEMethod)
|
|
38
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
39
|
+
is_layer_skipped
|
|
40
|
+
|
|
41
|
+
from tpu_inference import envs
|
|
42
|
+
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
43
|
+
from tpu_inference.layers.common.quant_methods import (MXFP4,
|
|
44
|
+
get_tpu_quant_method)
|
|
45
|
+
from tpu_inference.layers.common.quantization import (
|
|
46
|
+
dequantize_tensor_from_mxfp4_packed, quantize_tensor)
|
|
47
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func
|
|
48
|
+
from tpu_inference.layers.vllm.linear_common import \
|
|
49
|
+
reorder_concatenated_tensor_for_sharding
|
|
50
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
51
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
52
|
+
VllmUnquantizedLinearMethod
|
|
53
|
+
|
|
54
|
+
REQUANTIZED_BLOCK_SIZE = 512
|
|
55
|
+
|
|
56
|
+
P = PartitionSpec
|
|
57
|
+
|
|
58
|
+
logger = init_logger(__name__)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@register_quantization_config(get_tpu_quant_method(MXFP4))
|
|
62
|
+
class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def get_name(cls):
|
|
66
|
+
return MXFP4
|
|
67
|
+
|
|
68
|
+
def get_quant_method(self, layer: torch.nn.Module,
|
|
69
|
+
prefix: str) -> Optional["QuantizeMethodBase"]:
|
|
70
|
+
from vllm.attention.layer import Attention # Avoid circular import
|
|
71
|
+
|
|
72
|
+
if isinstance(layer, LinearBase):
|
|
73
|
+
linear_config = self.get_linear_config(layer)
|
|
74
|
+
if self.ignored_layers and is_layer_skipped(
|
|
75
|
+
prefix=prefix,
|
|
76
|
+
ignored_layers=self.ignored_layers,
|
|
77
|
+
fused_mapping=self.packed_modules_mapping,
|
|
78
|
+
):
|
|
79
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
80
|
+
logger.warning_once(
|
|
81
|
+
"MXFP4 linear layer is not implemented - falling back to "
|
|
82
|
+
"UnquantizedLinearMethod.")
|
|
83
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
84
|
+
elif isinstance(layer, FusedMoE):
|
|
85
|
+
moe_config = self.get_moe_config(layer)
|
|
86
|
+
return VllmMxfp4MoEMethod(moe_config, self.mesh)
|
|
87
|
+
elif isinstance(layer, Attention):
|
|
88
|
+
logger.warning_once("MXFP4 attention layer is not implemented. "
|
|
89
|
+
"Skipping quantization for this layer.")
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
94
|
+
|
|
95
|
+
def __init__(self,
|
|
96
|
+
moe: FusedMoEConfig,
|
|
97
|
+
mesh: Mesh,
|
|
98
|
+
ep_axis_name: str = 'model'):
|
|
99
|
+
FusedMoEMethodBase.__init__(self, moe)
|
|
100
|
+
|
|
101
|
+
# We piggyback on triton implementation as it applies minimal hardware
|
|
102
|
+
# specific post processing to the weights.
|
|
103
|
+
self.mxfp4_backend = Mxfp4Backend.TRITON
|
|
104
|
+
|
|
105
|
+
self.mesh = mesh
|
|
106
|
+
self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
|
|
107
|
+
self.ep_axis_name = ep_axis_name
|
|
108
|
+
# TODO: Use autotune table once we have it.
|
|
109
|
+
self.block_size = {
|
|
110
|
+
"bt": 256,
|
|
111
|
+
"bf": 1024,
|
|
112
|
+
"bd1": 1024,
|
|
113
|
+
"bd2": 1024,
|
|
114
|
+
"btc": 256,
|
|
115
|
+
"bfc": 1024,
|
|
116
|
+
"bd1c": 1024,
|
|
117
|
+
"bd2c": 1024,
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
def get_fused_moe_quant_config(
|
|
121
|
+
self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
|
|
122
|
+
return mxfp4_w4a16_moe_quant_config(
|
|
123
|
+
w1_scale=layer.w13_weight_scale,
|
|
124
|
+
w2_scale=layer.w2_weight_scale,
|
|
125
|
+
w1_bias=layer.w13_bias,
|
|
126
|
+
w2_bias=layer.w2_bias,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
130
|
+
assert isinstance(layer, FusedMoE)
|
|
131
|
+
assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
|
|
132
|
+
|
|
133
|
+
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
134
|
+
w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
|
|
135
|
+
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
136
|
+
|
|
137
|
+
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
138
|
+
w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
|
|
139
|
+
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
140
|
+
|
|
141
|
+
# Wrap functions in jit to speedup requantization.
|
|
142
|
+
@jax.jit
|
|
143
|
+
def wrapper(w13_weight, w13_weight_scale, w13_bias, w2_weight,
|
|
144
|
+
w2_weight_scale, w2_bias):
|
|
145
|
+
# Dequantize fp4 weights into fp32.
|
|
146
|
+
w13_weight = dequantize_tensor_from_mxfp4_packed(
|
|
147
|
+
w13_weight, w13_weight_scale, 2)
|
|
148
|
+
w2_weight = dequantize_tensor_from_mxfp4_packed(
|
|
149
|
+
w2_weight, w2_weight_scale, 2)
|
|
150
|
+
|
|
151
|
+
num_experts, orig_hidden_size, orig_intermediate_size = w2_weight.shape
|
|
152
|
+
|
|
153
|
+
# Requantize the weights into TPU friendly block size.
|
|
154
|
+
w13_weight, w13_weight_scale = quantize_tensor(
|
|
155
|
+
jnp.float4_e2m1fn, w13_weight, 2, REQUANTIZED_BLOCK_SIZE, True)
|
|
156
|
+
w2_weight, w2_weight_scale = quantize_tensor(
|
|
157
|
+
jnp.float4_e2m1fn, w2_weight, 2, REQUANTIZED_BLOCK_SIZE, True)
|
|
158
|
+
|
|
159
|
+
intermediate_size = w2_weight.shape[-1]
|
|
160
|
+
hidden_size = w13_weight.shape[-1]
|
|
161
|
+
|
|
162
|
+
# Dims may have been padded to align with subchannel size during
|
|
163
|
+
# quantization. We pad the corresponding dim on other weight.
|
|
164
|
+
# NOTE: We perform padding after quantization as padding value can
|
|
165
|
+
# affect quantization numerics.
|
|
166
|
+
intermediate_padding_size = 2 * (intermediate_size -
|
|
167
|
+
orig_intermediate_size)
|
|
168
|
+
w13_weight = jnp.pad(w13_weight,
|
|
169
|
+
((0, 0), (0, intermediate_padding_size),
|
|
170
|
+
(0, 0)))
|
|
171
|
+
w13_weight_scale = jnp.pad(w13_weight_scale,
|
|
172
|
+
((0, 0), (0, intermediate_padding_size),
|
|
173
|
+
(0, 0)))
|
|
174
|
+
w13_bias = jnp.pad(w13_bias,
|
|
175
|
+
((0, 0), (0, intermediate_padding_size)))
|
|
176
|
+
|
|
177
|
+
hidden_padding_size = hidden_size - orig_hidden_size
|
|
178
|
+
w2_weight = jnp.pad(w2_weight,
|
|
179
|
+
((0, 0), (0, hidden_padding_size), (0, 0)))
|
|
180
|
+
w2_weight_scale = jnp.pad(w2_weight_scale,
|
|
181
|
+
((0, 0), (0, hidden_padding_size),
|
|
182
|
+
(0, 0)))
|
|
183
|
+
w2_bias = jnp.pad(w2_bias, ((0, 0), (0, hidden_padding_size)))
|
|
184
|
+
|
|
185
|
+
if layer.activation == "swigluoai":
|
|
186
|
+
# When using swigluoai, vLLM splits gmm output in a interleaved way.
|
|
187
|
+
# However, interleaved split is not performant on TPU. Therefore,
|
|
188
|
+
# we preprocess the weight so that splitting gmm output by middle
|
|
189
|
+
# can still get the same result.
|
|
190
|
+
w1_weight = w13_weight[:, ::2, :]
|
|
191
|
+
w3_weight = w13_weight[:, 1::2, :]
|
|
192
|
+
w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
|
|
193
|
+
|
|
194
|
+
w1_weight_scale = w13_weight_scale[:, ::2, :]
|
|
195
|
+
w3_weight_scale = w13_weight_scale[:, 1::2, :]
|
|
196
|
+
w13_weight_scale = jnp.concat(
|
|
197
|
+
[w1_weight_scale, w3_weight_scale], axis=1)
|
|
198
|
+
|
|
199
|
+
w1_bias = w13_bias[:, ::2]
|
|
200
|
+
w3_bias = w13_bias[:, 1::2]
|
|
201
|
+
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
202
|
+
|
|
203
|
+
if self.use_kernel:
|
|
204
|
+
# Kernel expects:
|
|
205
|
+
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
206
|
+
# w2: (num_experts, intermediate_size, hidden_size)
|
|
207
|
+
# Current format:
|
|
208
|
+
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
209
|
+
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
210
|
+
|
|
211
|
+
w13_weight = w13_weight.reshape(num_experts, 2,
|
|
212
|
+
intermediate_size, hidden_size)
|
|
213
|
+
|
|
214
|
+
w13_weight_scale = w13_weight_scale.reshape(
|
|
215
|
+
num_experts, 2, intermediate_size, 1, -1)
|
|
216
|
+
w2_weight_scale = w2_weight_scale.reshape(
|
|
217
|
+
num_experts, hidden_size, 1, -1)
|
|
218
|
+
|
|
219
|
+
w13_bias = w13_bias.astype(jnp.float32).reshape(
|
|
220
|
+
num_experts, 2, 1, intermediate_size)
|
|
221
|
+
w2_bias = w2_bias.astype(jnp.float32).reshape(
|
|
222
|
+
num_experts, 1, hidden_size)
|
|
223
|
+
|
|
224
|
+
# Transpose non-constracting dim to right most dim
|
|
225
|
+
w13_weight = jnp.swapaxes(w13_weight, 2, 3)
|
|
226
|
+
w2_weight = jnp.swapaxes(w2_weight, 1, 2)
|
|
227
|
+
|
|
228
|
+
w13_weight_scale = jnp.swapaxes(w13_weight_scale, 2, 4)
|
|
229
|
+
w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 3)
|
|
230
|
+
|
|
231
|
+
# Apply EP sharding
|
|
232
|
+
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
233
|
+
|
|
234
|
+
w13_weight = jax.lax.with_sharding_constraint(
|
|
235
|
+
w13_weight, Format(Layout((0, 1, 2, 3)), ep_sharding))
|
|
236
|
+
w2_weight = jax.lax.with_sharding_constraint(
|
|
237
|
+
w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
|
|
238
|
+
|
|
239
|
+
w13_weight_scale = jax.lax.with_sharding_constraint(
|
|
240
|
+
w13_weight_scale,
|
|
241
|
+
Format(Layout((0, 1, 2, 3, 4)), ep_sharding))
|
|
242
|
+
w2_weight_scale = jax.lax.with_sharding_constraint(
|
|
243
|
+
w2_weight_scale, Format(Layout((0, 1, 2, 3)), ep_sharding))
|
|
244
|
+
|
|
245
|
+
w13_bias = jax.lax.with_sharding_constraint(
|
|
246
|
+
w13_bias, Format(Layout((0, 1, 2, 3)), ep_sharding))
|
|
247
|
+
w2_bias = jax.lax.with_sharding_constraint(
|
|
248
|
+
w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
|
|
249
|
+
else:
|
|
250
|
+
w13_weight_scale = jnp.swapaxes(w13_weight_scale, 1, 2)
|
|
251
|
+
w13_weight_scale = jnp.expand_dims(w13_weight_scale, 2)
|
|
252
|
+
w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 2)
|
|
253
|
+
w2_weight_scale = jnp.expand_dims(w2_weight_scale, 2)
|
|
254
|
+
|
|
255
|
+
w13_bias = jnp.expand_dims(w13_bias, 1)
|
|
256
|
+
w2_bias = jnp.expand_dims(w2_bias, 1)
|
|
257
|
+
|
|
258
|
+
if layer.use_ep:
|
|
259
|
+
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
260
|
+
|
|
261
|
+
w13_weight = jax.lax.with_sharding_constraint(
|
|
262
|
+
w13_weight, ep_sharding)
|
|
263
|
+
w2_weight = jax.lax.with_sharding_constraint(
|
|
264
|
+
w2_weight, ep_sharding)
|
|
265
|
+
|
|
266
|
+
w13_weight_scale = jax.lax.with_sharding_constraint(
|
|
267
|
+
w13_weight_scale, ep_sharding)
|
|
268
|
+
w2_weight_scale = jax.lax.with_sharding_constraint(
|
|
269
|
+
w2_weight_scale, ep_sharding)
|
|
270
|
+
|
|
271
|
+
w13_bias = jax.lax.with_sharding_constraint(
|
|
272
|
+
w13_bias, ep_sharding)
|
|
273
|
+
w2_bias = jax.lax.with_sharding_constraint(
|
|
274
|
+
w2_bias, ep_sharding)
|
|
275
|
+
|
|
276
|
+
else:
|
|
277
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
278
|
+
n_shards = self.mesh.shape["model"]
|
|
279
|
+
assert intermediate_size % n_shards == 0
|
|
280
|
+
|
|
281
|
+
# Reorder w13 weights so that splitting w1 and w3 output
|
|
282
|
+
# can happen locally without any collective operations.
|
|
283
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
284
|
+
w13_weight,
|
|
285
|
+
output_sizes,
|
|
286
|
+
n_shards,
|
|
287
|
+
dim=1,
|
|
288
|
+
)
|
|
289
|
+
w13_weight_scale = reorder_concatenated_tensor_for_sharding(
|
|
290
|
+
w13_weight_scale,
|
|
291
|
+
output_sizes,
|
|
292
|
+
n_shards,
|
|
293
|
+
dim=3,
|
|
294
|
+
)
|
|
295
|
+
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
296
|
+
w13_bias,
|
|
297
|
+
output_sizes,
|
|
298
|
+
n_shards,
|
|
299
|
+
dim=2,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
w13_weight = jax.lax.with_sharding_constraint(
|
|
303
|
+
w13_weight,
|
|
304
|
+
NamedSharding(self.mesh, P(None, "model", None)))
|
|
305
|
+
w2_weight = jax.lax.with_sharding_constraint(
|
|
306
|
+
w2_weight,
|
|
307
|
+
NamedSharding(self.mesh, P(None, None, "model")))
|
|
308
|
+
w13_weight_scale = jax.lax.with_sharding_constraint(
|
|
309
|
+
w13_weight_scale,
|
|
310
|
+
NamedSharding(self.mesh, P(None, None, None, "model")))
|
|
311
|
+
w2_weight_scale = jax.lax.with_sharding_constraint(
|
|
312
|
+
w2_weight_scale,
|
|
313
|
+
NamedSharding(self.mesh, P(None, "model", None, None)))
|
|
314
|
+
w13_bias = jax.lax.with_sharding_constraint(
|
|
315
|
+
w13_bias,
|
|
316
|
+
NamedSharding(self.mesh, P(None, None, "model")))
|
|
317
|
+
w2_bias = jax.lax.with_sharding_constraint(
|
|
318
|
+
w2_bias, NamedSharding(self.mesh, P(None, None, None)))
|
|
319
|
+
|
|
320
|
+
return w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale, w2_bias
|
|
321
|
+
|
|
322
|
+
w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale, w2_bias = wrapper(
|
|
323
|
+
w13_weight, w13_weight_scale, w13_bias, w2_weight, w2_weight_scale,
|
|
324
|
+
w2_bias)
|
|
325
|
+
|
|
326
|
+
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
327
|
+
requires_grad=False)
|
|
328
|
+
layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
329
|
+
|
|
330
|
+
layer.w13_weight_scale = Parameter(torch_view(w13_weight_scale),
|
|
331
|
+
requires_grad=False)
|
|
332
|
+
layer.w2_weight_scale = Parameter(torch_view(w2_weight_scale),
|
|
333
|
+
requires_grad=False)
|
|
334
|
+
|
|
335
|
+
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
|
|
336
|
+
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
337
|
+
|
|
338
|
+
def apply(
|
|
339
|
+
self,
|
|
340
|
+
layer: torch.nn.Module,
|
|
341
|
+
x: torch.Tensor,
|
|
342
|
+
router_logits: torch.Tensor,
|
|
343
|
+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
344
|
+
assert isinstance(layer, FusedMoE)
|
|
345
|
+
if layer.scoring_func != "softmax":
|
|
346
|
+
raise NotImplementedError(
|
|
347
|
+
"Only softmax is supported for scoring_func")
|
|
348
|
+
|
|
349
|
+
x = jax_view(x)
|
|
350
|
+
w13_weight = jax_view(layer.w13_weight)
|
|
351
|
+
w2_weight = jax_view(layer.w2_weight)
|
|
352
|
+
w13_weight_scale = jax_view(layer.w13_weight_scale)
|
|
353
|
+
w2_weight_scale = jax_view(layer.w2_weight_scale)
|
|
354
|
+
w13_bias = jax_view(layer.w13_bias)
|
|
355
|
+
w2_bias = jax_view(layer.w2_bias)
|
|
356
|
+
gating_output = jax_view(router_logits)
|
|
357
|
+
|
|
358
|
+
if self.use_kernel:
|
|
359
|
+
actual_hidden_size = x.shape[-1]
|
|
360
|
+
padding_size = w13_weight.shape[-2] - actual_hidden_size
|
|
361
|
+
x = jnp.pad(x, ((0, 0), (0, padding_size)))
|
|
362
|
+
output = fused_ep_moe(
|
|
363
|
+
mesh=self.mesh,
|
|
364
|
+
tokens=x,
|
|
365
|
+
w1=w13_weight,
|
|
366
|
+
w2=w2_weight,
|
|
367
|
+
w1_scale=w13_weight_scale,
|
|
368
|
+
w2_scale=w2_weight_scale,
|
|
369
|
+
b1=w13_bias,
|
|
370
|
+
b2=w2_bias,
|
|
371
|
+
gating_output=gating_output,
|
|
372
|
+
subc_quant_wsz=REQUANTIZED_BLOCK_SIZE,
|
|
373
|
+
top_k=layer.top_k,
|
|
374
|
+
ep_axis_name=self.ep_axis_name,
|
|
375
|
+
renormalize_topk_logits=layer.renormalize,
|
|
376
|
+
act_fn=layer.activation,
|
|
377
|
+
**self.block_size,
|
|
378
|
+
)[:, :actual_hidden_size]
|
|
379
|
+
else:
|
|
380
|
+
output = fused_moe_func(
|
|
381
|
+
hidden_states=x,
|
|
382
|
+
w1=w13_weight,
|
|
383
|
+
w2=w2_weight,
|
|
384
|
+
w1_scale=w13_weight_scale,
|
|
385
|
+
w2_scale=w2_weight_scale,
|
|
386
|
+
w1_bias=w13_bias,
|
|
387
|
+
w2_bias=w2_bias,
|
|
388
|
+
gating_output=gating_output,
|
|
389
|
+
topk=layer.top_k,
|
|
390
|
+
renormalize=layer.renormalize,
|
|
391
|
+
mesh=self.mesh,
|
|
392
|
+
use_ep=layer.use_ep,
|
|
393
|
+
activation=layer.activation,
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
return torch_view(output)
|