tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,18 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional, Union
|
|
2
16
|
|
|
3
17
|
import jax
|
|
4
18
|
import jax.numpy as jnp
|
|
@@ -25,17 +39,23 @@ from tpu_inference import envs
|
|
|
25
39
|
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
26
40
|
from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
|
|
27
41
|
get_tpu_quant_method)
|
|
28
|
-
from tpu_inference.layers.
|
|
42
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
43
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func
|
|
29
44
|
from tpu_inference.layers.vllm.linear_common import (
|
|
30
45
|
reorder_concatenated_tensor_for_sharding,
|
|
31
46
|
slice_sharded_tensor_for_concatenation, torch_to_jax_param)
|
|
32
47
|
from tpu_inference.layers.vllm.quantization.common import (
|
|
33
48
|
JaxCommonConfig, JaxCommonLinearConfig)
|
|
49
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
34
50
|
|
|
35
51
|
P = PartitionSpec
|
|
36
52
|
logger = init_logger(__name__)
|
|
37
53
|
|
|
38
54
|
|
|
55
|
+
def align_to(a, b):
|
|
56
|
+
return (a + b - 1) // b * b
|
|
57
|
+
|
|
58
|
+
|
|
39
59
|
@register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
|
|
40
60
|
class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
|
|
41
61
|
|
|
@@ -168,7 +188,7 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
168
188
|
ep_axis_name: str = 'model'):
|
|
169
189
|
super().__init__(moe)
|
|
170
190
|
self.mesh = mesh
|
|
171
|
-
self.use_kernel = envs.USE_MOE_EP_KERNEL
|
|
191
|
+
self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
|
|
172
192
|
self.ep_axis_name = ep_axis_name
|
|
173
193
|
# TODO: Use autotune table once we have it.
|
|
174
194
|
self.block_size = {
|
|
@@ -196,6 +216,8 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
196
216
|
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
197
217
|
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
198
218
|
|
|
219
|
+
num_experts, hidden_size, intermediate_size = w2_weight.shape
|
|
220
|
+
|
|
199
221
|
if self.moe.has_bias:
|
|
200
222
|
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
201
223
|
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
@@ -214,7 +236,7 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
214
236
|
w3_bias = w13_bias[:, 1::2]
|
|
215
237
|
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
216
238
|
|
|
217
|
-
if self.use_kernel
|
|
239
|
+
if self.use_kernel:
|
|
218
240
|
# Kernel expects:
|
|
219
241
|
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
220
242
|
# w2: (num_experts, intermediate_size, hidden_size)
|
|
@@ -225,87 +247,119 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
225
247
|
intermediate_size = w13_weight.shape[1] // 2
|
|
226
248
|
hidden_size = w13_weight.shape[2]
|
|
227
249
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
intermediate_size, hidden_size)
|
|
231
|
-
w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
|
|
250
|
+
padded_intermediate_size = align_to(intermediate_size, 256)
|
|
251
|
+
padded_hidden_size = align_to(hidden_size, 256)
|
|
232
252
|
|
|
233
253
|
# Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
|
|
234
|
-
|
|
254
|
+
w13_weight = w13_weight.reshape(num_experts, 2, intermediate_size,
|
|
255
|
+
hidden_size)
|
|
256
|
+
w13_weight = jnp.swapaxes(w13_weight, 3, 2)
|
|
257
|
+
|
|
258
|
+
w2_weight = jnp.swapaxes(w2_weight, 2, 1)
|
|
259
|
+
|
|
260
|
+
w13_weight = jnp.pad(
|
|
261
|
+
w13_weight,
|
|
262
|
+
((0, 0), (0, 0), (0, padded_hidden_size - hidden_size),
|
|
263
|
+
(0, padded_intermediate_size - intermediate_size)),
|
|
264
|
+
constant_values=0)
|
|
265
|
+
|
|
266
|
+
w2_weight = jnp.pad(
|
|
267
|
+
w2_weight,
|
|
268
|
+
((0, 0), (0, padded_intermediate_size - intermediate_size),
|
|
269
|
+
(0, padded_hidden_size - hidden_size)),
|
|
270
|
+
constant_values=0)
|
|
235
271
|
|
|
236
272
|
# Apply EP sharding
|
|
273
|
+
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
274
|
+
|
|
237
275
|
w13_weight = jax.device_put(
|
|
238
|
-
|
|
276
|
+
w13_weight,
|
|
239
277
|
Format(Layout((0, 1, 2, 3)),
|
|
240
278
|
NamedSharding(self.mesh, P("model", None, None, None))))
|
|
241
279
|
w2_weight = jax.device_put(
|
|
242
|
-
|
|
280
|
+
w2_weight,
|
|
243
281
|
Format(Layout((0, 1, 2)),
|
|
244
282
|
NamedSharding(self.mesh, P("model", None, None))))
|
|
245
283
|
|
|
246
284
|
if self.moe.has_bias:
|
|
247
|
-
w13_bias = w13_bias.reshape(
|
|
285
|
+
w13_bias = w13_bias.astype(jnp.float32).reshape(
|
|
286
|
+
num_experts, 2, 1, intermediate_size)
|
|
287
|
+
w2_bias = w2_bias.astype(jnp.float32).reshape(
|
|
288
|
+
num_experts, 1, hidden_size)
|
|
289
|
+
|
|
290
|
+
w13_bias = jnp.pad(
|
|
291
|
+
w13_bias,
|
|
292
|
+
((0, 0), (0, 0), (0, 0),
|
|
293
|
+
(0, padded_intermediate_size - intermediate_size)),
|
|
294
|
+
constant_values=0)
|
|
295
|
+
|
|
296
|
+
w2_bias = jnp.pad(w2_bias,
|
|
297
|
+
((0, 0), (0, 0),
|
|
298
|
+
(0, padded_hidden_size - hidden_size)),
|
|
299
|
+
constant_values=0)
|
|
248
300
|
|
|
249
301
|
# Apply EP sharding
|
|
250
302
|
w13_bias = jax.device_put(
|
|
251
|
-
w13_bias,
|
|
252
|
-
Format(Layout((0, 1, 2)),
|
|
253
|
-
NamedSharding(self.mesh, P("model", None, None))))
|
|
303
|
+
w13_bias, Format(Layout((0, 1, 2, 3)), ep_sharding))
|
|
254
304
|
w2_bias = jax.device_put(
|
|
255
|
-
w2_bias,
|
|
256
|
-
Format(Layout((0, 1)),
|
|
257
|
-
NamedSharding(self.mesh, P("model", None))))
|
|
258
|
-
|
|
305
|
+
w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
|
|
259
306
|
else:
|
|
260
|
-
|
|
307
|
+
if self.moe.has_bias:
|
|
308
|
+
w13_bias = jnp.expand_dims(w13_bias, 1)
|
|
309
|
+
w2_bias = jnp.expand_dims(w2_bias, 1)
|
|
310
|
+
|
|
261
311
|
if layer.use_ep:
|
|
312
|
+
ep_sharding = NamedSharding(self.mesh,
|
|
313
|
+
P(ShardingAxisName.EXPERT))
|
|
262
314
|
w13_weight = jax.device_put(
|
|
263
|
-
w13_weight,
|
|
264
|
-
Format(Layout((0, 1, 2)),
|
|
265
|
-
NamedSharding(self.mesh, P("model", None, None))))
|
|
315
|
+
w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
|
|
266
316
|
w2_weight = jax.device_put(
|
|
267
|
-
w2_weight,
|
|
268
|
-
Format(Layout((0, 1, 2)),
|
|
269
|
-
NamedSharding(self.mesh, P("model", None, None))))
|
|
317
|
+
w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
|
|
270
318
|
|
|
271
319
|
if self.moe.has_bias:
|
|
272
320
|
w13_bias = jax.device_put(
|
|
273
|
-
w13_bias,
|
|
274
|
-
Format(Layout((0, 1)),
|
|
275
|
-
NamedSharding(self.mesh, P("model", None))))
|
|
321
|
+
w13_bias, Format(Layout((0, 1, 2)), ep_sharding))
|
|
276
322
|
w2_bias = jax.device_put(
|
|
277
|
-
w2_bias,
|
|
278
|
-
Format(Layout((0, 1)),
|
|
279
|
-
NamedSharding(self.mesh, P("model", None))))
|
|
323
|
+
w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
|
|
280
324
|
|
|
281
325
|
else:
|
|
282
|
-
intermediate_size = w13_weight.shape[1] // 2
|
|
283
|
-
assert intermediate_size == w2_weight.shape[-1]
|
|
284
326
|
output_sizes = [intermediate_size, intermediate_size]
|
|
285
|
-
n_shards = self.mesh
|
|
327
|
+
n_shards = get_mesh_shape_product(self.mesh,
|
|
328
|
+
ShardingAxisName.MLP_TENSOR)
|
|
286
329
|
assert intermediate_size % n_shards == 0
|
|
330
|
+
|
|
287
331
|
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
288
332
|
w13_weight, output_sizes, n_shards, dim=1)
|
|
289
333
|
w13_weight = jax.device_put(
|
|
290
334
|
w13_weight,
|
|
291
|
-
Format(
|
|
292
|
-
|
|
335
|
+
Format(
|
|
336
|
+
Layout((0, 1, 2)),
|
|
337
|
+
NamedSharding(
|
|
338
|
+
self.mesh,
|
|
339
|
+
P(None, ShardingAxisName.MLP_TENSOR, None))))
|
|
293
340
|
w2_weight = jax.device_put(
|
|
294
341
|
w2_weight,
|
|
295
|
-
Format(
|
|
296
|
-
|
|
342
|
+
Format(
|
|
343
|
+
Layout((0, 1, 2)),
|
|
344
|
+
NamedSharding(
|
|
345
|
+
self.mesh,
|
|
346
|
+
P(None, None, ShardingAxisName.MLP_TENSOR))))
|
|
297
347
|
|
|
298
348
|
if self.moe.has_bias:
|
|
299
349
|
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
300
|
-
w13_bias, output_sizes, n_shards, dim=
|
|
350
|
+
w13_bias, output_sizes, n_shards, dim=2)
|
|
351
|
+
|
|
301
352
|
w13_bias = jax.device_put(
|
|
302
353
|
w13_bias,
|
|
303
|
-
Format(
|
|
304
|
-
|
|
354
|
+
Format(
|
|
355
|
+
Layout((0, 1, 2)),
|
|
356
|
+
NamedSharding(
|
|
357
|
+
self.mesh,
|
|
358
|
+
P(None, None, ShardingAxisName.MLP_TENSOR))))
|
|
305
359
|
w2_bias = jax.device_put(
|
|
306
360
|
w2_bias,
|
|
307
|
-
Format(Layout((0, 1)),
|
|
308
|
-
NamedSharding(self.mesh, P(None, None))))
|
|
361
|
+
Format(Layout((0, 1, 2)),
|
|
362
|
+
NamedSharding(self.mesh, P(None, None, None))))
|
|
309
363
|
|
|
310
364
|
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
311
365
|
requires_grad=False)
|
|
@@ -321,60 +375,54 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
321
375
|
layer: torch.nn.Module,
|
|
322
376
|
x: torch.Tensor,
|
|
323
377
|
router_logits: torch.Tensor,
|
|
324
|
-
top_k: int,
|
|
325
|
-
renormalize: bool,
|
|
326
|
-
use_grouped_topk: bool = False,
|
|
327
|
-
topk_group: Optional[int] = None,
|
|
328
|
-
num_expert_group: Optional[int] = None,
|
|
329
|
-
global_num_experts: int = -1,
|
|
330
|
-
expert_map: Optional[torch.Tensor] = None,
|
|
331
|
-
custom_routing_function: Optional[Callable] = None,
|
|
332
|
-
scoring_func: str = "softmax",
|
|
333
|
-
routed_scaling_factor: float = 1.0,
|
|
334
|
-
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
335
|
-
apply_router_weight_on_input: bool = False,
|
|
336
|
-
activation: str = "silu",
|
|
337
|
-
enable_eplb: bool = False,
|
|
338
|
-
expert_load_view: Optional[torch.Tensor] = None,
|
|
339
|
-
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
340
|
-
logical_replica_count: Optional[torch.Tensor] = None,
|
|
341
378
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
342
379
|
assert isinstance(layer, FusedMoE)
|
|
343
|
-
if scoring_func != "softmax":
|
|
380
|
+
if layer.scoring_func != "softmax":
|
|
344
381
|
raise NotImplementedError(
|
|
345
382
|
"Only softmax is supported for scoring_func")
|
|
346
383
|
|
|
347
|
-
|
|
384
|
+
x = jax_view(x)
|
|
385
|
+
w13_weight = jax_view(layer.w13_weight)
|
|
386
|
+
w2_weight = jax_view(layer.w2_weight)
|
|
387
|
+
w13_bias = w2_bias = None
|
|
388
|
+
if self.moe.has_bias:
|
|
389
|
+
w13_bias = jax_view(layer.w13_bias)
|
|
390
|
+
w2_bias = jax_view(layer.w2_bias)
|
|
391
|
+
gating_output = jax_view(router_logits)
|
|
392
|
+
|
|
393
|
+
if self.use_kernel:
|
|
394
|
+
actual_hidden_size = x.shape[-1]
|
|
395
|
+
padding_size = w13_weight.shape[-2] - actual_hidden_size
|
|
396
|
+
x = jnp.pad(x, ((0, 0), (0, padding_size)))
|
|
348
397
|
output = fused_ep_moe(
|
|
349
398
|
mesh=self.mesh,
|
|
350
|
-
tokens=
|
|
351
|
-
w1=
|
|
352
|
-
w2=
|
|
353
|
-
b1=
|
|
354
|
-
b2=
|
|
355
|
-
gating_output=
|
|
356
|
-
top_k=top_k,
|
|
399
|
+
tokens=x,
|
|
400
|
+
w1=w13_weight,
|
|
401
|
+
w2=w2_weight,
|
|
402
|
+
b1=w13_bias,
|
|
403
|
+
b2=w2_bias,
|
|
404
|
+
gating_output=gating_output,
|
|
405
|
+
top_k=layer.top_k,
|
|
357
406
|
ep_axis_name=self.ep_axis_name,
|
|
358
|
-
renormalize_topk_logits=renormalize,
|
|
359
|
-
act_fn=activation,
|
|
407
|
+
renormalize_topk_logits=layer.renormalize,
|
|
408
|
+
act_fn=layer.activation,
|
|
360
409
|
**self.block_size,
|
|
361
|
-
)
|
|
410
|
+
)[:, :actual_hidden_size]
|
|
362
411
|
else:
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
renormalize=renormalize,
|
|
374
|
-
reduce_results=layer.reduce_results,
|
|
412
|
+
output = fused_moe_func(
|
|
413
|
+
hidden_states=x,
|
|
414
|
+
w1=w13_weight,
|
|
415
|
+
w2=w2_weight,
|
|
416
|
+
w1_scale=None,
|
|
417
|
+
w2_scale=None,
|
|
418
|
+
w1_bias=w13_bias,
|
|
419
|
+
w2_bias=w2_bias,
|
|
420
|
+
gating_output=gating_output,
|
|
421
|
+
topk=layer.top_k,
|
|
422
|
+
renormalize=layer.renormalize,
|
|
375
423
|
mesh=self.mesh,
|
|
376
424
|
use_ep=layer.use_ep,
|
|
377
|
-
activation=activation,
|
|
425
|
+
activation=layer.activation,
|
|
378
426
|
)
|
|
379
427
|
|
|
380
428
|
return torch_view(output)
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import os
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -20,6 +34,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
|
20
34
|
ParallelLMHead, VocabParallelEmbedding)
|
|
21
35
|
|
|
22
36
|
from tpu_inference import envs
|
|
37
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
23
38
|
from tpu_inference.logger import init_logger
|
|
24
39
|
|
|
25
40
|
P = PartitionSpec
|
|
@@ -109,7 +124,8 @@ def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,
|
|
|
109
124
|
def _shard_vocab_parallel_embedding(layer: VocabParallelEmbedding,
|
|
110
125
|
mesh: Mesh) -> None:
|
|
111
126
|
weight = _convert_to_torchax_and_shard(
|
|
112
|
-
layer.weight, NamedSharding(mesh, P(
|
|
127
|
+
layer.weight, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR,
|
|
128
|
+
None)))
|
|
113
129
|
layer.weight = Parameter(weight, requires_grad=False)
|
|
114
130
|
|
|
115
131
|
|
|
@@ -118,11 +134,12 @@ def _shard_lm_head(layer: ParallelLMHead, mesh: Mesh):
|
|
|
118
134
|
# if that config is set, then we should not create new weights but reuse the
|
|
119
135
|
# weight from VocabParallelEmbedding
|
|
120
136
|
weight = _convert_to_torchax_and_shard(
|
|
121
|
-
layer.weight, NamedSharding(mesh, P(
|
|
137
|
+
layer.weight, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR,
|
|
138
|
+
None)))
|
|
122
139
|
layer.weight = Parameter(weight, requires_grad=False)
|
|
123
140
|
if layer.bias is not None:
|
|
124
|
-
bias = _convert_to_torchax_and_shard(
|
|
125
|
-
|
|
141
|
+
bias = _convert_to_torchax_and_shard(
|
|
142
|
+
layer.bias, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR)))
|
|
126
143
|
layer.bias = Parameter(bias, requires_grad=False)
|
|
127
144
|
|
|
128
145
|
|
tpu_inference/lora/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -4,7 +4,6 @@
|
|
|
4
4
|
import jax
|
|
5
5
|
import jax.numpy as jnp
|
|
6
6
|
import torch
|
|
7
|
-
import torch.nn.functional as F
|
|
8
7
|
from torchax.interop import call_jax
|
|
9
8
|
|
|
10
9
|
|
|
@@ -85,19 +84,15 @@ def bgmv_expand_slice(
|
|
|
85
84
|
add_inputs (bool): Whether or not to add the input tensor to the output
|
|
86
85
|
tensor.
|
|
87
86
|
"""
|
|
88
|
-
outputs = bgmv_torch(inputs, lora_b_weights,
|
|
87
|
+
outputs = bgmv_torch(inputs, lora_b_weights,
|
|
88
|
+
lora_indices_tensor) # [num_tokens, out_features]
|
|
89
89
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
output_tensor.shape[1] - (slice_offset + slice_size),
|
|
95
|
-
0,
|
|
96
|
-
0,
|
|
97
|
-
),
|
|
98
|
-
)
|
|
90
|
+
# Create a padded tensor manually to avoid issues with F.pad on sharded tensors.
|
|
91
|
+
# This is a more robust way to handle padding in a distributed environment.
|
|
92
|
+
outputs_padded = torch.zeros_like(output_tensor)
|
|
93
|
+
outputs_padded[:, slice_offset:slice_offset + slice_size] = outputs
|
|
99
94
|
|
|
100
95
|
if add_inputs:
|
|
101
|
-
return output_tensor +
|
|
96
|
+
return output_tensor + outputs_padded
|
|
102
97
|
else:
|
|
103
|
-
return
|
|
98
|
+
return outputs_padded
|
tpu_inference/models/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|