tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,38 +1,73 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Optional, Union
|
|
2
16
|
|
|
3
17
|
import jax
|
|
4
18
|
import jax.numpy as jnp
|
|
5
19
|
import torch
|
|
6
|
-
from jax.experimental.shard_map import shard_map
|
|
7
20
|
from jax.sharding import Mesh, NamedSharding
|
|
8
21
|
from jax.sharding import PartitionSpec as P
|
|
9
22
|
from torchax.interop import torch_view
|
|
10
23
|
from torchax.ops.mappings import t2j
|
|
11
24
|
|
|
12
|
-
from tpu_inference
|
|
13
|
-
|
|
25
|
+
from tpu_inference import envs
|
|
26
|
+
from tpu_inference.kernels.quantized_matmul.kernel import (
|
|
27
|
+
quantized_matmul_kernel, xla_quantized_matmul)
|
|
14
28
|
|
|
15
29
|
|
|
16
30
|
def sharded_quantized_matmul(x: jax.Array, w_q: jax.Array, w_s: jax.Array,
|
|
17
|
-
mesh: Mesh, weight_sharding: P):
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
31
|
+
mesh: Mesh, weight_sharding: P) -> jax.Array:
|
|
32
|
+
"""
|
|
33
|
+
Wrapper around the quantized matmul kernel.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
x: Activation.
|
|
37
|
+
w_q: Weight quantized array. [n_output_features, n_input_features]
|
|
38
|
+
w_s: Weight quantization scale. [n_output_features]
|
|
39
|
+
mesh: Mesh to shard on.
|
|
40
|
+
weight_sharding: PartitionSpec for the weight tensor.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Output of the quantized matmul.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
# NOTE (jacobplatin/kyuyeunk) there have been numeric issues (concerning) NaNs
|
|
47
|
+
# with the kernel and thus we disable it for now.
|
|
48
|
+
if envs.ENABLE_QUANTIZED_MATMUL_KERNEL:
|
|
49
|
+
out_axis, in_axis = weight_sharding
|
|
50
|
+
x_sharding = P(None, in_axis)
|
|
51
|
+
scale_sharding = P(out_axis, )
|
|
52
|
+
out_sharding = P(None, out_axis)
|
|
53
|
+
|
|
54
|
+
x = jax.lax.with_sharding_constraint(x,
|
|
55
|
+
NamedSharding(mesh, x_sharding))
|
|
56
|
+
|
|
57
|
+
def wrapper(x, w_q, w_s):
|
|
58
|
+
output = quantized_matmul_kernel(x, w_q, w_s, x_q_dtype=w_q.dtype)
|
|
59
|
+
if in_axis:
|
|
60
|
+
output = jax.lax.psum(output, axis_name=in_axis)
|
|
61
|
+
return output
|
|
62
|
+
|
|
63
|
+
return jax.shard_map(wrapper,
|
|
64
|
+
mesh=mesh,
|
|
65
|
+
in_specs=(x_sharding, weight_sharding,
|
|
66
|
+
scale_sharding),
|
|
67
|
+
out_specs=(out_sharding),
|
|
68
|
+
check_vma=False)(x, w_q, w_s)
|
|
69
|
+
else:
|
|
70
|
+
return xla_quantized_matmul(x, w_q, w_s)
|
|
36
71
|
|
|
37
72
|
|
|
38
73
|
def reorder_concatenated_tensor_for_sharding(concatenated_tensor: jax.Array,
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import copy
|
|
2
16
|
|
|
3
17
|
from jax.sharding import Mesh
|
|
@@ -10,6 +24,7 @@ from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
|
|
|
10
24
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
11
25
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
|
|
12
26
|
VllmCompressedTensorsConfig # noqa: E501
|
|
27
|
+
from tpu_inference.layers.vllm.quantization.fp8 import VllmFp8Config
|
|
13
28
|
from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
|
|
14
29
|
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
15
30
|
VllmUnquantizedConfig
|
|
@@ -23,6 +38,7 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
|
|
|
23
38
|
None: VllmUnquantizedConfig,
|
|
24
39
|
quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
|
|
25
40
|
quant_methods.AWQ: VllmAWQConfig,
|
|
41
|
+
quant_methods.FP8: VllmFp8Config,
|
|
26
42
|
quant_methods.MXFP4: VllmMxfp4Config,
|
|
27
43
|
}
|
|
28
44
|
if model_config.quantization not in method_to_config:
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Optional, Union
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -39,7 +53,7 @@ class VllmAWQConfig(AWQConfig, JaxCommonConfig):
|
|
|
39
53
|
|
|
40
54
|
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
41
55
|
# NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
|
|
42
|
-
# bfloat16 is
|
|
56
|
+
# bfloat16 is significantly preferred over float16. This might lead to
|
|
43
57
|
# some numeric output change.
|
|
44
58
|
return [torch.bfloat16]
|
|
45
59
|
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import torchax
|
|
2
16
|
from jax.sharding import Mesh, PartitionSpec
|
|
3
17
|
from vllm.config import VllmConfig
|
|
@@ -11,9 +25,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
|
11
25
|
ReplicatedLinear,
|
|
12
26
|
RowParallelLinear)
|
|
13
27
|
|
|
28
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
14
29
|
from tpu_inference.layers.vllm.linear_common import \
|
|
15
30
|
get_model_matmul_fusion_assignment
|
|
16
|
-
from tpu_inference.utils import TPU_SECOND_LAST_MINOR
|
|
31
|
+
from tpu_inference.utils import TPU_SECOND_LAST_MINOR, get_mesh_shape_product
|
|
17
32
|
|
|
18
33
|
# yapf: enable
|
|
19
34
|
|
|
@@ -31,18 +46,22 @@ class JaxCommonLinearConfig:
|
|
|
31
46
|
self.output_sizes = [layer.output_size]
|
|
32
47
|
self.weight_sharding = P(None, None)
|
|
33
48
|
self.fuse_matmuls = True
|
|
34
|
-
self.
|
|
49
|
+
self.enable_sp = vllm_config.compilation_config.pass_config.enable_sp
|
|
35
50
|
self.input_sharding = None
|
|
36
51
|
self.output_sharding = None
|
|
37
52
|
|
|
53
|
+
self.tp_size = get_mesh_shape_product(self.mesh,
|
|
54
|
+
ShardingAxisName.MLP_TENSOR)
|
|
55
|
+
|
|
38
56
|
if isinstance(layer, RowParallelLinear):
|
|
39
|
-
self.weight_sharding = P(None,
|
|
40
|
-
if self.
|
|
41
|
-
self.output_sharding = P(
|
|
57
|
+
self.weight_sharding = P(None, ShardingAxisName.ATTN_HEAD)
|
|
58
|
+
if self.enable_sp:
|
|
59
|
+
self.output_sharding = P(ShardingAxisName.MLP_TENSOR, None)
|
|
42
60
|
elif isinstance(layer, ColumnParallelLinear):
|
|
43
|
-
self.weight_sharding = P(
|
|
44
|
-
|
|
45
|
-
|
|
61
|
+
self.weight_sharding = P(ShardingAxisName.ATTN_HEAD, None)
|
|
62
|
+
|
|
63
|
+
if self.enable_sp:
|
|
64
|
+
self.input_sharding = P(ShardingAxisName.MLP_TENSOR, None)
|
|
46
65
|
|
|
47
66
|
if isinstance(layer, MergedColumnParallelLinear) or isinstance(
|
|
48
67
|
layer, QKVParallelLinear):
|
|
@@ -61,28 +80,24 @@ class JaxCommonLinearConfig:
|
|
|
61
80
|
" bad performance.", type(layer))
|
|
62
81
|
|
|
63
82
|
self.bias_sharding = P(self.weight_sharding[0])
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
for axis in self.weight_sharding[0]:
|
|
67
|
-
self.n_shards *= self.mesh.shape.get(axis, 1)
|
|
68
|
-
else:
|
|
69
|
-
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
|
|
83
|
+
self.n_shards = get_mesh_shape_product(self.mesh,
|
|
84
|
+
self.weight_sharding[0])
|
|
70
85
|
|
|
71
86
|
def get_input_sharding(self, x: torchax.tensor.Tensor):
|
|
72
|
-
if self.
|
|
87
|
+
if self.enable_sp:
|
|
73
88
|
token_num = x.shape[0]
|
|
74
89
|
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
|
|
75
|
-
if token_num // self.
|
|
90
|
+
if token_num // self.tp_size >= TPU_SECOND_LAST_MINOR:
|
|
76
91
|
return self.input_sharding
|
|
77
92
|
else:
|
|
78
93
|
return None
|
|
79
94
|
return self.input_sharding
|
|
80
95
|
|
|
81
96
|
def get_output_sharding(self, x: torchax.tensor.Tensor):
|
|
82
|
-
if self.
|
|
97
|
+
if self.enable_sp:
|
|
83
98
|
token_num = x.shape[0]
|
|
84
99
|
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
|
|
85
|
-
if token_num // self.
|
|
100
|
+
if token_num // self.tp_size >= TPU_SECOND_LAST_MINOR:
|
|
86
101
|
return self.output_sharding
|
|
87
102
|
else:
|
|
88
103
|
return None
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Optional
|
|
2
16
|
|
|
3
17
|
import torch
|
|
@@ -20,7 +34,7 @@ from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
|
|
|
20
34
|
get_tpu_quant_method)
|
|
21
35
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
22
36
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
|
|
23
|
-
|
|
37
|
+
VllmCompressedTensorsMoEMethod
|
|
24
38
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
|
|
25
39
|
VllmCompressedTensorsW8A8Fp8
|
|
26
40
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
|
|
@@ -113,8 +127,9 @@ class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
|
|
|
113
127
|
layer.scheme = scheme
|
|
114
128
|
return CompressedTensorsLinearMethod(self)
|
|
115
129
|
if isinstance(layer, FusedMoE):
|
|
116
|
-
|
|
117
|
-
|
|
130
|
+
layer.moe_config = self.get_moe_config(layer)
|
|
131
|
+
return VllmCompressedTensorsMoEMethod.get_moe_method(
|
|
132
|
+
self, layer, layer_name=prefix)
|
|
118
133
|
if isinstance(layer, Attention):
|
|
119
134
|
return CompressedTensorsKVCacheMethod(self)
|
|
120
135
|
return None
|