tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from jax.sharding import PartitionSpec
|
|
19
|
+
from vllm.attention.layer import Attention
|
|
20
|
+
from vllm.logger import init_logger
|
|
21
|
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
22
|
+
from vllm.model_executor.layers.linear import LinearBase
|
|
23
|
+
from vllm.model_executor.layers.quantization import \
|
|
24
|
+
register_quantization_config
|
|
25
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
26
|
+
QuantizeMethodBase # noqa: E501
|
|
27
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
|
28
|
+
CompressedTensorsConfig, CompressedTensorsKVCacheMethod,
|
|
29
|
+
CompressedTensorsLinearMethod, CompressedTensorsScheme)
|
|
30
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
31
|
+
find_matched_target, should_ignore_layer)
|
|
32
|
+
|
|
33
|
+
from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
|
|
34
|
+
get_tpu_quant_method)
|
|
35
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
36
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
|
|
37
|
+
VllmCompressedTensorsMoEMethod
|
|
38
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
|
|
39
|
+
VllmCompressedTensorsW8A8Fp8
|
|
40
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
|
|
41
|
+
VllmCompressedTensorsW8A8Int8
|
|
42
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
43
|
+
VllmUnquantizedConfig
|
|
44
|
+
|
|
45
|
+
P = PartitionSpec
|
|
46
|
+
logger = init_logger(__name__)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@register_quantization_config(get_tpu_quant_method(COMPRESSED_TENSORS))
|
|
50
|
+
class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def get_name(cls) -> str:
|
|
54
|
+
return COMPRESSED_TENSORS
|
|
55
|
+
|
|
56
|
+
def get_scheme(self,
|
|
57
|
+
layer: torch.nn.Module,
|
|
58
|
+
layer_name: Optional[str] = None
|
|
59
|
+
) -> Optional["CompressedTensorsScheme"]:
|
|
60
|
+
"""
|
|
61
|
+
compressed-tensors supports non uniform in the following way:
|
|
62
|
+
|
|
63
|
+
targets of config_groups: There can be N config_groups which each
|
|
64
|
+
have a quantization scheme. Each config_group has a list of targets
|
|
65
|
+
which can be a full layer_name, a regex for a layer_name, or
|
|
66
|
+
an nn.Module name.
|
|
67
|
+
|
|
68
|
+
Detect whether a layer_name is found in any target and
|
|
69
|
+
use the quantization scheme corresponding to the matched target
|
|
70
|
+
to select the CompressedTensorsScheme used for inference.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
# Will be empty for models with only sparsity
|
|
74
|
+
weight_quant = input_quant = None
|
|
75
|
+
if self.target_scheme_map:
|
|
76
|
+
matched_target = find_matched_target(
|
|
77
|
+
layer_name=layer_name,
|
|
78
|
+
module=layer,
|
|
79
|
+
targets=self.target_scheme_map.keys(),
|
|
80
|
+
fused_mapping=self.packed_modules_mapping,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
scheme_dict = self.target_scheme_map[matched_target]
|
|
84
|
+
weight_quant = scheme_dict.get("weights")
|
|
85
|
+
input_quant = scheme_dict.get("input_activations")
|
|
86
|
+
|
|
87
|
+
if weight_quant is None:
|
|
88
|
+
logger.warning_once("Acceleration for non-quantized schemes is "
|
|
89
|
+
"not supported by Compressed Tensors. "
|
|
90
|
+
"Falling back to UnquantizedLinearMethod")
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
# TODO(kyuyeunk): Add support for different act_quant_format
|
|
94
|
+
|
|
95
|
+
linear_config = self.get_linear_config(layer)
|
|
96
|
+
if self._is_fp8_w8a8(weight_quant, input_quant):
|
|
97
|
+
is_static_input_scheme = input_quant and not input_quant.dynamic
|
|
98
|
+
return VllmCompressedTensorsW8A8Fp8(
|
|
99
|
+
weight_quant=weight_quant,
|
|
100
|
+
is_static_input_scheme=is_static_input_scheme,
|
|
101
|
+
jax_config=linear_config,
|
|
102
|
+
)
|
|
103
|
+
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
|
104
|
+
return VllmCompressedTensorsW8A8Int8(
|
|
105
|
+
strategy=weight_quant.strategy,
|
|
106
|
+
is_static_input_scheme=False,
|
|
107
|
+
input_symmetric=input_quant.symmetric,
|
|
108
|
+
jax_config=linear_config,
|
|
109
|
+
)
|
|
110
|
+
raise NotImplementedError(
|
|
111
|
+
"No compressed-tensors compatible scheme was found.")
|
|
112
|
+
|
|
113
|
+
def get_quant_method(
|
|
114
|
+
self,
|
|
115
|
+
layer: torch.nn.Module,
|
|
116
|
+
prefix: str,
|
|
117
|
+
) -> Optional[QuantizeMethodBase]:
|
|
118
|
+
if should_ignore_layer(prefix,
|
|
119
|
+
ignore=self.ignore,
|
|
120
|
+
fused_mapping=self.packed_modules_mapping):
|
|
121
|
+
return VllmUnquantizedConfig.get_quant_method(self, layer, prefix)
|
|
122
|
+
if isinstance(layer, LinearBase):
|
|
123
|
+
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
|
124
|
+
if scheme is None:
|
|
125
|
+
return VllmUnquantizedConfig.get_quant_method(
|
|
126
|
+
self, layer, prefix)
|
|
127
|
+
layer.scheme = scheme
|
|
128
|
+
return CompressedTensorsLinearMethod(self)
|
|
129
|
+
if isinstance(layer, FusedMoE):
|
|
130
|
+
layer.moe_config = self.get_moe_config(layer)
|
|
131
|
+
return VllmCompressedTensorsMoEMethod.get_moe_method(
|
|
132
|
+
self, layer, layer_name=prefix)
|
|
133
|
+
if isinstance(layer, Attention):
|
|
134
|
+
return CompressedTensorsKVCacheMethod(self)
|
|
135
|
+
return None
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Union
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import torch
|
|
20
|
+
from compressed_tensors.quantization import QuantizationArgs
|
|
21
|
+
from jax.experimental.layout import Format, Layout
|
|
22
|
+
from jax.sharding import Mesh, NamedSharding
|
|
23
|
+
from jax.sharding import PartitionSpec as P
|
|
24
|
+
from torch.nn.parameter import Parameter
|
|
25
|
+
from torchax.interop import jax_view, torch_view
|
|
26
|
+
from torchax.ops.mappings import t2j
|
|
27
|
+
from vllm.logger import init_logger
|
|
28
|
+
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
|
|
29
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
|
30
|
+
CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod)
|
|
31
|
+
|
|
32
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func
|
|
33
|
+
from tpu_inference.layers.vllm.linear_common import \
|
|
34
|
+
reorder_concatenated_tensor_for_sharding
|
|
35
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
36
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
37
|
+
VllmUnquantizedFusedMoEMethod
|
|
38
|
+
|
|
39
|
+
logger = init_logger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class VllmCompressedTensorsMoEMethod(CompressedTensorsMoEMethod):
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def get_moe_method(
|
|
46
|
+
quant_config: "VllmCompressedTensorsConfig", # type: ignore # noqa E501
|
|
47
|
+
layer: torch.nn.Module,
|
|
48
|
+
layer_name: str,
|
|
49
|
+
) -> CompressedTensorsMoEMethod:
|
|
50
|
+
assert isinstance(layer, FusedMoE)
|
|
51
|
+
|
|
52
|
+
# FusedMoE was made by combining multiple Linears so need to
|
|
53
|
+
# make sure quantization config for Linear can target it
|
|
54
|
+
quant_config._add_fused_moe_to_target_scheme_map()
|
|
55
|
+
unfused_names = [
|
|
56
|
+
layer_name + proj_name
|
|
57
|
+
for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"]
|
|
58
|
+
]
|
|
59
|
+
# TODO: refactor this to use expert_mapping and check all layer numbers
|
|
60
|
+
all_scheme_dicts = [
|
|
61
|
+
quant_config.get_scheme_dict(layer, name) for name in unfused_names
|
|
62
|
+
]
|
|
63
|
+
scheme_dict = all_scheme_dicts.pop()
|
|
64
|
+
|
|
65
|
+
# multiple schemes found
|
|
66
|
+
if not all([cur_dict == scheme_dict for cur_dict in all_scheme_dicts]):
|
|
67
|
+
raise ValueError("All MoE projections need to have same "
|
|
68
|
+
"quantization scheme but found multiple")
|
|
69
|
+
|
|
70
|
+
if scheme_dict is None:
|
|
71
|
+
return VllmUnquantizedFusedMoEMethod(layer.moe_config,
|
|
72
|
+
quant_config.mesh)
|
|
73
|
+
|
|
74
|
+
weight_quant = scheme_dict.get("weights")
|
|
75
|
+
input_quant = scheme_dict.get("input_activations")
|
|
76
|
+
|
|
77
|
+
if quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
|
78
|
+
return VllmCompressedTensorsW8A8Fp8MoEMethod(
|
|
79
|
+
weight_quant, input_quant, layer.moe_config, quant_config.mesh)
|
|
80
|
+
else:
|
|
81
|
+
raise RuntimeError(
|
|
82
|
+
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class VllmCompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsW8A8Fp8MoEMethod,
|
|
86
|
+
JaxCommonConfig):
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
weight_quant: QuantizationArgs,
|
|
91
|
+
input_quant: QuantizationArgs,
|
|
92
|
+
moe: FusedMoEConfig,
|
|
93
|
+
mesh: Mesh,
|
|
94
|
+
):
|
|
95
|
+
super().__init__(weight_quant, input_quant, moe)
|
|
96
|
+
self.mesh = mesh
|
|
97
|
+
|
|
98
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Docstring for process_weights_after_loading
|
|
101
|
+
|
|
102
|
+
:param self: Description
|
|
103
|
+
:param layer: Description
|
|
104
|
+
:type layer: torch.nn.Module
|
|
105
|
+
|
|
106
|
+
Steps:
|
|
107
|
+
1. Read weights from layer object and convert to jax arrays
|
|
108
|
+
2. Interleave concat w13 weights
|
|
109
|
+
3. Shard weights for tp (rowwise w13, colwise w2)
|
|
110
|
+
4. Initialize Params as torch.nn.Parameter
|
|
111
|
+
a. w13_weight - float8_e4m3fn shape: (num_experts, 2 x intermediate_size, input_size)
|
|
112
|
+
b. w2_weight - float8_e4m3fn shape: (num_experts, output_size, intermediate_size)
|
|
113
|
+
c. w13_weight_scale - FP32 shape: (num_experts, 2 x intermediate_size, 1)
|
|
114
|
+
d. w2_weight_scale - FP32shape: (num_experts, output_size, 1)
|
|
115
|
+
"""
|
|
116
|
+
assert isinstance(layer, FusedMoE)
|
|
117
|
+
|
|
118
|
+
# Read weights from layer object
|
|
119
|
+
w13_weight = t2j(
|
|
120
|
+
layer.w13_weight, use_dlpack=False
|
|
121
|
+
) # float8_e4m3fn shape: (num_experts, 2 x intermediate_size, input_size)
|
|
122
|
+
w13_weight_scale = t2j(
|
|
123
|
+
layer.w13_weight_scale, use_dlpack=False
|
|
124
|
+
) # FP32 shape: (num_experts, 2 x intermediate_size, 1)
|
|
125
|
+
w2_weight = t2j(
|
|
126
|
+
layer.w2_weight, use_dlpack=False
|
|
127
|
+
) # float8_e4m3fn shape: (num_experts, output_size, intermediate_size)
|
|
128
|
+
w2_weight_scale = t2j(layer.w2_weight_scale, use_dlpack=False)
|
|
129
|
+
w13_weight_scale = w13_weight_scale.astype(jnp.bfloat16)
|
|
130
|
+
w2_weight_scale = w2_weight_scale.astype(jnp.bfloat16)
|
|
131
|
+
intermediate_size = layer.w13_weight.shape[1] // 2
|
|
132
|
+
assert intermediate_size == w2_weight.shape[-1]
|
|
133
|
+
n_shards = self.mesh.shape["model"]
|
|
134
|
+
assert intermediate_size % n_shards == 0
|
|
135
|
+
num_experts, hidden_size, intermediate_size = w2_weight.shape
|
|
136
|
+
assert w2_weight_scale.shape == (num_experts, hidden_size, 1)
|
|
137
|
+
assert w13_weight.shape == (num_experts, 2 * intermediate_size,
|
|
138
|
+
hidden_size)
|
|
139
|
+
assert w13_weight_scale.shape == (num_experts, 2 * intermediate_size,
|
|
140
|
+
1)
|
|
141
|
+
|
|
142
|
+
if not layer.use_ep:
|
|
143
|
+
# Interleave concat w13 weights
|
|
144
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
145
|
+
w13_weight,
|
|
146
|
+
split_sizes=(intermediate_size, intermediate_size),
|
|
147
|
+
dim=1,
|
|
148
|
+
n_shards=n_shards,
|
|
149
|
+
)
|
|
150
|
+
# Interleave concat w13 weight scales
|
|
151
|
+
w13_weight_scale = reorder_concatenated_tensor_for_sharding(
|
|
152
|
+
w13_weight_scale,
|
|
153
|
+
split_sizes=(intermediate_size, intermediate_size),
|
|
154
|
+
dim=1,
|
|
155
|
+
n_shards=n_shards,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# 160,5120,1 -> 160,1,5120
|
|
159
|
+
w13_weight_scale = jnp.swapaxes(w13_weight_scale, 1, 2)
|
|
160
|
+
# 160,1,5120 -> 160, 1, 1, 5120 (num_experts, num_blocks, 1, outer_dim)
|
|
161
|
+
w13_weight_scale = jnp.expand_dims(w13_weight_scale, 2)
|
|
162
|
+
w2_weight_scale = jnp.swapaxes(w2_weight_scale, 1, 2)
|
|
163
|
+
w2_weight_scale = jnp.expand_dims(w2_weight_scale, 2)
|
|
164
|
+
|
|
165
|
+
if layer.use_ep:
|
|
166
|
+
# Apply EP sharding
|
|
167
|
+
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
168
|
+
|
|
169
|
+
w13_weight = jax.lax.with_sharding_constraint(
|
|
170
|
+
w13_weight, ep_sharding)
|
|
171
|
+
w2_weight = jax.lax.with_sharding_constraint(
|
|
172
|
+
w2_weight, ep_sharding)
|
|
173
|
+
|
|
174
|
+
w13_weight_scale = jax.lax.with_sharding_constraint(
|
|
175
|
+
w13_weight_scale, ep_sharding)
|
|
176
|
+
w2_weight_scale = jax.lax.with_sharding_constraint(
|
|
177
|
+
w2_weight_scale, ep_sharding)
|
|
178
|
+
|
|
179
|
+
else:
|
|
180
|
+
# Shard weights for tp (rowwise w13, colwise w2)
|
|
181
|
+
w13_format = Format(
|
|
182
|
+
Layout((0, 1, 2)), # expert, 2xintermed, input
|
|
183
|
+
NamedSharding(self.mesh, P(None, "model", None)),
|
|
184
|
+
) # rowwise sharding on intermed dim
|
|
185
|
+
|
|
186
|
+
w13_scale_format = Format(
|
|
187
|
+
Layout(
|
|
188
|
+
(0, 1, 2, 3)), # (num_experts, num_blocks, 1, outer_dim)
|
|
189
|
+
NamedSharding(self.mesh, P(None, None, None, "model")),
|
|
190
|
+
) # col wise GMM sharding on intermed dim
|
|
191
|
+
|
|
192
|
+
# Local shard shape: (num_experts, 2 x (intermediate_size // n_shards), input_size)
|
|
193
|
+
w13_weight = jax.lax.with_sharding_constraint(
|
|
194
|
+
w13_weight, w13_format)
|
|
195
|
+
# Local shard shape: (num_experts, (intermediate_size // n_shards), 1)
|
|
196
|
+
w13_weight_scale = jax.lax.with_sharding_constraint(
|
|
197
|
+
w13_weight_scale, w13_scale_format)
|
|
198
|
+
|
|
199
|
+
# Shard weights for tp (colwise w2)
|
|
200
|
+
w2_format = Format(
|
|
201
|
+
Layout((0, 1, 2)), # expert, intermed, hidden
|
|
202
|
+
NamedSharding(self.mesh, P(None, None, "model")),
|
|
203
|
+
)
|
|
204
|
+
# Local shard shape: (num_experts, hidden, (intermediate_size // n_shards))
|
|
205
|
+
# # (num_experts, num_blocks, 1, outer_dim)
|
|
206
|
+
w2_weight = jax.lax.with_sharding_constraint(w2_weight, w2_format)
|
|
207
|
+
|
|
208
|
+
w2_scale_format = Format(
|
|
209
|
+
Layout((0, 1, 2, 3)), # expert, intermed, 1
|
|
210
|
+
NamedSharding(self.mesh, P(None, None, None, None)),
|
|
211
|
+
)
|
|
212
|
+
# Local shard shape: (num_experts, intermediate_size // n_shards, 1)
|
|
213
|
+
w2_weight_scale = jax.lax.with_sharding_constraint(
|
|
214
|
+
w2_weight_scale, w2_scale_format)
|
|
215
|
+
|
|
216
|
+
w13_weight = Parameter(torch_view(w13_weight), requires_grad=False)
|
|
217
|
+
w13_weight_scale = Parameter(torch_view(w13_weight_scale),
|
|
218
|
+
requires_grad=False)
|
|
219
|
+
w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
220
|
+
w2_weight_scale = Parameter(torch_view(w2_weight_scale),
|
|
221
|
+
requires_grad=False)
|
|
222
|
+
|
|
223
|
+
layer.w13_weight = w13_weight
|
|
224
|
+
layer.w13_weight_scale = w13_weight_scale
|
|
225
|
+
layer.w2_weight = w2_weight
|
|
226
|
+
layer.w2_weight_scale = w2_weight_scale
|
|
227
|
+
|
|
228
|
+
def apply(
|
|
229
|
+
self,
|
|
230
|
+
layer: torch.nn.Module,
|
|
231
|
+
x: torch.Tensor,
|
|
232
|
+
router_logits: torch.Tensor,
|
|
233
|
+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
234
|
+
assert isinstance(layer, FusedMoE)
|
|
235
|
+
if layer.activation != "silu":
|
|
236
|
+
raise NotImplementedError(
|
|
237
|
+
"Only silu is supported for activation function.")
|
|
238
|
+
if layer.scoring_func != "softmax":
|
|
239
|
+
raise NotImplementedError(
|
|
240
|
+
"Only softmax is supported for scoring_func")
|
|
241
|
+
|
|
242
|
+
# TODO: Use MoE kernel when it supports fp8
|
|
243
|
+
x = jax_view(x)
|
|
244
|
+
w13_weight = jax_view(layer.w13_weight)
|
|
245
|
+
w2_weight = jax_view(layer.w2_weight)
|
|
246
|
+
w13_weight_scale = jax_view(layer.w13_weight_scale)
|
|
247
|
+
w2_weight_scale = jax_view(layer.w2_weight_scale)
|
|
248
|
+
gating_output = jax_view(router_logits)
|
|
249
|
+
out = torch_view(
|
|
250
|
+
fused_moe_func(
|
|
251
|
+
hidden_states=x,
|
|
252
|
+
w1=w13_weight,
|
|
253
|
+
w2=w2_weight,
|
|
254
|
+
w1_scale=w13_weight_scale,
|
|
255
|
+
w2_scale=w2_weight_scale,
|
|
256
|
+
w1_bias=None,
|
|
257
|
+
w2_bias=None,
|
|
258
|
+
gating_output=gating_output,
|
|
259
|
+
topk=layer.top_k,
|
|
260
|
+
renormalize=layer.renormalize,
|
|
261
|
+
mesh=self.mesh,
|
|
262
|
+
use_ep=layer.use_ep,
|
|
263
|
+
activation=layer.activation,
|
|
264
|
+
))
|
|
265
|
+
|
|
266
|
+
return out
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import torch
|
|
20
|
+
from compressed_tensors.quantization import (QuantizationArgs,
|
|
21
|
+
QuantizationStrategy)
|
|
22
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
23
|
+
from torchax.interop import jax_view, torch_view
|
|
24
|
+
from torchax.ops.mappings import t2j
|
|
25
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
|
|
26
|
+
CompressedTensorsW8A8Fp8
|
|
27
|
+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
|
|
28
|
+
per_tensor_dequantize
|
|
29
|
+
|
|
30
|
+
from tpu_inference.layers.vllm.linear_common import (
|
|
31
|
+
sharded_quantized_matmul, slice_sharded_tensor_for_concatenation,
|
|
32
|
+
torch_to_jax_param)
|
|
33
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonLinearConfig
|
|
34
|
+
|
|
35
|
+
P = PartitionSpec
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def requantize_with_max_scale(
|
|
39
|
+
weight: torch.Tensor, weight_scale: torch.Tensor,
|
|
40
|
+
logical_widths: list[int]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
41
|
+
dtype = weight.dtype
|
|
42
|
+
dtype_info = torch.finfo(dtype)
|
|
43
|
+
maxval = float(dtype_info.max)
|
|
44
|
+
minval = float(dtype_info.min)
|
|
45
|
+
|
|
46
|
+
max_w_scale = weight_scale.max()
|
|
47
|
+
|
|
48
|
+
unfused_module_in_checkpoint = (weight_scale[-1]
|
|
49
|
+
> torch.finfo(torch.float8_e4m3fn).min)
|
|
50
|
+
|
|
51
|
+
# If unfused checkpoint, need requanize with the single scale.
|
|
52
|
+
if unfused_module_in_checkpoint:
|
|
53
|
+
start = 0
|
|
54
|
+
for idx, logical_width in enumerate(logical_widths):
|
|
55
|
+
# Skip any component with zero width.
|
|
56
|
+
if logical_width == 0:
|
|
57
|
+
continue
|
|
58
|
+
end = start + logical_width
|
|
59
|
+
weight_dq = per_tensor_dequantize(weight[start:end, :],
|
|
60
|
+
weight_scale[idx])
|
|
61
|
+
weight_q = weight_dq / max_w_scale
|
|
62
|
+
weight[start:end, :] = weight_q.clamp(minval, maxval).to(dtype)
|
|
63
|
+
start = end
|
|
64
|
+
|
|
65
|
+
return max_w_scale, weight
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class VllmCompressedTensorsW8A8Fp8(CompressedTensorsW8A8Fp8):
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
weight_quant: QuantizationArgs,
|
|
73
|
+
is_static_input_scheme: bool,
|
|
74
|
+
jax_config: JaxCommonLinearConfig,
|
|
75
|
+
):
|
|
76
|
+
super().__init__(weight_quant, is_static_input_scheme)
|
|
77
|
+
|
|
78
|
+
self.jax_config = jax_config
|
|
79
|
+
|
|
80
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
81
|
+
weight = layer.weight
|
|
82
|
+
weight_scale = layer.weight_scale
|
|
83
|
+
|
|
84
|
+
if self.is_static_input_scheme:
|
|
85
|
+
# In static quant, all input_scales share the same value.
|
|
86
|
+
assert layer.input_scale.min() == layer.input_scale.max()
|
|
87
|
+
input_scale_first = layer.input_scale[0]
|
|
88
|
+
|
|
89
|
+
input_scale = jax.device_put(
|
|
90
|
+
t2j(input_scale_first, use_dlpack=False),
|
|
91
|
+
NamedSharding(self.jax_config.mesh, P()))
|
|
92
|
+
input_scale = torch.nn.Parameter(torch_view(input_scale),
|
|
93
|
+
requires_grad=False)
|
|
94
|
+
delattr(layer, "input_scale")
|
|
95
|
+
layer.input_scale = input_scale
|
|
96
|
+
|
|
97
|
+
# TODO(kyuyeunk): Investigate performance gain from merging scales.
|
|
98
|
+
# By merging input and weight scales, we reduce the number of muls
|
|
99
|
+
# required for dequantization from 2 (for each scales) to 1.
|
|
100
|
+
# weight_scale *= input_scale_first
|
|
101
|
+
|
|
102
|
+
if self.strategy == QuantizationStrategy.TENSOR:
|
|
103
|
+
weight_scale, weight = requantize_with_max_scale(
|
|
104
|
+
weight, weight_scale, self.jax_config.output_sizes)
|
|
105
|
+
weight_scale = jax.device_put(
|
|
106
|
+
t2j(weight_scale, use_dlpack=False),
|
|
107
|
+
NamedSharding(self.jax_config.mesh, P()))
|
|
108
|
+
weight_scale = torch.nn.Parameter(torch_view(weight_scale),
|
|
109
|
+
requires_grad=False)
|
|
110
|
+
else:
|
|
111
|
+
weight_scale = weight_scale.squeeze(-1)
|
|
112
|
+
weight_scale = torch_to_jax_param(
|
|
113
|
+
weight_scale,
|
|
114
|
+
NamedSharding(self.jax_config.mesh,
|
|
115
|
+
self.jax_config.bias_sharding),
|
|
116
|
+
self.jax_config.output_sizes, self.jax_config.n_shards,
|
|
117
|
+
self.jax_config.fuse_matmuls)
|
|
118
|
+
delattr(layer, "weight_scale")
|
|
119
|
+
layer.weight_scale = weight_scale
|
|
120
|
+
|
|
121
|
+
weight = torch_to_jax_param(
|
|
122
|
+
layer.weight,
|
|
123
|
+
NamedSharding(self.jax_config.mesh,
|
|
124
|
+
self.jax_config.weight_sharding),
|
|
125
|
+
self.jax_config.output_sizes, self.jax_config.n_shards,
|
|
126
|
+
self.jax_config.fuse_matmuls)
|
|
127
|
+
delattr(layer, "weight")
|
|
128
|
+
layer.weight = weight
|
|
129
|
+
|
|
130
|
+
if layer.bias is not None:
|
|
131
|
+
bias = torch_to_jax_param(
|
|
132
|
+
layer.bias,
|
|
133
|
+
NamedSharding(self.jax_config.mesh,
|
|
134
|
+
self.jax_config.bias_sharding),
|
|
135
|
+
self.jax_config.output_sizes, self.jax_config.n_shards,
|
|
136
|
+
self.jax_config.fuse_matmuls)
|
|
137
|
+
delattr(layer, "bias")
|
|
138
|
+
layer.bias = bias
|
|
139
|
+
|
|
140
|
+
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
141
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
142
|
+
with jax.named_scope(layer._get_name()):
|
|
143
|
+
if self.jax_config.fuse_matmuls:
|
|
144
|
+
return self._apply_fused(layer, x, bias)
|
|
145
|
+
else:
|
|
146
|
+
return self._apply_split(layer, x, bias)
|
|
147
|
+
|
|
148
|
+
def _apply_fused(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
149
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
150
|
+
x_jax = jax_view(x)
|
|
151
|
+
weight_jax = jax_view(layer.weight)
|
|
152
|
+
weight_scale_jax = jax_view(layer.weight_scale)
|
|
153
|
+
|
|
154
|
+
if self.is_static_input_scheme:
|
|
155
|
+
# TODO(kyuyeunk): Add kernel support for static quant
|
|
156
|
+
input_scale = jax_view(layer.input_scale)
|
|
157
|
+
dtype_info = jnp.finfo(weight_jax.dtype)
|
|
158
|
+
maxval = float(dtype_info.max)
|
|
159
|
+
minval = float(dtype_info.min)
|
|
160
|
+
x_q = jnp.clip(x_jax / input_scale.astype(x_jax.dtype), minval,
|
|
161
|
+
maxval).astype(weight_jax.dtype)
|
|
162
|
+
|
|
163
|
+
outs = jax.lax.dot_general(
|
|
164
|
+
x_q,
|
|
165
|
+
weight_jax,
|
|
166
|
+
(((1, ), (1, )), ((), ())),
|
|
167
|
+
preferred_element_type=jnp.float32,
|
|
168
|
+
)
|
|
169
|
+
outs *= weight_scale_jax
|
|
170
|
+
outs = outs.astype(x_jax.dtype)
|
|
171
|
+
else:
|
|
172
|
+
outs = sharded_quantized_matmul(x_jax, weight_jax,
|
|
173
|
+
weight_scale_jax,
|
|
174
|
+
self.jax_config.mesh,
|
|
175
|
+
self.jax_config.weight_sharding)
|
|
176
|
+
|
|
177
|
+
if bias is not None and not layer.skip_bias_add:
|
|
178
|
+
outs += jax_view(bias)
|
|
179
|
+
outs = slice_sharded_tensor_for_concatenation(
|
|
180
|
+
outs, self.jax_config.output_sizes, self.jax_config.n_shards)
|
|
181
|
+
return torch_view(jnp.concatenate(outs, axis=-1))
|
|
182
|
+
|
|
183
|
+
def _apply_split(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
184
|
+
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
185
|
+
assert isinstance(layer.weight, torch.nn.ParameterList)
|
|
186
|
+
|
|
187
|
+
x_jax = jax_view(x)
|
|
188
|
+
outs = []
|
|
189
|
+
for i, (weight, weight_scale) in enumerate(
|
|
190
|
+
zip(layer.weight, layer.weight_scale)):
|
|
191
|
+
weight_jax = jax_view(weight)
|
|
192
|
+
weight_scale_jax = jax_view(weight_scale)
|
|
193
|
+
|
|
194
|
+
if self.is_static_input_scheme:
|
|
195
|
+
# TODO(kyuyeunk): Add kernel support for static quant
|
|
196
|
+
input_scale = jax_view(layer.input_scale)
|
|
197
|
+
dtype_info = jnp.finfo(weight_jax.dtype)
|
|
198
|
+
maxval = float(dtype_info.max)
|
|
199
|
+
minval = float(dtype_info.min)
|
|
200
|
+
x_q = jnp.clip(x_jax / input_scale.astype(x_jax.dtype), minval,
|
|
201
|
+
maxval).astype(weight_jax.dtype)
|
|
202
|
+
|
|
203
|
+
out = jax.lax.dot_general(
|
|
204
|
+
x_q,
|
|
205
|
+
weight_jax,
|
|
206
|
+
(((1, ), (1, )), ((), ())),
|
|
207
|
+
preferred_element_type=jnp.float32,
|
|
208
|
+
)
|
|
209
|
+
# TODO(kyuyeunk): Investigate performance gain from merging scales.
|
|
210
|
+
# out *= weight_scale_jax
|
|
211
|
+
out *= weight_scale_jax * input_scale
|
|
212
|
+
out = out.astype(x_jax.dtype)
|
|
213
|
+
else:
|
|
214
|
+
out = sharded_quantized_matmul(x_jax, weight_jax,
|
|
215
|
+
weight_scale_jax,
|
|
216
|
+
self.jax_config.mesh,
|
|
217
|
+
self.jax_config.weight_sharding)
|
|
218
|
+
|
|
219
|
+
if bias is not None and not layer.skip_bias_add:
|
|
220
|
+
out += jax_view(bias[i])
|
|
221
|
+
outs.append(out)
|
|
222
|
+
return torch_view(jnp.concatenate(outs, axis=-1))
|