tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,203 +1,199 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
2
14
|
|
|
3
15
|
import jax
|
|
4
|
-
import jax.numpy as jnp
|
|
5
16
|
import torch
|
|
6
|
-
|
|
7
|
-
from jax.
|
|
8
|
-
from jax.sharding import Mesh, NamedSharding
|
|
9
|
-
from jax.sharding import PartitionSpec as P
|
|
17
|
+
from compressed_tensors.quantization import QuantizationArgs
|
|
18
|
+
from jax.sharding import Mesh
|
|
10
19
|
from torch.nn.parameter import Parameter
|
|
11
|
-
from torchax.interop import
|
|
20
|
+
from torchax.interop import torch_view
|
|
12
21
|
from torchax.ops.mappings import t2j
|
|
13
|
-
from vllm.logger import init_logger
|
|
14
22
|
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
|
|
15
|
-
from vllm.model_executor.layers.quantization.compressed_tensors.
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
from
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
from tpu_inference.layers.vllm.
|
|
23
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
|
24
|
+
CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod)
|
|
25
|
+
|
|
26
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
27
|
+
from tpu_inference.layers.vllm.fused_moe import (FusedMoEBackend,
|
|
28
|
+
fused_moe_apply,
|
|
29
|
+
select_moe_backend)
|
|
30
|
+
from tpu_inference.layers.vllm.process_weights.fused_moe_weights import (
|
|
31
|
+
FusedMoEWeights, process_moe_weights, shard_moe_weights)
|
|
32
|
+
from tpu_inference.layers.vllm.quantization.configs import VllmQuantConfig
|
|
33
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
34
|
+
VllmUnquantizedFusedMoEMethod
|
|
35
|
+
from tpu_inference.logger import init_logger
|
|
36
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
23
37
|
|
|
24
38
|
logger = init_logger(__name__)
|
|
25
39
|
|
|
26
40
|
|
|
41
|
+
class VllmCompressedTensorsMoEMethod(CompressedTensorsMoEMethod):
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def get_moe_method(
|
|
45
|
+
quant_config: "VllmCompressedTensorsConfig", # type: ignore # noqa E501
|
|
46
|
+
layer: torch.nn.Module,
|
|
47
|
+
layer_name: str,
|
|
48
|
+
) -> CompressedTensorsMoEMethod:
|
|
49
|
+
assert isinstance(layer, FusedMoE)
|
|
50
|
+
|
|
51
|
+
# FusedMoE was made by combining multiple Linears so need to
|
|
52
|
+
# make sure quantization config for Linear can target it
|
|
53
|
+
quant_config._add_fused_moe_to_target_scheme_map()
|
|
54
|
+
unfused_names = [
|
|
55
|
+
layer_name + proj_name
|
|
56
|
+
for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
|
|
57
|
+
]
|
|
58
|
+
# TODO: refactor this to use expert_mapping and check all layer numbers
|
|
59
|
+
all_scheme_dicts = [
|
|
60
|
+
quant_config.get_scheme_dict(layer, name) for name in unfused_names
|
|
61
|
+
]
|
|
62
|
+
scheme_dict = all_scheme_dicts.pop()
|
|
63
|
+
|
|
64
|
+
# multiple schemes found
|
|
65
|
+
if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
|
|
66
|
+
raise ValueError("All MoE projections need to have same "
|
|
67
|
+
"quantization scheme but found multiple")
|
|
68
|
+
|
|
69
|
+
if scheme_dict is None:
|
|
70
|
+
return VllmUnquantizedFusedMoEMethod(layer.moe_config,
|
|
71
|
+
quant_config.mesh)
|
|
72
|
+
|
|
73
|
+
weight_quant = scheme_dict.get("weights")
|
|
74
|
+
input_quant = scheme_dict.get("input_activations")
|
|
75
|
+
|
|
76
|
+
if quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
|
77
|
+
return VllmCompressedTensorsW8A8Fp8MoEMethod(
|
|
78
|
+
weight_quant, input_quant, layer.moe_config, quant_config.mesh)
|
|
79
|
+
else:
|
|
80
|
+
raise RuntimeError(
|
|
81
|
+
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
|
82
|
+
|
|
83
|
+
|
|
27
84
|
class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
|
|
28
|
-
|
|
85
|
+
VllmQuantConfig):
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
weight_quant: QuantizationArgs,
|
|
90
|
+
input_quant: QuantizationArgs,
|
|
91
|
+
moe: FusedMoEConfig,
|
|
92
|
+
mesh: Mesh,
|
|
93
|
+
):
|
|
94
|
+
super().__init__(weight_quant, input_quant, moe)
|
|
29
95
|
|
|
30
|
-
def __init__(self, quant_config: "CompressedTensorsConfig",
|
|
31
|
-
moe: FusedMoEConfig, mesh: Mesh):
|
|
32
|
-
super().__init__(quant_config, moe)
|
|
33
96
|
self.mesh = mesh
|
|
34
|
-
self.
|
|
97
|
+
self.moe_backend = select_moe_backend(self.moe)
|
|
35
98
|
|
|
36
|
-
|
|
37
|
-
self.
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
self.disable_expert_map = False
|
|
99
|
+
self.extra_backend_kwargs = {}
|
|
100
|
+
if self.moe_backend == FusedMoEBackend.FUSED_MOE:
|
|
101
|
+
raise NotImplementedError(
|
|
102
|
+
"Per-channel quantization is not supported in FusedMoE kernel."
|
|
103
|
+
)
|
|
42
104
|
|
|
43
105
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
106
|
+
"""
|
|
107
|
+
Docstring for process_weights_after_loading
|
|
108
|
+
|
|
109
|
+
:param self: Description
|
|
110
|
+
:param layer: Description
|
|
111
|
+
:type layer: torch.nn.Module
|
|
112
|
+
|
|
113
|
+
Steps:
|
|
114
|
+
1. Read weights from layer object and convert to jax arrays
|
|
115
|
+
2. Interleave concat w13 weights
|
|
116
|
+
3. Shard weights for tp (rowwise w13, colwise w2)
|
|
117
|
+
4. Initialize Params as torch.nn.Parameter
|
|
118
|
+
a. w13_weight - float8_e4m3fn shape: (num_experts, 2 x intermediate_size, input_size)
|
|
119
|
+
b. w2_weight - float8_e4m3fn shape: (num_experts, output_size, intermediate_size)
|
|
120
|
+
c. w13_weight_scale - FP32 shape: (num_experts, 2 x intermediate_size, 1)
|
|
121
|
+
d. w2_weight_scale - FP32shape: (num_experts, output_size, 1)
|
|
122
|
+
"""
|
|
44
123
|
assert isinstance(layer, FusedMoE)
|
|
45
124
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
w3_weight = layer.w13_weight[:, intermediate_size:]
|
|
49
|
-
w1_weight_scale = layer.w13_weight_scale[:, :intermediate_size]
|
|
50
|
-
w3_weight_scale = layer.w13_weight_scale[:, intermediate_size:]
|
|
51
|
-
|
|
125
|
+
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
126
|
+
w13_weight_scale = t2j(layer.w13_weight_scale, use_dlpack=False)
|
|
52
127
|
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
53
|
-
w2_weight_scale = t2j(layer.w2_weight_scale
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
w3_weight = t2j(w3_weight, use_dlpack=False)
|
|
59
|
-
w3_weight_scale = t2j(w3_weight_scale.to(torch.bfloat16),
|
|
60
|
-
use_dlpack=False)
|
|
61
|
-
|
|
62
|
-
if layer.use_ep:
|
|
63
|
-
format = Format(Layout((0, 1, 2)),
|
|
64
|
-
NamedSharding(self.mesh, P("model", None, None)))
|
|
65
|
-
w1_weight = jax.device_put(w1_weight, format)
|
|
66
|
-
w1_weight_scale = jax.device_put(w1_weight_scale, format)
|
|
67
|
-
w3_weight = jax.device_put(w3_weight, format)
|
|
68
|
-
w3_weight_scale = jax.device_put(w3_weight_scale, format)
|
|
69
|
-
w2_weight = jax.device_put(w2_weight, format)
|
|
70
|
-
w2_weight_scale = jax.device_put(w2_weight_scale, format)
|
|
128
|
+
w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
|
|
129
|
+
|
|
130
|
+
if self.moe.has_bias:
|
|
131
|
+
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
132
|
+
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
71
133
|
else:
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
134
|
+
w13_bias = w2_bias = None
|
|
135
|
+
|
|
136
|
+
@jax.jit
|
|
137
|
+
def process_fp8_moe_weights(
|
|
138
|
+
w13_weight: jax.Array,
|
|
139
|
+
w13_weight_scale: jax.Array,
|
|
140
|
+
w13_bias: jax.Array | None,
|
|
141
|
+
w2_weight: jax.Array,
|
|
142
|
+
w2_weight_scale: jax.Array,
|
|
143
|
+
w2_bias: jax.Array | None,
|
|
144
|
+
) -> FusedMoEWeights:
|
|
145
|
+
w13_interleave = layer.activation == "swigluoai"
|
|
146
|
+
w13_reorder_size = get_mesh_shape_product(
|
|
147
|
+
self.mesh, ShardingAxisName.MLP_TENSOR)
|
|
148
|
+
|
|
149
|
+
return process_moe_weights(
|
|
150
|
+
weights=FusedMoEWeights(
|
|
151
|
+
w13_weight=w13_weight,
|
|
152
|
+
w13_weight_scale=w13_weight_scale,
|
|
153
|
+
w13_bias=w13_bias,
|
|
154
|
+
w2_weight=w2_weight,
|
|
155
|
+
w2_weight_scale=w2_weight_scale,
|
|
156
|
+
w2_bias=w2_bias,
|
|
157
|
+
),
|
|
158
|
+
moe_backend=self.moe_backend,
|
|
159
|
+
w13_reorder_size=w13_reorder_size,
|
|
160
|
+
w13_interleave=w13_interleave,
|
|
93
161
|
)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
layer.
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
162
|
+
|
|
163
|
+
weights = process_fp8_moe_weights(
|
|
164
|
+
w13_weight,
|
|
165
|
+
w13_weight_scale,
|
|
166
|
+
w13_bias,
|
|
167
|
+
w2_weight,
|
|
168
|
+
w2_weight_scale,
|
|
169
|
+
w2_bias,
|
|
170
|
+
)
|
|
171
|
+
weights = torch_view(
|
|
172
|
+
shard_moe_weights(weights, self.moe_backend, self.mesh))
|
|
173
|
+
|
|
174
|
+
layer.w13_weight = Parameter(weights.w13_weight, requires_grad=False)
|
|
175
|
+
layer.w2_weight = Parameter(weights.w2_weight, requires_grad=False)
|
|
176
|
+
|
|
177
|
+
layer.w13_weight_scale = Parameter(weights.w13_weight_scale,
|
|
178
|
+
requires_grad=False)
|
|
179
|
+
layer.w2_weight_scale = Parameter(weights.w2_weight_scale,
|
|
180
|
+
requires_grad=False)
|
|
181
|
+
|
|
182
|
+
if self.moe.has_bias:
|
|
183
|
+
layer.w13_bias = Parameter(weights.w13_bias, requires_grad=False)
|
|
184
|
+
layer.w2_bias = Parameter(weights.w2_bias, requires_grad=False)
|
|
116
185
|
|
|
117
186
|
def apply(
|
|
118
187
|
self,
|
|
119
188
|
layer: torch.nn.Module,
|
|
120
189
|
x: torch.Tensor,
|
|
121
190
|
router_logits: torch.Tensor,
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
routed_scaling_factor: float = 1.0,
|
|
132
|
-
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
133
|
-
apply_router_weight_on_input: bool = False,
|
|
134
|
-
activation: str = "silu",
|
|
135
|
-
enable_eplb: bool = False,
|
|
136
|
-
expert_load_view: Optional[torch.Tensor] = None,
|
|
137
|
-
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
138
|
-
logical_replica_count: Optional[torch.Tensor] = None,
|
|
139
|
-
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
140
|
-
assert isinstance(layer, FusedMoE)
|
|
141
|
-
if activation != "silu":
|
|
142
|
-
raise NotImplementedError(
|
|
143
|
-
"Only silu is supported for activation function.")
|
|
144
|
-
if scoring_func != "softmax":
|
|
145
|
-
raise NotImplementedError(
|
|
146
|
-
"Only softmax is supported for scoring_func")
|
|
147
|
-
|
|
148
|
-
# import sys
|
|
149
|
-
# sys.stdin = open(0)
|
|
150
|
-
# breakpoint()
|
|
151
|
-
|
|
152
|
-
# TODO: Use MoE kernel when it supports fp8
|
|
153
|
-
|
|
154
|
-
seqlen = x.shape[0]
|
|
155
|
-
|
|
156
|
-
expert_weights = F.softmax(router_logits, dim=-1)
|
|
157
|
-
expert_weights, expert_indices = torch.topk(expert_weights,
|
|
158
|
-
top_k,
|
|
159
|
-
dim=-1)
|
|
160
|
-
if renormalize:
|
|
161
|
-
expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
|
|
162
|
-
|
|
163
|
-
# cond ffn
|
|
164
|
-
# e = total num of exp = 160
|
|
165
|
-
# t = seqlen
|
|
166
|
-
# o = config.imtermediate size
|
|
167
|
-
# i = config.dim
|
|
168
|
-
#torch.einsum("ti, eoi -> teo", x, layer.w13_weight) * self.w13_weight_scale)
|
|
169
|
-
ux1 = call_jax(jax.lax.dot,
|
|
170
|
-
x,
|
|
171
|
-
layer.w13_weight,
|
|
172
|
-
dimension_numbers=(((1, ), (2, )), ((), ())),
|
|
173
|
-
preferred_element_type=jnp.bfloat16.dtype)
|
|
174
|
-
x1 = F.silu(ux1 * layer.w13_weight_scale.squeeze(2))
|
|
175
|
-
|
|
176
|
-
#x3 = torch.einsum("ti, eoi -> teo", x, layer.w3_weight) * self.w3_weight_scale
|
|
177
|
-
x3 = call_jax(jax.lax.dot,
|
|
178
|
-
x,
|
|
179
|
-
layer.w3_weight,
|
|
180
|
-
dimension_numbers=(((1, ), (2, )), ((), ())),
|
|
181
|
-
preferred_element_type=jnp.bfloat16.dtype
|
|
182
|
-
) * layer.w3_weight_scale.squeeze(2)
|
|
183
|
-
|
|
184
|
-
#expert_outs = torch.einsum("teo, eio -> tei", (x1 * x3), self.w2_weight) * self.w2_weight_scale
|
|
185
|
-
expert_outs = call_jax(
|
|
186
|
-
jax.lax.dot,
|
|
187
|
-
x1 * x3,
|
|
188
|
-
layer.w2_weight,
|
|
189
|
-
dimension_numbers=(((2, ), (2, )), ((1, ), (0, ))),
|
|
190
|
-
preferred_element_type=jnp.bfloat16.dtype).transpose(
|
|
191
|
-
0, 1) * layer.w2_weight_scale.squeeze(2)
|
|
192
|
-
|
|
193
|
-
seq_indexes = torch.arange(seqlen, device='jax').unsqueeze(1)
|
|
194
|
-
expert_outs = expert_outs[seq_indexes, expert_indices]
|
|
195
|
-
|
|
196
|
-
# out = torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
|
|
197
|
-
out = call_jax(jax.lax.dot,
|
|
198
|
-
expert_outs,
|
|
199
|
-
expert_weights,
|
|
200
|
-
dimension_numbers=(((1, ), (1, )), ((0, ), (0, ))),
|
|
201
|
-
preferred_element_type=jnp.bfloat16.dtype)
|
|
202
|
-
|
|
203
|
-
return out
|
|
191
|
+
) -> torch.Tensor:
|
|
192
|
+
return fused_moe_apply(
|
|
193
|
+
layer,
|
|
194
|
+
x,
|
|
195
|
+
router_logits,
|
|
196
|
+
self.moe_backend,
|
|
197
|
+
self.mesh,
|
|
198
|
+
self.extra_backend_kwargs,
|
|
199
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Optional
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -6,49 +20,27 @@ import torch
|
|
|
6
20
|
from compressed_tensors.quantization import (QuantizationArgs,
|
|
7
21
|
QuantizationStrategy)
|
|
8
22
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
23
|
+
from torch.nn.parameter import Parameter
|
|
9
24
|
from torchax.interop import jax_view, torch_view
|
|
10
25
|
from torchax.ops.mappings import t2j
|
|
11
26
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
|
|
12
27
|
CompressedTensorsW8A8Fp8
|
|
13
|
-
from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
|
|
14
|
-
per_tensor_dequantize
|
|
15
28
|
|
|
16
|
-
from tpu_inference.layers.
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
29
|
+
from tpu_inference.layers.common.quantization import (dequantize_tensor,
|
|
30
|
+
quantize_tensor)
|
|
31
|
+
from tpu_inference.layers.common.utils import \
|
|
32
|
+
slice_sharded_tensor_for_concatenation
|
|
33
|
+
from tpu_inference.layers.vllm.linear import sharded_quantized_matmul
|
|
34
|
+
from tpu_inference.layers.vllm.process_weights.linear_weights import (
|
|
35
|
+
LinearWeights, process_lienar_weights, shard_linear_weights,
|
|
36
|
+
to_parameter_list)
|
|
37
|
+
from tpu_inference.layers.vllm.quantization.configs import \
|
|
38
|
+
VllmQuantLinearConfig
|
|
39
|
+
from tpu_inference.logger import init_logger
|
|
20
40
|
|
|
21
41
|
P = PartitionSpec
|
|
22
42
|
|
|
23
|
-
|
|
24
|
-
def requantize_with_max_scale(
|
|
25
|
-
weight: torch.Tensor, weight_scale: torch.Tensor,
|
|
26
|
-
logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
27
|
-
dtype = weight.dtype
|
|
28
|
-
dtype_info = torch.finfo(dtype)
|
|
29
|
-
maxval = float(dtype_info.max)
|
|
30
|
-
minval = float(dtype_info.min)
|
|
31
|
-
|
|
32
|
-
max_w_scale = weight_scale.max()
|
|
33
|
-
|
|
34
|
-
unfused_module_in_checkpoint = (weight_scale[-1]
|
|
35
|
-
> torch.finfo(torch.float8_e4m3fn).min)
|
|
36
|
-
|
|
37
|
-
# If unfused checkpoint, need requanize with the single scale.
|
|
38
|
-
if unfused_module_in_checkpoint:
|
|
39
|
-
start = 0
|
|
40
|
-
for idx, logical_width in enumerate(logical_widths):
|
|
41
|
-
# Skip any component with zero width.
|
|
42
|
-
if logical_width == 0:
|
|
43
|
-
continue
|
|
44
|
-
end = start + logical_width
|
|
45
|
-
weight_dq = per_tensor_dequantize(weight[start:end, :],
|
|
46
|
-
weight_scale[idx])
|
|
47
|
-
weight_q = weight_dq / max_w_scale
|
|
48
|
-
weight[start:end, :] = weight_q.clamp(minval, maxval).to(dtype)
|
|
49
|
-
start = end
|
|
50
|
-
|
|
51
|
-
return max_w_scale, weight
|
|
43
|
+
logger = init_logger(__name__)
|
|
52
44
|
|
|
53
45
|
|
|
54
46
|
class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
|
|
@@ -57,15 +49,86 @@ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
|
|
|
57
49
|
self,
|
|
58
50
|
weight_quant: QuantizationArgs,
|
|
59
51
|
is_static_input_scheme: bool,
|
|
60
|
-
|
|
52
|
+
linear_config: VllmQuantLinearConfig,
|
|
61
53
|
):
|
|
62
54
|
super().__init__(weight_quant, is_static_input_scheme)
|
|
63
55
|
|
|
64
|
-
self.
|
|
56
|
+
self.linear_config = linear_config
|
|
65
57
|
|
|
66
58
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
67
|
-
weight = layer.weight
|
|
68
|
-
|
|
59
|
+
weight = t2j(layer.weight, use_dlpack=False)
|
|
60
|
+
delattr(layer, "weight")
|
|
61
|
+
weight_scale = t2j(layer.weight_scale, use_dlpack=False)
|
|
62
|
+
delattr(layer, "weight_scale")
|
|
63
|
+
|
|
64
|
+
if layer.bias is not None and not layer.skip_bias_add:
|
|
65
|
+
if layer.return_bias:
|
|
66
|
+
logger.warning_once("Bias might return incorrect value.")
|
|
67
|
+
bias = t2j(layer.bias, use_dlpack=False)
|
|
68
|
+
delattr(layer, "bias")
|
|
69
|
+
else:
|
|
70
|
+
bias = None
|
|
71
|
+
|
|
72
|
+
per_tensor = self.strategy == QuantizationStrategy.TENSOR
|
|
73
|
+
|
|
74
|
+
@jax.jit
|
|
75
|
+
def process_fp8_linear_weights(
|
|
76
|
+
weight: jax.Array,
|
|
77
|
+
weight_scale: jax.Array,
|
|
78
|
+
bias: jax.Array | None,
|
|
79
|
+
) -> LinearWeights:
|
|
80
|
+
if per_tensor:
|
|
81
|
+
weights = []
|
|
82
|
+
start = 0
|
|
83
|
+
# Multiple weights may have been concatenated. Loop through
|
|
84
|
+
# each weight and perform dequantization.
|
|
85
|
+
for i, output_size in enumerate(
|
|
86
|
+
self.linear_config.output_sizes):
|
|
87
|
+
end = start + output_size
|
|
88
|
+
weights.append(
|
|
89
|
+
dequantize_tensor(weight[start:end], weight_scale[i]))
|
|
90
|
+
start = end
|
|
91
|
+
weight = jnp.concat(weights, axis=0)
|
|
92
|
+
# Requantize into per-tensor.
|
|
93
|
+
weight, weight_scale = quantize_tensor(jnp.float8_e4m3fn,
|
|
94
|
+
weight, None)
|
|
95
|
+
else:
|
|
96
|
+
weight_scale = jnp.squeeze(weight_scale, -1)
|
|
97
|
+
|
|
98
|
+
return process_lienar_weights(
|
|
99
|
+
LinearWeights(
|
|
100
|
+
weight=weight,
|
|
101
|
+
weight_scale=weight_scale,
|
|
102
|
+
zero_point=None,
|
|
103
|
+
bias=bias,
|
|
104
|
+
),
|
|
105
|
+
fused=self.linear_config.fuse_matmuls,
|
|
106
|
+
output_sizes=self.linear_config.output_sizes,
|
|
107
|
+
reorder_size=self.linear_config.n_shards,
|
|
108
|
+
per_tensor=per_tensor,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
weights = process_fp8_linear_weights(weight, weight_scale, bias)
|
|
112
|
+
weights = torch_view(
|
|
113
|
+
shard_linear_weights(
|
|
114
|
+
weights,
|
|
115
|
+
mesh=self.linear_config.mesh,
|
|
116
|
+
weight_p_spec=self.linear_config.weight_sharding,
|
|
117
|
+
bias_p_spec=self.linear_config.bias_sharding,
|
|
118
|
+
per_tensor=per_tensor,
|
|
119
|
+
))
|
|
120
|
+
|
|
121
|
+
if self.linear_config.fuse_matmuls:
|
|
122
|
+
layer.weight = Parameter(weights.weight, requires_grad=False)
|
|
123
|
+
layer.weight_scale = Parameter(weights.weight_scale,
|
|
124
|
+
requires_grad=False)
|
|
125
|
+
if bias is not None:
|
|
126
|
+
layer.bias = Parameter(weights.bias, requires_grad=False)
|
|
127
|
+
else:
|
|
128
|
+
layer.weight = to_parameter_list(weights.weight)
|
|
129
|
+
layer.weight_scale = to_parameter_list(weights.weight_scale)
|
|
130
|
+
if bias is not None:
|
|
131
|
+
layer.bias = to_parameter_list(weights.bias)
|
|
69
132
|
|
|
70
133
|
if self.is_static_input_scheme:
|
|
71
134
|
# In static quant, all input_scales share the same value.
|
|
@@ -74,59 +137,16 @@ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
|
|
|
74
137
|
|
|
75
138
|
input_scale = jax.device_put(
|
|
76
139
|
t2j(input_scale_first, use_dlpack=False),
|
|
77
|
-
NamedSharding(self.
|
|
140
|
+
NamedSharding(self.linear_config.mesh, P()))
|
|
78
141
|
input_scale = torch.nn.Parameter(torch_view(input_scale),
|
|
79
142
|
requires_grad=False)
|
|
80
143
|
delattr(layer, "input_scale")
|
|
81
144
|
layer.input_scale = input_scale
|
|
82
145
|
|
|
83
|
-
# TODO(kyuyeunk): Investigate performance gain from merging scales.
|
|
84
|
-
# By merging input and weight scales, we reduce the number of muls
|
|
85
|
-
# required for dequantization from 2 (for each scales) to 1.
|
|
86
|
-
# weight_scale *= input_scale_first
|
|
87
|
-
|
|
88
|
-
if self.strategy == QuantizationStrategy.TENSOR:
|
|
89
|
-
weight_scale, weight = requantize_with_max_scale(
|
|
90
|
-
weight, weight_scale, self.jax_config.output_sizes)
|
|
91
|
-
weight_scale = jax.device_put(
|
|
92
|
-
t2j(weight_scale, use_dlpack=False),
|
|
93
|
-
NamedSharding(self.jax_config.mesh, P()))
|
|
94
|
-
weight_scale = torch.nn.Parameter(torch_view(weight_scale),
|
|
95
|
-
requires_grad=False)
|
|
96
|
-
else:
|
|
97
|
-
weight_scale = weight_scale.squeeze(-1)
|
|
98
|
-
weight_scale = torch_to_jax_param(
|
|
99
|
-
weight_scale,
|
|
100
|
-
NamedSharding(self.jax_config.mesh,
|
|
101
|
-
self.jax_config.bias_sharding),
|
|
102
|
-
self.jax_config.output_sizes, self.jax_config.n_shards,
|
|
103
|
-
self.jax_config.fuse_matmuls)
|
|
104
|
-
delattr(layer, "weight_scale")
|
|
105
|
-
layer.weight_scale = weight_scale
|
|
106
|
-
|
|
107
|
-
weight = torch_to_jax_param(
|
|
108
|
-
layer.weight,
|
|
109
|
-
NamedSharding(self.jax_config.mesh,
|
|
110
|
-
self.jax_config.weight_sharding),
|
|
111
|
-
self.jax_config.output_sizes, self.jax_config.n_shards,
|
|
112
|
-
self.jax_config.fuse_matmuls)
|
|
113
|
-
delattr(layer, "weight")
|
|
114
|
-
layer.weight = weight
|
|
115
|
-
|
|
116
|
-
if layer.bias is not None:
|
|
117
|
-
bias = torch_to_jax_param(
|
|
118
|
-
layer.bias,
|
|
119
|
-
NamedSharding(self.jax_config.mesh,
|
|
120
|
-
self.jax_config.bias_sharding),
|
|
121
|
-
self.jax_config.output_sizes, self.jax_config.n_shards,
|
|
122
|
-
self.jax_config.fuse_matmuls)
|
|
123
|
-
delattr(layer, "bias")
|
|
124
|
-
layer.bias = bias
|
|
125
|
-
|
|
126
146
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
127
147
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
128
148
|
with jax.named_scope(layer._get_name()):
|
|
129
|
-
if self.
|
|
149
|
+
if self.linear_config.fuse_matmuls:
|
|
130
150
|
return self._apply_fused(layer, x, bias)
|
|
131
151
|
else:
|
|
132
152
|
return self._apply_split(layer, x, bias)
|
|
@@ -157,13 +177,13 @@ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
|
|
|
157
177
|
else:
|
|
158
178
|
outs = sharded_quantized_matmul(x_jax, weight_jax,
|
|
159
179
|
weight_scale_jax,
|
|
160
|
-
self.
|
|
161
|
-
self.
|
|
180
|
+
self.linear_config.mesh,
|
|
181
|
+
self.linear_config.weight_sharding)
|
|
162
182
|
|
|
163
183
|
if bias is not None and not layer.skip_bias_add:
|
|
164
184
|
outs += jax_view(bias)
|
|
165
185
|
outs = slice_sharded_tensor_for_concatenation(
|
|
166
|
-
outs, self.
|
|
186
|
+
outs, self.linear_config.output_sizes, self.linear_config.n_shards)
|
|
167
187
|
return torch_view(jnp.concatenate(outs, axis=-1))
|
|
168
188
|
|
|
169
189
|
def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
@@ -197,10 +217,10 @@ class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
|
|
|
197
217
|
out *= weight_scale_jax * input_scale
|
|
198
218
|
out = out.astype(x_jax.dtype)
|
|
199
219
|
else:
|
|
200
|
-
out = sharded_quantized_matmul(
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
220
|
+
out = sharded_quantized_matmul(
|
|
221
|
+
x_jax, weight_jax, weight_scale_jax,
|
|
222
|
+
self.linear_config.mesh,
|
|
223
|
+
self.linear_config.weight_sharding)
|
|
204
224
|
|
|
205
225
|
if bias is not None and not layer.skip_bias_add:
|
|
206
226
|
out += jax_view(bias[i])
|