tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Union
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import torch
|
|
20
|
+
from jax.sharding import Mesh, NamedSharding
|
|
21
|
+
from jax.sharding import PartitionSpec as P
|
|
22
|
+
from torchax.interop import torch_view
|
|
23
|
+
from torchax.ops.mappings import t2j
|
|
24
|
+
|
|
25
|
+
from tpu_inference import envs
|
|
26
|
+
from tpu_inference.kernels.quantized_matmul.kernel import (
|
|
27
|
+
quantized_matmul_kernel, xla_quantized_matmul)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def sharded_quantized_matmul(x: jax.Array, w_q: jax.Array, w_s: jax.Array,
|
|
31
|
+
mesh: Mesh, weight_sharding: P) -> jax.Array:
|
|
32
|
+
"""
|
|
33
|
+
Wrapper around the quantized matmul kernel.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
x: Activation.
|
|
37
|
+
w_q: Weight quantized array. [n_output_features, n_input_features]
|
|
38
|
+
w_s: Weight quantization scale. [n_output_features]
|
|
39
|
+
mesh: Mesh to shard on.
|
|
40
|
+
weight_sharding: PartitionSpec for the weight tensor.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Output of the quantized matmul.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
# NOTE (jacobplatin/kyuyeunk) there have been numeric issues (concerning) NaNs
|
|
47
|
+
# with the kernel and thus we disable it for now.
|
|
48
|
+
if envs.ENABLE_QUANTIZED_MATMUL_KERNEL:
|
|
49
|
+
out_axis, in_axis = weight_sharding
|
|
50
|
+
x_sharding = P(None, in_axis)
|
|
51
|
+
scale_sharding = P(out_axis, )
|
|
52
|
+
out_sharding = P(None, out_axis)
|
|
53
|
+
|
|
54
|
+
x = jax.lax.with_sharding_constraint(x,
|
|
55
|
+
NamedSharding(mesh, x_sharding))
|
|
56
|
+
|
|
57
|
+
def wrapper(x, w_q, w_s):
|
|
58
|
+
output = quantized_matmul_kernel(x, w_q, w_s, x_q_dtype=w_q.dtype)
|
|
59
|
+
if in_axis:
|
|
60
|
+
output = jax.lax.psum(output, axis_name=in_axis)
|
|
61
|
+
return output
|
|
62
|
+
|
|
63
|
+
return jax.shard_map(wrapper,
|
|
64
|
+
mesh=mesh,
|
|
65
|
+
in_specs=(x_sharding, weight_sharding,
|
|
66
|
+
scale_sharding),
|
|
67
|
+
out_specs=(out_sharding),
|
|
68
|
+
check_vma=False)(x, w_q, w_s)
|
|
69
|
+
else:
|
|
70
|
+
return xla_quantized_matmul(x, w_q, w_s)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def reorder_concatenated_tensor_for_sharding(concatenated_tensor: jax.Array,
|
|
74
|
+
split_sizes: list[int],
|
|
75
|
+
n_shards: int, dim: int):
|
|
76
|
+
"""
|
|
77
|
+
Reorder a replicated concatenated tensor such that when sharded on multiple chips, each shard is a concatenation of the shards of the individual tensors.
|
|
78
|
+
For example, let the concatenated_tensor be:
|
|
79
|
+
AAAAAAAAAAAABBBBBBBBCCCC
|
|
80
|
+
12 As 8 Bs 4 Cs
|
|
81
|
+
and let the split_sizes = [12, 8, 4] and n_shards = 4.
|
|
82
|
+
The output is:
|
|
83
|
+
AAABBCAAABBCAAABBCAAABBC
|
|
84
|
+
In other words, it reorders the input tensor into 4 segements, with each segment corresponding to a shard and being AAABBC.
|
|
85
|
+
Args:
|
|
86
|
+
concatenated_tensor: the tensor, concatenated on the dimension specified by `dim`.
|
|
87
|
+
split_sizes: each individual tensor's size on the dimension specified by `dim`.
|
|
88
|
+
n_shards: num of shards.
|
|
89
|
+
dim: the dimension on which the concatenated_tensor is concatenated.
|
|
90
|
+
"""
|
|
91
|
+
# Split the concatenated tensor into individual tensors.
|
|
92
|
+
split_tensors = []
|
|
93
|
+
start_offset = 0
|
|
94
|
+
old_shape = concatenated_tensor.shape
|
|
95
|
+
# New shape ensures each split_tensor[i] maps to a tensor in ith shards
|
|
96
|
+
new_shape = old_shape[:dim] + (n_shards, -1) + old_shape[dim + 1:]
|
|
97
|
+
for split_size in split_sizes:
|
|
98
|
+
split_tensor = jax.lax.slice_in_dim(concatenated_tensor,
|
|
99
|
+
start_offset,
|
|
100
|
+
start_offset + split_size,
|
|
101
|
+
axis=dim)
|
|
102
|
+
split_tensors.append(split_tensor.reshape(new_shape))
|
|
103
|
+
start_offset += split_size
|
|
104
|
+
# While maintaining 0th dim as a shard dim, we concatenate along 1th dim to
|
|
105
|
+
# to create concatenated tnensor where 0th dim maps to shard dim.
|
|
106
|
+
reordered_tensor = jnp.concatenate(split_tensors, axis=dim + 1)
|
|
107
|
+
return reordered_tensor.reshape(old_shape)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def slice_sharded_tensor_for_concatenation(sharded_tensor: jax.Array,
|
|
111
|
+
split_sizes: list[int],
|
|
112
|
+
n_shards: int):
|
|
113
|
+
"""
|
|
114
|
+
Slice the input tensor which is sharded on multiple chips (on the last dim) into individual tensors with the same sharding.
|
|
115
|
+
For example, let the sharded_tensor be:
|
|
116
|
+
AAABBC | AAABBC | AAABBC | AAABBC
|
|
117
|
+
Shard0 Shard1 Shard2 Shard3
|
|
118
|
+
and let the split_sizes = [12, 8, 4] and n_shards = 4.
|
|
119
|
+
The output is a list of 3 tensors:
|
|
120
|
+
AAA | AAA | AAA | AAA
|
|
121
|
+
BB | BB | BB | BB
|
|
122
|
+
C | C | C | C
|
|
123
|
+
Shard0 Shard1 Shard2 Shard3
|
|
124
|
+
In other words, each individual tensor is a slice of the input tensor with the same sharding.
|
|
125
|
+
Args:
|
|
126
|
+
sharded_tensor: the input tensor, sharded on the last dim.
|
|
127
|
+
split_sizes: each individual tensor's size on the last dim.
|
|
128
|
+
n_shards: num of shards.
|
|
129
|
+
"""
|
|
130
|
+
new_shape = sharded_tensor.shape[:-1] + (n_shards, -1)
|
|
131
|
+
# New shape ensures each sharded_tensor[:, i] maps to a tensor in ith shards
|
|
132
|
+
sharded_tensor = sharded_tensor.reshape(new_shape)
|
|
133
|
+
|
|
134
|
+
split_tensors = []
|
|
135
|
+
start_offset = 0
|
|
136
|
+
for split_size in split_sizes:
|
|
137
|
+
assert split_size % n_shards == 0
|
|
138
|
+
sz = split_size // n_shards # size of this split tensor per shard
|
|
139
|
+
end_offset = start_offset + sz
|
|
140
|
+
# Because we are slicing over last dim, sharding dim remains intact.
|
|
141
|
+
# Therefore, splitting happens locally.
|
|
142
|
+
split_tensor = sharded_tensor[..., start_offset:end_offset]
|
|
143
|
+
split_tensors.append(split_tensor.reshape(new_shape[:-2] + (-1, )))
|
|
144
|
+
start_offset = end_offset
|
|
145
|
+
|
|
146
|
+
return split_tensors
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def torch_to_jax_param(
|
|
150
|
+
tensor: torch.Tensor,
|
|
151
|
+
sharding: NamedSharding,
|
|
152
|
+
output_sizes: Optional[int],
|
|
153
|
+
n_shards: int,
|
|
154
|
+
fused: bool,
|
|
155
|
+
dim: int = 0,
|
|
156
|
+
jax_dtype: Optional[jnp.dtype] = None,
|
|
157
|
+
) -> Union[torch.nn.Parameter, torch.nn.ParameterList]:
|
|
158
|
+
if output_sizes is None:
|
|
159
|
+
output_sizes = [tensor.shape[0]]
|
|
160
|
+
|
|
161
|
+
tensor = t2j(tensor, use_dlpack=False)
|
|
162
|
+
if jax_dtype:
|
|
163
|
+
tensor = tensor.astype(jax_dtype)
|
|
164
|
+
|
|
165
|
+
if fused:
|
|
166
|
+
tensor = reorder_concatenated_tensor_for_sharding(
|
|
167
|
+
tensor, output_sizes, n_shards, dim)
|
|
168
|
+
tensor = jax.device_put(tensor, sharding)
|
|
169
|
+
param = torch.nn.Parameter(torch_view(tensor), requires_grad=False)
|
|
170
|
+
else:
|
|
171
|
+
tensors = []
|
|
172
|
+
start_offset = 0
|
|
173
|
+
for size in output_sizes:
|
|
174
|
+
end_offset = start_offset + size
|
|
175
|
+
|
|
176
|
+
tensor_split = jax.lax.slice_in_dim(tensor,
|
|
177
|
+
start_offset,
|
|
178
|
+
end_offset,
|
|
179
|
+
axis=dim)
|
|
180
|
+
tensor_split = jax.device_put(tensor_split, sharding)
|
|
181
|
+
tensor_split = torch.nn.Parameter(torch_view(tensor_split),
|
|
182
|
+
requires_grad=False)
|
|
183
|
+
tensors.append(tensor_split)
|
|
184
|
+
|
|
185
|
+
start_offset = end_offset
|
|
186
|
+
param = torch.nn.ParameterList(tensors)
|
|
187
|
+
return param
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
MODEL_MATMUL_FUSION_TRUTH_TABLE = {
|
|
191
|
+
("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "QKVParallelLinear"):
|
|
192
|
+
True,
|
|
193
|
+
("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
|
|
194
|
+
False,
|
|
195
|
+
("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "QKVParallelLinear"):
|
|
196
|
+
False,
|
|
197
|
+
("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
|
|
198
|
+
False,
|
|
199
|
+
("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "QKVParallelLinear"):
|
|
200
|
+
False,
|
|
201
|
+
("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
|
|
202
|
+
False,
|
|
203
|
+
("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "QKVParallelLinear"):
|
|
204
|
+
False,
|
|
205
|
+
("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
|
|
206
|
+
False,
|
|
207
|
+
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "QKVParallelLinear"):
|
|
208
|
+
False,
|
|
209
|
+
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "MergedColumnParallelLinear"):
|
|
210
|
+
False,
|
|
211
|
+
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "QKVParallelLinear"):
|
|
212
|
+
False,
|
|
213
|
+
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "MergedColumnParallelLinear"):
|
|
214
|
+
False,
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def get_model_matmul_fusion_assignment(model_name: str, batch_size: int,
|
|
219
|
+
tp_size: int, layer_name: str):
|
|
220
|
+
key = (model_name, batch_size, tp_size, layer_name)
|
|
221
|
+
return MODEL_MATMUL_FUSION_TRUTH_TABLE.get(key, True)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import copy
|
|
16
|
+
|
|
17
|
+
from jax.sharding import Mesh
|
|
18
|
+
from vllm.config import VllmConfig
|
|
19
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
20
|
+
QuantizationConfig
|
|
21
|
+
|
|
22
|
+
from tpu_inference.layers.common import quant_methods
|
|
23
|
+
from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
|
|
24
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
25
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
|
|
26
|
+
VllmCompressedTensorsConfig # noqa: E501
|
|
27
|
+
from tpu_inference.layers.vllm.quantization.fp8 import VllmFp8Config
|
|
28
|
+
from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
|
|
29
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
30
|
+
VllmUnquantizedConfig
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_tpu_quantization_config(vllm_config: VllmConfig,
|
|
34
|
+
mesh: Mesh) -> QuantizationConfig:
|
|
35
|
+
model_config = copy.deepcopy(vllm_config.model_config)
|
|
36
|
+
# TODO(kyuyeunk): Add support for "tpu_int8".
|
|
37
|
+
method_to_config: dict[str, str] = {
|
|
38
|
+
None: VllmUnquantizedConfig,
|
|
39
|
+
quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
|
|
40
|
+
quant_methods.AWQ: VllmAWQConfig,
|
|
41
|
+
quant_methods.FP8: VllmFp8Config,
|
|
42
|
+
quant_methods.MXFP4: VllmMxfp4Config,
|
|
43
|
+
}
|
|
44
|
+
if model_config.quantization not in method_to_config:
|
|
45
|
+
raise NotImplementedError(
|
|
46
|
+
f"{model_config.quantization} quantization method not supported."
|
|
47
|
+
f" Supported methods are {method_to_config.keys()}")
|
|
48
|
+
quant_config = method_to_config[model_config.quantization]
|
|
49
|
+
assert issubclass(quant_config, JaxCommonConfig)
|
|
50
|
+
quant_config.set_configs(vllm_config, mesh)
|
|
51
|
+
|
|
52
|
+
model_config.quantization = quant_methods.get_tpu_quant_method(
|
|
53
|
+
quant_config.get_name())
|
|
54
|
+
return VllmConfig.get_quantization_config(model_config,
|
|
55
|
+
vllm_config.load_config)
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Union
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import torch
|
|
20
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
21
|
+
from torchax.interop import jax_view, torch_view
|
|
22
|
+
from vllm.logger import init_logger
|
|
23
|
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
|
24
|
+
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
|
25
|
+
from vllm.model_executor.layers.quantization import \
|
|
26
|
+
register_quantization_config
|
|
27
|
+
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
|
28
|
+
AWQLinearMethod)
|
|
29
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
30
|
+
QuantizeMethodBase
|
|
31
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
32
|
+
is_layer_skipped, unpack_quantized_values_into_int32)
|
|
33
|
+
from vllm.scalar_type import scalar_types
|
|
34
|
+
|
|
35
|
+
from tpu_inference.layers.common.quant_methods import AWQ, get_tpu_quant_method
|
|
36
|
+
from tpu_inference.layers.vllm.linear_common import (
|
|
37
|
+
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
|
|
38
|
+
from tpu_inference.layers.vllm.quantization.common import (
|
|
39
|
+
JaxCommonConfig, JaxCommonLinearConfig)
|
|
40
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
41
|
+
VllmUnquantizedLinearMethod
|
|
42
|
+
|
|
43
|
+
P = PartitionSpec
|
|
44
|
+
logger = init_logger(__name__)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@register_quantization_config(get_tpu_quant_method(AWQ))
|
|
48
|
+
class VllmAWQConfig(AWQConfig, JaxCommonConfig):
|
|
49
|
+
|
|
50
|
+
@classmethod
|
|
51
|
+
def get_name(cls):
|
|
52
|
+
return AWQ
|
|
53
|
+
|
|
54
|
+
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
55
|
+
# NOTE: AWQ checkpoint was quantized with float16. But on TPUs, using
|
|
56
|
+
# bfloat16 is significantly preferred over float16. This might lead to
|
|
57
|
+
# some numeric output change.
|
|
58
|
+
return [torch.bfloat16]
|
|
59
|
+
|
|
60
|
+
def get_quant_method(
|
|
61
|
+
self, layer: torch.nn.Module, prefix: str
|
|
62
|
+
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
|
|
63
|
+
if isinstance(layer, LinearBase):
|
|
64
|
+
linear_config = self.get_linear_config(layer)
|
|
65
|
+
if is_layer_skipped(prefix, self.modules_to_not_convert):
|
|
66
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
67
|
+
return VllmAWQLinearMethod(self, linear_config)
|
|
68
|
+
elif isinstance(layer, FusedMoE):
|
|
69
|
+
raise NotImplementedError(
|
|
70
|
+
"AWQ FusedMoE is currently not supported in torchax-jax")
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class VllmAWQLinearMethod(AWQLinearMethod):
|
|
75
|
+
|
|
76
|
+
def __init__(self, quant_config: VllmAWQConfig,
|
|
77
|
+
jax_config: JaxCommonLinearConfig):
|
|
78
|
+
super().__init__(quant_config)
|
|
79
|
+
self.jax_config = jax_config
|
|
80
|
+
|
|
81
|
+
out_sharding, in_sharding = self.jax_config.weight_sharding[:]
|
|
82
|
+
self.jax_config.weight_sharding = P(in_sharding, None, out_sharding)
|
|
83
|
+
self.jax_config.scale_sharding = P(in_sharding, out_sharding)
|
|
84
|
+
|
|
85
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
86
|
+
qweight = layer.qweight
|
|
87
|
+
qweight = unpack_awq_weight(qweight, qweight.packed_dim)
|
|
88
|
+
|
|
89
|
+
group_size = self.quant_config.group_size
|
|
90
|
+
# Reshape so that each qweight[i] were quantized with same scales[i].
|
|
91
|
+
qweight = qweight.reshape((-1, group_size, layer.output_size))
|
|
92
|
+
qweight = torch_to_jax_param(qweight,
|
|
93
|
+
NamedSharding(
|
|
94
|
+
self.jax_config.mesh,
|
|
95
|
+
self.jax_config.weight_sharding),
|
|
96
|
+
self.jax_config.output_sizes,
|
|
97
|
+
self.jax_config.n_shards,
|
|
98
|
+
self.jax_config.fuse_matmuls,
|
|
99
|
+
dim=2,
|
|
100
|
+
jax_dtype=jnp.uint4)
|
|
101
|
+
delattr(layer, "qweight")
|
|
102
|
+
layer.qweight = qweight
|
|
103
|
+
|
|
104
|
+
qzeros = layer.qzeros
|
|
105
|
+
qzeros = unpack_awq_weight(qzeros, qzeros.packed_dim)
|
|
106
|
+
qzeros = torch_to_jax_param(qzeros,
|
|
107
|
+
NamedSharding(
|
|
108
|
+
self.jax_config.mesh,
|
|
109
|
+
self.jax_config.scale_sharding),
|
|
110
|
+
self.jax_config.output_sizes,
|
|
111
|
+
self.jax_config.n_shards,
|
|
112
|
+
self.jax_config.fuse_matmuls,
|
|
113
|
+
dim=1,
|
|
114
|
+
jax_dtype=jnp.uint4)
|
|
115
|
+
delattr(layer, "qzeros")
|
|
116
|
+
layer.qzeros = qzeros
|
|
117
|
+
|
|
118
|
+
scales = torch_to_jax_param(layer.scales,
|
|
119
|
+
NamedSharding(
|
|
120
|
+
self.jax_config.mesh,
|
|
121
|
+
self.jax_config.scale_sharding),
|
|
122
|
+
self.jax_config.output_sizes,
|
|
123
|
+
self.jax_config.n_shards,
|
|
124
|
+
self.jax_config.fuse_matmuls,
|
|
125
|
+
dim=1)
|
|
126
|
+
delattr(layer, "scales")
|
|
127
|
+
layer.scales = scales
|
|
128
|
+
|
|
129
|
+
if layer.bias is not None and not layer.skip_bias_add:
|
|
130
|
+
if layer.return_bias:
|
|
131
|
+
logger.warning_once("Bias might return incorrect value.")
|
|
132
|
+
|
|
133
|
+
bias = torch_to_jax_param(
|
|
134
|
+
layer.bias,
|
|
135
|
+
NamedSharding(self.jax_config.mesh,
|
|
136
|
+
self.jax_config.bias_sharding),
|
|
137
|
+
self.jax_config.output_sizes,
|
|
138
|
+
self.jax_config.n_shards,
|
|
139
|
+
self.jax_config.fuse_matmuls,
|
|
140
|
+
)
|
|
141
|
+
delattr(layer, "bias")
|
|
142
|
+
layer.bias = bias
|
|
143
|
+
|
|
144
|
+
def apply(self,
|
|
145
|
+
layer: torch.nn.Module,
|
|
146
|
+
x: torch.Tensor,
|
|
147
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
148
|
+
|
|
149
|
+
with jax.named_scope(layer._get_name()):
|
|
150
|
+
if self.jax_config.fuse_matmuls:
|
|
151
|
+
out = self._apply_fused(layer, x, bias)
|
|
152
|
+
else:
|
|
153
|
+
out = self._apply_split(layer, x, bias)
|
|
154
|
+
|
|
155
|
+
return out
|
|
156
|
+
|
|
157
|
+
def _apply_fused(self,
|
|
158
|
+
layer: torch.nn.Module,
|
|
159
|
+
x: torch.Tensor,
|
|
160
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
161
|
+
x_jax = jax_view(x)
|
|
162
|
+
|
|
163
|
+
qweight = jax_view(layer.qweight)
|
|
164
|
+
qzeros = jnp.expand_dims(jax_view(layer.qzeros), 1)
|
|
165
|
+
scales = jnp.expand_dims(jax_view(layer.scales), 1)
|
|
166
|
+
|
|
167
|
+
qweight = qweight.astype(jnp.int8)
|
|
168
|
+
qzeros = qzeros.astype(jnp.int8)
|
|
169
|
+
|
|
170
|
+
weight = (qweight - qzeros) * scales
|
|
171
|
+
weight = weight.reshape((-1, weight.shape[-1]))
|
|
172
|
+
outs = jnp.einsum("bd,df->bf", x_jax, weight)
|
|
173
|
+
|
|
174
|
+
if bias is not None and not layer.skip_bias_add:
|
|
175
|
+
outs += bias.jax()
|
|
176
|
+
|
|
177
|
+
outs = slice_sharded_tensor_for_concatenation(
|
|
178
|
+
outs, self.jax_config.output_sizes, self.jax_config.n_shards)
|
|
179
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
180
|
+
return torch_view(out)
|
|
181
|
+
|
|
182
|
+
def _apply_split(self,
|
|
183
|
+
layer: torch.nn.Module,
|
|
184
|
+
x: torch.Tensor,
|
|
185
|
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
186
|
+
assert isinstance(layer.qweight, torch.nn.ParameterList)
|
|
187
|
+
|
|
188
|
+
x_jax = jax_view(x)
|
|
189
|
+
params = zip(layer.qweight, layer.qzeros, layer.scales)
|
|
190
|
+
outs = []
|
|
191
|
+
for i, (qweight, qzeros, scales) in enumerate(params):
|
|
192
|
+
qweight = jax_view(qweight)
|
|
193
|
+
scales = jnp.expand_dims(jax_view(scales), 1)
|
|
194
|
+
qzeros = jnp.expand_dims(jax_view(qzeros), 1)
|
|
195
|
+
|
|
196
|
+
qweight = qweight.astype(jnp.int8)
|
|
197
|
+
qzeros = qzeros.astype(jnp.int8)
|
|
198
|
+
|
|
199
|
+
weight = (qweight - qzeros) * scales
|
|
200
|
+
weight = weight.reshape((-1, weight.shape[-1]))
|
|
201
|
+
out = jnp.einsum("bd,df->bf", x_jax, weight)
|
|
202
|
+
|
|
203
|
+
if bias is not None and not layer.skip_bias_add:
|
|
204
|
+
out += jax_view(bias[i])
|
|
205
|
+
|
|
206
|
+
outs.append(out)
|
|
207
|
+
out = jnp.concatenate(outs, axis=-1)
|
|
208
|
+
return torch_view(out)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def unpack_awq_weight(weight: torch.Tensor, packed_dim: int):
|
|
212
|
+
weight = unpack_quantized_values_into_int32(weight, scalar_types.uint4,
|
|
213
|
+
packed_dim)
|
|
214
|
+
|
|
215
|
+
# AWQ packs 8 uint4 into 32-bits in this order: (0, 2, 4, 6, 1, 3, 5, 7).
|
|
216
|
+
# Following list maps the order used by AWQ into an ascending order.
|
|
217
|
+
reverse_awq_order = (0, 4, 1, 5, 2, 6, 3, 7)
|
|
218
|
+
|
|
219
|
+
orig_shape = weight.shape
|
|
220
|
+
weight = weight.reshape(orig_shape[:-1] + (-1, 8))
|
|
221
|
+
return weight[..., reverse_awq_order].reshape(orig_shape)
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import torchax
|
|
16
|
+
from jax.sharding import Mesh, PartitionSpec
|
|
17
|
+
from vllm.config import VllmConfig
|
|
18
|
+
from vllm.logger import init_logger
|
|
19
|
+
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEConfig
|
|
20
|
+
# yapf: disable
|
|
21
|
+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
22
|
+
LinearBase,
|
|
23
|
+
MergedColumnParallelLinear,
|
|
24
|
+
QKVParallelLinear,
|
|
25
|
+
ReplicatedLinear,
|
|
26
|
+
RowParallelLinear)
|
|
27
|
+
|
|
28
|
+
from tpu_inference.layers.vllm.linear_common import \
|
|
29
|
+
get_model_matmul_fusion_assignment
|
|
30
|
+
from tpu_inference.utils import TPU_SECOND_LAST_MINOR
|
|
31
|
+
|
|
32
|
+
# yapf: enable
|
|
33
|
+
|
|
34
|
+
P = PartitionSpec
|
|
35
|
+
|
|
36
|
+
logger = init_logger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class JaxCommonLinearConfig:
|
|
40
|
+
|
|
41
|
+
def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
|
|
42
|
+
assert isinstance(layer, LinearBase)
|
|
43
|
+
|
|
44
|
+
self.mesh = mesh
|
|
45
|
+
self.output_sizes = [layer.output_size]
|
|
46
|
+
self.weight_sharding = P(None, None)
|
|
47
|
+
self.fuse_matmuls = True
|
|
48
|
+
self.enable_sp = vllm_config.compilation_config.pass_config.enable_sp
|
|
49
|
+
self.input_sharding = None
|
|
50
|
+
self.output_sharding = None
|
|
51
|
+
|
|
52
|
+
if isinstance(layer, RowParallelLinear):
|
|
53
|
+
self.weight_sharding = P(None, "model")
|
|
54
|
+
if self.enable_sp:
|
|
55
|
+
self.output_sharding = P("model", None)
|
|
56
|
+
elif isinstance(layer, ColumnParallelLinear):
|
|
57
|
+
self.weight_sharding = P("model", None)
|
|
58
|
+
if self.enable_sp:
|
|
59
|
+
self.input_sharding = P("model", None)
|
|
60
|
+
|
|
61
|
+
if isinstance(layer, MergedColumnParallelLinear) or isinstance(
|
|
62
|
+
layer, QKVParallelLinear):
|
|
63
|
+
self.output_sizes = layer.output_sizes
|
|
64
|
+
|
|
65
|
+
self.fuse_matmuls = get_model_matmul_fusion_assignment(
|
|
66
|
+
vllm_config.model_config.model,
|
|
67
|
+
vllm_config.scheduler_config.max_num_batched_tokens,
|
|
68
|
+
vllm_config.parallel_config.tensor_parallel_size,
|
|
69
|
+
layer._get_name())
|
|
70
|
+
elif isinstance(layer, ReplicatedLinear):
|
|
71
|
+
self.weight_sharding = P(None, None)
|
|
72
|
+
else:
|
|
73
|
+
logger.warning(
|
|
74
|
+
"Unsupported linear layer type of %s. Can potentially yield "
|
|
75
|
+
" bad performance.", type(layer))
|
|
76
|
+
|
|
77
|
+
self.bias_sharding = P(self.weight_sharding[0])
|
|
78
|
+
if isinstance(self.weight_sharding[0], tuple):
|
|
79
|
+
self.n_shards = 1
|
|
80
|
+
for axis in self.weight_sharding[0]:
|
|
81
|
+
self.n_shards *= self.mesh.shape.get(axis, 1)
|
|
82
|
+
else:
|
|
83
|
+
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
|
|
84
|
+
|
|
85
|
+
def get_input_sharding(self, x: torchax.tensor.Tensor):
|
|
86
|
+
if self.enable_sp:
|
|
87
|
+
token_num = x.shape[0]
|
|
88
|
+
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
|
|
89
|
+
if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
|
|
90
|
+
return self.input_sharding
|
|
91
|
+
else:
|
|
92
|
+
return None
|
|
93
|
+
return self.input_sharding
|
|
94
|
+
|
|
95
|
+
def get_output_sharding(self, x: torchax.tensor.Tensor):
|
|
96
|
+
if self.enable_sp:
|
|
97
|
+
token_num = x.shape[0]
|
|
98
|
+
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
|
|
99
|
+
if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
|
|
100
|
+
return self.output_sharding
|
|
101
|
+
else:
|
|
102
|
+
return None
|
|
103
|
+
return self.output_sharding
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class JaxCommonConfig:
|
|
107
|
+
vllm_config: VllmConfig
|
|
108
|
+
mesh: Mesh
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def set_configs(cls, vllm_config: VllmConfig, mesh: Mesh):
|
|
112
|
+
cls.vllm_config = vllm_config
|
|
113
|
+
cls.mesh = mesh
|
|
114
|
+
|
|
115
|
+
def get_linear_config(self, layer: LinearBase) -> JaxCommonLinearConfig:
|
|
116
|
+
assert isinstance(layer, LinearBase)
|
|
117
|
+
return JaxCommonLinearConfig(self.vllm_config, self.mesh, layer)
|
|
118
|
+
|
|
119
|
+
def get_moe_config(self, layer: FusedMoE) -> FusedMoEConfig:
|
|
120
|
+
assert isinstance(layer, FusedMoE)
|
|
121
|
+
moe_config = layer.moe_config
|
|
122
|
+
use_ep = self.vllm_config.parallel_config.enable_expert_parallel
|
|
123
|
+
moe_config.moe_parallel_config.use_ep = use_ep
|
|
124
|
+
return moe_config
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|