tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
CHANGED
|
@@ -1,21 +1,41 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Optional
|
|
2
16
|
|
|
3
17
|
import jax
|
|
4
18
|
import jax.numpy as jnp
|
|
5
19
|
import torch
|
|
6
20
|
from compressed_tensors.quantization import QuantizationStrategy
|
|
7
|
-
from jax.sharding import
|
|
21
|
+
from jax.sharding import PartitionSpec
|
|
22
|
+
from torch.nn.parameter import Parameter
|
|
8
23
|
from torchax.interop import jax_view, torch_view
|
|
9
|
-
from
|
|
24
|
+
from torchax.ops.mappings import t2j
|
|
10
25
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
|
|
11
26
|
CompressedTensorsW8A8Int8
|
|
12
27
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import \
|
|
13
28
|
convert_to_channelwise
|
|
14
29
|
|
|
15
|
-
from tpu_inference.layers.
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
from tpu_inference.layers.vllm.
|
|
30
|
+
from tpu_inference.layers.common.utils import \
|
|
31
|
+
slice_sharded_tensor_for_concatenation
|
|
32
|
+
from tpu_inference.layers.vllm.linear import sharded_quantized_matmul
|
|
33
|
+
from tpu_inference.layers.vllm.process_weights.linear_weights import (
|
|
34
|
+
LinearWeights, process_lienar_weights, shard_linear_weights,
|
|
35
|
+
to_parameter_list)
|
|
36
|
+
from tpu_inference.layers.vllm.quantization.configs import \
|
|
37
|
+
VllmQuantLinearConfig
|
|
38
|
+
from tpu_inference.logger import init_logger
|
|
19
39
|
|
|
20
40
|
P = PartitionSpec
|
|
21
41
|
logger = init_logger(__name__)
|
|
@@ -24,23 +44,15 @@ logger = init_logger(__name__)
|
|
|
24
44
|
class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
|
|
25
45
|
|
|
26
46
|
def __init__(self, strategy: str, is_static_input_scheme: bool,
|
|
27
|
-
input_symmetric: bool,
|
|
47
|
+
input_symmetric: bool, linear_config: VllmQuantLinearConfig):
|
|
28
48
|
super().__init__(strategy, is_static_input_scheme, input_symmetric)
|
|
29
49
|
|
|
30
|
-
self.
|
|
31
|
-
self.is_channelwise = (self.strategy == QuantizationStrategy.CHANNEL)
|
|
50
|
+
self.linear_config = linear_config
|
|
51
|
+
self.is_channelwise = (self.strategy == QuantizationStrategy.CHANNEL)
|
|
32
52
|
|
|
33
53
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
34
|
-
weight =
|
|
35
|
-
layer.weight,
|
|
36
|
-
NamedSharding(self.jax_config.mesh,
|
|
37
|
-
self.jax_config.weight_sharding),
|
|
38
|
-
self.jax_config.output_sizes,
|
|
39
|
-
self.jax_config.n_shards,
|
|
40
|
-
self.jax_config.fuse_matmuls,
|
|
41
|
-
)
|
|
54
|
+
weight = t2j(layer.weight, use_dlpack=False)
|
|
42
55
|
delattr(layer, "weight")
|
|
43
|
-
layer.weight = weight
|
|
44
56
|
|
|
45
57
|
weight_scale = layer.weight_scale
|
|
46
58
|
is_fused_module = len(layer.logical_widths) > 1
|
|
@@ -48,31 +60,55 @@ class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
|
|
|
48
60
|
weight_scale = convert_to_channelwise(weight_scale,
|
|
49
61
|
layer.logical_widths)
|
|
50
62
|
weight_scale = weight_scale.squeeze(-1)
|
|
51
|
-
|
|
52
|
-
weight_scale = torch_to_jax_param(
|
|
53
|
-
weight_scale,
|
|
54
|
-
NamedSharding(self.jax_config.mesh, self.jax_config.bias_sharding),
|
|
55
|
-
self.jax_config.output_sizes,
|
|
56
|
-
self.jax_config.n_shards,
|
|
57
|
-
self.jax_config.fuse_matmuls,
|
|
58
|
-
)
|
|
63
|
+
weight_scale = t2j(weight_scale, use_dlpack=False)
|
|
59
64
|
delattr(layer, "weight_scale")
|
|
60
|
-
layer.weight_scale = weight_scale
|
|
61
65
|
|
|
62
66
|
if layer.bias is not None and not layer.skip_bias_add:
|
|
63
67
|
if layer.return_bias:
|
|
64
68
|
logger.warning_once("Bias might return incorrect value.")
|
|
65
|
-
|
|
66
|
-
bias = torch_to_jax_param(
|
|
67
|
-
layer.bias,
|
|
68
|
-
NamedSharding(self.jax_config.mesh,
|
|
69
|
-
self.jax_config.bias_sharding),
|
|
70
|
-
self.jax_config.output_sizes,
|
|
71
|
-
self.jax_config.n_shards,
|
|
72
|
-
self.jax_config.fuse_matmuls,
|
|
73
|
-
)
|
|
69
|
+
bias = t2j(layer.bias, use_dlpack=False)
|
|
74
70
|
delattr(layer, "bias")
|
|
75
|
-
|
|
71
|
+
else:
|
|
72
|
+
bias = None
|
|
73
|
+
|
|
74
|
+
@jax.jit
|
|
75
|
+
def process_int8_linear_weights(
|
|
76
|
+
weight: jax.Array,
|
|
77
|
+
weight_scale: jax.Array,
|
|
78
|
+
bias: jax.Array | None,
|
|
79
|
+
) -> LinearWeights:
|
|
80
|
+
return process_lienar_weights(
|
|
81
|
+
LinearWeights(
|
|
82
|
+
weight=weight,
|
|
83
|
+
weight_scale=weight_scale,
|
|
84
|
+
zero_point=None,
|
|
85
|
+
bias=bias,
|
|
86
|
+
),
|
|
87
|
+
fused=self.linear_config.fuse_matmuls,
|
|
88
|
+
output_sizes=self.linear_config.output_sizes,
|
|
89
|
+
reorder_size=self.linear_config.n_shards,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
weights = process_int8_linear_weights(weight, weight_scale, bias)
|
|
93
|
+
weights = torch_view(
|
|
94
|
+
shard_linear_weights(
|
|
95
|
+
weights,
|
|
96
|
+
mesh=self.linear_config.mesh,
|
|
97
|
+
weight_p_spec=self.linear_config.weight_sharding,
|
|
98
|
+
bias_p_spec=self.linear_config.bias_sharding,
|
|
99
|
+
))
|
|
100
|
+
|
|
101
|
+
if self.linear_config.fuse_matmuls:
|
|
102
|
+
layer.weight = Parameter(weights.weight, requires_grad=False)
|
|
103
|
+
layer.weight_scale = Parameter(weights.weight_scale,
|
|
104
|
+
requires_grad=False)
|
|
105
|
+
if bias is not None:
|
|
106
|
+
layer.bias = Parameter(weights.bias, requires_grad=False)
|
|
107
|
+
else:
|
|
108
|
+
layer.weight = to_parameter_list(weights.weight)
|
|
109
|
+
layer.weight_scale = to_parameter_list(weights.weight_scale)
|
|
110
|
+
if bias is not None:
|
|
111
|
+
layer.bias = to_parameter_list(weights.bias)
|
|
76
112
|
|
|
77
113
|
# TODO(kyuyeunk): Support static range input quantization.
|
|
78
114
|
assert getattr(layer, "input_scale", None) is None
|
|
@@ -82,7 +118,7 @@ class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
|
|
|
82
118
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
|
83
119
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
84
120
|
with jax.named_scope(layer._get_name()):
|
|
85
|
-
if self.
|
|
121
|
+
if self.linear_config.fuse_matmuls:
|
|
86
122
|
out = self._apply_fused(layer, x, bias)
|
|
87
123
|
else:
|
|
88
124
|
out = self._apply_split(layer, x, bias)
|
|
@@ -99,14 +135,14 @@ class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
|
|
|
99
135
|
x_jax,
|
|
100
136
|
weight_jax,
|
|
101
137
|
weight_scale_jax,
|
|
102
|
-
self.
|
|
103
|
-
self.
|
|
138
|
+
self.linear_config.mesh,
|
|
139
|
+
self.linear_config.weight_sharding,
|
|
104
140
|
)
|
|
105
141
|
if bias is not None and not layer.skip_bias_add:
|
|
106
142
|
outs += jax_view(bias)
|
|
107
143
|
|
|
108
144
|
outs = slice_sharded_tensor_for_concatenation(
|
|
109
|
-
outs, self.
|
|
145
|
+
outs, self.linear_config.output_sizes, self.linear_config.n_shards)
|
|
110
146
|
out = jnp.concatenate(outs, axis=-1)
|
|
111
147
|
return torch_view(out)
|
|
112
148
|
|
|
@@ -125,8 +161,8 @@ class VllmCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8):
|
|
|
125
161
|
x_jax,
|
|
126
162
|
weight_jax,
|
|
127
163
|
weight_scale_jax,
|
|
128
|
-
self.
|
|
129
|
-
self.
|
|
164
|
+
self.linear_config.mesh,
|
|
165
|
+
self.linear_config.weight_sharding,
|
|
130
166
|
)
|
|
131
167
|
if bias is not None and not layer.skip_bias_add:
|
|
132
168
|
out += jax_view(bias[i])
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import torchax
|
|
2
16
|
from jax.sharding import Mesh, PartitionSpec
|
|
3
17
|
from vllm.config import VllmConfig
|
|
@@ -11,9 +25,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
|
11
25
|
ReplicatedLinear,
|
|
12
26
|
RowParallelLinear)
|
|
13
27
|
|
|
14
|
-
from tpu_inference.layers.
|
|
28
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
29
|
+
from tpu_inference.layers.vllm.process_weights.linear_weights import \
|
|
15
30
|
get_model_matmul_fusion_assignment
|
|
16
|
-
from tpu_inference.utils import TPU_SECOND_LAST_MINOR
|
|
31
|
+
from tpu_inference.utils import TPU_SECOND_LAST_MINOR, get_mesh_shape_product
|
|
17
32
|
|
|
18
33
|
# yapf: enable
|
|
19
34
|
|
|
@@ -22,7 +37,7 @@ P = PartitionSpec
|
|
|
22
37
|
logger = init_logger(__name__)
|
|
23
38
|
|
|
24
39
|
|
|
25
|
-
class
|
|
40
|
+
class VllmQuantLinearConfig:
|
|
26
41
|
|
|
27
42
|
def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
|
|
28
43
|
assert isinstance(layer, LinearBase)
|
|
@@ -35,14 +50,18 @@ class JaxCommonLinearConfig:
|
|
|
35
50
|
self.input_sharding = None
|
|
36
51
|
self.output_sharding = None
|
|
37
52
|
|
|
53
|
+
self.tp_size = get_mesh_shape_product(self.mesh,
|
|
54
|
+
ShardingAxisName.MLP_TENSOR)
|
|
55
|
+
|
|
38
56
|
if isinstance(layer, RowParallelLinear):
|
|
39
|
-
self.weight_sharding = P(None,
|
|
57
|
+
self.weight_sharding = P(None, ShardingAxisName.ATTN_HEAD)
|
|
40
58
|
if self.enable_sp:
|
|
41
|
-
self.output_sharding = P(
|
|
59
|
+
self.output_sharding = P(ShardingAxisName.MLP_TENSOR, None)
|
|
42
60
|
elif isinstance(layer, ColumnParallelLinear):
|
|
43
|
-
self.weight_sharding = P(
|
|
61
|
+
self.weight_sharding = P(ShardingAxisName.ATTN_HEAD, None)
|
|
62
|
+
|
|
44
63
|
if self.enable_sp:
|
|
45
|
-
self.input_sharding = P(
|
|
64
|
+
self.input_sharding = P(ShardingAxisName.MLP_TENSOR, None)
|
|
46
65
|
|
|
47
66
|
if isinstance(layer, MergedColumnParallelLinear) or isinstance(
|
|
48
67
|
layer, QKVParallelLinear):
|
|
@@ -61,35 +80,28 @@ class JaxCommonLinearConfig:
|
|
|
61
80
|
" bad performance.", type(layer))
|
|
62
81
|
|
|
63
82
|
self.bias_sharding = P(self.weight_sharding[0])
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
for axis in self.weight_sharding[0]:
|
|
67
|
-
self.n_shards *= self.mesh.shape.get(axis, 1)
|
|
68
|
-
else:
|
|
69
|
-
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
|
|
83
|
+
self.n_shards = get_mesh_shape_product(self.mesh,
|
|
84
|
+
self.weight_sharding[0])
|
|
70
85
|
|
|
71
86
|
def get_input_sharding(self, x: torchax.tensor.Tensor):
|
|
72
|
-
if self.enable_sp:
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
return None
|
|
87
|
+
if not self.enable_sp:
|
|
88
|
+
return None
|
|
89
|
+
token_num = x.shape[0]
|
|
90
|
+
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
|
|
91
|
+
if token_num // self.tp_size < TPU_SECOND_LAST_MINOR:
|
|
92
|
+
return None
|
|
79
93
|
return self.input_sharding
|
|
80
94
|
|
|
81
95
|
def get_output_sharding(self, x: torchax.tensor.Tensor):
|
|
82
96
|
if self.enable_sp:
|
|
83
97
|
token_num = x.shape[0]
|
|
84
98
|
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
|
|
85
|
-
if token_num // self.
|
|
86
|
-
return self.output_sharding
|
|
87
|
-
else:
|
|
99
|
+
if token_num // self.tp_size < TPU_SECOND_LAST_MINOR:
|
|
88
100
|
return None
|
|
89
101
|
return self.output_sharding
|
|
90
102
|
|
|
91
103
|
|
|
92
|
-
class
|
|
104
|
+
class VllmQuantConfig:
|
|
93
105
|
vllm_config: VllmConfig
|
|
94
106
|
mesh: Mesh
|
|
95
107
|
|
|
@@ -98,9 +110,9 @@ class JaxCommonConfig:
|
|
|
98
110
|
cls.vllm_config = vllm_config
|
|
99
111
|
cls.mesh = mesh
|
|
100
112
|
|
|
101
|
-
def get_linear_config(self, layer: LinearBase) ->
|
|
113
|
+
def get_linear_config(self, layer: LinearBase) -> VllmQuantLinearConfig:
|
|
102
114
|
assert isinstance(layer, LinearBase)
|
|
103
|
-
return
|
|
115
|
+
return VllmQuantLinearConfig(self.vllm_config, self.mesh, layer)
|
|
104
116
|
|
|
105
117
|
def get_moe_config(self, layer: FusedMoE) -> FusedMoEConfig:
|
|
106
118
|
assert isinstance(layer, FusedMoE)
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Union
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import torch
|
|
19
|
+
from jax.sharding import PartitionSpec
|
|
20
|
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
21
|
+
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
|
22
|
+
from vllm.model_executor.layers.quantization import \
|
|
23
|
+
register_quantization_config
|
|
24
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
25
|
+
QuantizeMethodBase
|
|
26
|
+
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
|
27
|
+
Fp8LinearMethod)
|
|
28
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
29
|
+
is_layer_skipped
|
|
30
|
+
|
|
31
|
+
from tpu_inference.layers.common.quant_methods import FP8, get_tpu_quant_method
|
|
32
|
+
from tpu_inference.layers.vllm.quantization.configs import (
|
|
33
|
+
VllmQuantConfig, VllmQuantLinearConfig)
|
|
34
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
35
|
+
VllmUnquantizedLinearMethod
|
|
36
|
+
from tpu_inference.logger import init_logger
|
|
37
|
+
|
|
38
|
+
P = PartitionSpec
|
|
39
|
+
|
|
40
|
+
logger = init_logger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@register_quantization_config(get_tpu_quant_method(FP8))
|
|
44
|
+
class VllmFp8Config(Fp8Config, VllmQuantConfig):
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def get_name(cls):
|
|
48
|
+
return FP8
|
|
49
|
+
|
|
50
|
+
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
51
|
+
return [torch.bfloat16]
|
|
52
|
+
|
|
53
|
+
def get_quant_method(
|
|
54
|
+
self, layer: torch.nn.Module, prefix: str
|
|
55
|
+
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
|
|
56
|
+
if isinstance(layer, LinearBase):
|
|
57
|
+
linear_config = self.get_linear_config(layer)
|
|
58
|
+
if is_layer_skipped(prefix, self.ignored_layers):
|
|
59
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
60
|
+
return VllmFp8LinearMethod(self, linear_config)
|
|
61
|
+
elif isinstance(layer, FusedMoE):
|
|
62
|
+
raise NotImplementedError(
|
|
63
|
+
"FP8 FusedMoE is currently not supported in torchax-jax")
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class VllmFp8LinearMethod(Fp8LinearMethod):
|
|
68
|
+
|
|
69
|
+
def __init__(self, quant_config: VllmFp8Config,
|
|
70
|
+
jax_config: VllmQuantLinearConfig):
|
|
71
|
+
super().__init__(quant_config)
|
|
72
|
+
self.jax_config = jax_config
|
|
73
|
+
self._configure_sharding()
|
|
74
|
+
|
|
75
|
+
def _configure_sharding(self) -> None:
|
|
76
|
+
|
|
77
|
+
raise NotImplementedError(
|
|
78
|
+
"Configure PartitionSpec for weight_sharding and scale_sharding "
|
|
79
|
+
"based on layer type (RowParallel/ColumnParallel)")
|
|
80
|
+
|
|
81
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
82
|
+
|
|
83
|
+
raise NotImplementedError(
|
|
84
|
+
"Convert layer.weight, layer.weight_scale, and optionally "
|
|
85
|
+
"layer.input_scale and layer.bias from torch tensors to JAX arrays "
|
|
86
|
+
"using torch_to_jax_param() with appropriate sharding")
|
|
87
|
+
|
|
88
|
+
def apply(self,
|
|
89
|
+
layer: torch.nn.Module,
|
|
90
|
+
x: torch.Tensor,
|
|
91
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
92
|
+
|
|
93
|
+
with jax.named_scope(layer._get_name()):
|
|
94
|
+
if self.jax_config.fuse_matmuls:
|
|
95
|
+
out = self._apply_fused(layer, x, bias)
|
|
96
|
+
else:
|
|
97
|
+
out = self._apply_split(layer, x, bias)
|
|
98
|
+
|
|
99
|
+
return out
|
|
100
|
+
|
|
101
|
+
def _apply_fused(self,
|
|
102
|
+
layer: torch.nn.Module,
|
|
103
|
+
x: torch.Tensor,
|
|
104
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
105
|
+
|
|
106
|
+
raise NotImplementedError(
|
|
107
|
+
"Implement single matmul for fused outputs: "
|
|
108
|
+
"quantize input to fp8, perform fp8 matmul with weight and scales, "
|
|
109
|
+
"dequantize output, and add bias if present")
|
|
110
|
+
|
|
111
|
+
def _apply_split(self,
|
|
112
|
+
layer: torch.nn.Module,
|
|
113
|
+
x: torch.Tensor,
|
|
114
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
115
|
+
|
|
116
|
+
raise NotImplementedError(
|
|
117
|
+
"Implement separate matmuls per output partition: "
|
|
118
|
+
"split weight/scale by output_sizes, perform fp8 matmul for each, "
|
|
119
|
+
"concatenate results, and add bias if present")
|