tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,19 +1,29 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional
|
|
2
16
|
|
|
3
17
|
import jax
|
|
4
18
|
import jax.numpy as jnp
|
|
5
19
|
import torch
|
|
6
|
-
from jax.experimental.layout import Format, Layout
|
|
7
20
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
21
|
from torch.nn.parameter import Parameter
|
|
9
22
|
from torchax.interop import jax_view, torch_view
|
|
10
23
|
from torchax.ops.mappings import t2j
|
|
11
24
|
from vllm.attention.layer import Attention
|
|
12
|
-
from vllm.logger import init_logger
|
|
13
25
|
from vllm.model_executor.layers.fused_moe.layer import (
|
|
14
26
|
FusedMoE, FusedMoEConfig, UnquantizedFusedMoEMethod)
|
|
15
|
-
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
|
16
|
-
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
|
|
17
27
|
from vllm.model_executor.layers.linear import (LinearBase,
|
|
18
28
|
UnquantizedLinearMethod)
|
|
19
29
|
from vllm.model_executor.layers.quantization import \
|
|
@@ -21,23 +31,31 @@ from vllm.model_executor.layers.quantization import \
|
|
|
21
31
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
22
32
|
QuantizationConfig, QuantizeMethodBase)
|
|
23
33
|
|
|
24
|
-
from tpu_inference import envs
|
|
25
|
-
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
26
34
|
from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
|
|
27
35
|
get_tpu_quant_method)
|
|
28
|
-
from tpu_inference.layers.
|
|
29
|
-
from tpu_inference.layers.
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
36
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
37
|
+
from tpu_inference.layers.common.utils import \
|
|
38
|
+
slice_sharded_tensor_for_concatenation
|
|
39
|
+
from tpu_inference.layers.vllm.fused_moe import (FusedMoEBackend,
|
|
40
|
+
fused_moe_apply,
|
|
41
|
+
select_moe_backend)
|
|
42
|
+
from tpu_inference.layers.vllm.process_weights.fused_moe_weights import (
|
|
43
|
+
FusedMoEWeights, process_moe_weights, shard_moe_weights)
|
|
44
|
+
from tpu_inference.layers.vllm.process_weights.linear_weights import (
|
|
45
|
+
LinearWeights, process_lienar_weights, shard_linear_weights,
|
|
46
|
+
to_parameter_list)
|
|
47
|
+
from tpu_inference.layers.vllm.quantization.configs import (
|
|
48
|
+
VllmQuantConfig, VllmQuantLinearConfig)
|
|
49
|
+
from tpu_inference.logger import init_logger
|
|
50
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
34
51
|
|
|
35
52
|
P = PartitionSpec
|
|
53
|
+
|
|
36
54
|
logger = init_logger(__name__)
|
|
37
55
|
|
|
38
56
|
|
|
39
57
|
@register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
|
|
40
|
-
class VllmUnquantizedConfig(QuantizationConfig,
|
|
58
|
+
class VllmUnquantizedConfig(QuantizationConfig, VllmQuantConfig):
|
|
41
59
|
|
|
42
60
|
@classmethod
|
|
43
61
|
def get_name(cls) -> str:
|
|
@@ -74,51 +92,73 @@ class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
|
|
|
74
92
|
|
|
75
93
|
class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|
76
94
|
|
|
77
|
-
def __init__(self,
|
|
78
|
-
self.
|
|
95
|
+
def __init__(self, linear_config: VllmQuantLinearConfig):
|
|
96
|
+
self.linear_config = linear_config
|
|
79
97
|
|
|
80
98
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
81
|
-
weight =
|
|
82
|
-
layer.weight,
|
|
83
|
-
NamedSharding(self.jax_config.mesh,
|
|
84
|
-
self.jax_config.weight_sharding),
|
|
85
|
-
self.jax_config.output_sizes,
|
|
86
|
-
self.jax_config.n_shards,
|
|
87
|
-
self.jax_config.fuse_matmuls,
|
|
88
|
-
)
|
|
99
|
+
weight = t2j(layer.weight, use_dlpack=False)
|
|
89
100
|
delattr(layer, "weight")
|
|
90
|
-
layer.weight = weight
|
|
91
|
-
|
|
92
101
|
if layer.bias is not None and not layer.skip_bias_add:
|
|
93
102
|
if layer.return_bias:
|
|
94
103
|
logger.warning_once("Bias might return incorrect value.")
|
|
95
|
-
|
|
96
|
-
bias = torch_to_jax_param(
|
|
97
|
-
layer.bias,
|
|
98
|
-
NamedSharding(self.jax_config.mesh,
|
|
99
|
-
self.jax_config.bias_sharding),
|
|
100
|
-
self.jax_config.output_sizes,
|
|
101
|
-
self.jax_config.n_shards,
|
|
102
|
-
self.jax_config.fuse_matmuls,
|
|
103
|
-
)
|
|
104
|
+
bias = t2j(layer.bias, use_dlpack=False)
|
|
104
105
|
delattr(layer, "bias")
|
|
105
|
-
|
|
106
|
+
else:
|
|
107
|
+
bias = None
|
|
108
|
+
|
|
109
|
+
@jax.jit
|
|
110
|
+
def process_unquantized_linear_weights(
|
|
111
|
+
weight: jax.Array,
|
|
112
|
+
bias: jax.Array | None,
|
|
113
|
+
) -> LinearWeights:
|
|
114
|
+
return process_lienar_weights(
|
|
115
|
+
LinearWeights(
|
|
116
|
+
weight=weight,
|
|
117
|
+
weight_scale=None,
|
|
118
|
+
zero_point=None,
|
|
119
|
+
bias=bias,
|
|
120
|
+
),
|
|
121
|
+
fused=self.linear_config.fuse_matmuls,
|
|
122
|
+
output_sizes=self.linear_config.output_sizes,
|
|
123
|
+
reorder_size=self.linear_config.n_shards,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
weights = process_unquantized_linear_weights(weight, bias)
|
|
127
|
+
weights = torch_view(
|
|
128
|
+
shard_linear_weights(
|
|
129
|
+
weights,
|
|
130
|
+
mesh=self.linear_config.mesh,
|
|
131
|
+
weight_p_spec=self.linear_config.weight_sharding,
|
|
132
|
+
bias_p_spec=self.linear_config.bias_sharding,
|
|
133
|
+
))
|
|
134
|
+
|
|
135
|
+
if self.linear_config.fuse_matmuls:
|
|
136
|
+
layer.weight = Parameter(weights.weight, requires_grad=False)
|
|
137
|
+
if bias is not None:
|
|
138
|
+
layer.bias = Parameter(weights.bias, requires_grad=False)
|
|
139
|
+
else:
|
|
140
|
+
layer.weight = to_parameter_list(weights.weight)
|
|
141
|
+
if bias is not None:
|
|
142
|
+
layer.bias = to_parameter_list(weights.bias)
|
|
106
143
|
|
|
107
144
|
def apply(self,
|
|
108
145
|
layer: torch.nn.Module,
|
|
109
146
|
x: torch.Tensor,
|
|
110
147
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
148
|
+
assert isinstance(layer, LinearBase)
|
|
149
|
+
|
|
111
150
|
with jax.named_scope(layer._get_name()):
|
|
112
|
-
if in_sharding := self.
|
|
113
|
-
x.shard_(NamedSharding(self.
|
|
151
|
+
if in_sharding := self.linear_config.get_input_sharding(x):
|
|
152
|
+
x.shard_(NamedSharding(self.linear_config.mesh, in_sharding))
|
|
114
153
|
|
|
115
|
-
if self.
|
|
154
|
+
if self.linear_config.fuse_matmuls:
|
|
116
155
|
out = self._apply_fused(layer, x, bias)
|
|
117
156
|
else:
|
|
118
157
|
out = self._apply_split(layer, x, bias)
|
|
119
158
|
|
|
120
|
-
if out_sharding := self.
|
|
121
|
-
out.shard_(NamedSharding(self.
|
|
159
|
+
if out_sharding := self.linear_config.get_output_sharding(out):
|
|
160
|
+
out.shard_(NamedSharding(self.linear_config.mesh,
|
|
161
|
+
out_sharding))
|
|
122
162
|
|
|
123
163
|
return out
|
|
124
164
|
|
|
@@ -134,7 +174,7 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|
|
134
174
|
outs += bias.jax()
|
|
135
175
|
|
|
136
176
|
outs = slice_sharded_tensor_for_concatenation(
|
|
137
|
-
outs, self.
|
|
177
|
+
outs, self.linear_config.output_sizes, self.linear_config.n_shards)
|
|
138
178
|
out = jnp.concatenate(outs, axis=-1)
|
|
139
179
|
return torch_view(out)
|
|
140
180
|
|
|
@@ -160,215 +200,99 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|
|
160
200
|
|
|
161
201
|
class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
162
202
|
|
|
163
|
-
def __init__(
|
|
164
|
-
moe: FusedMoEConfig,
|
|
165
|
-
mesh: Mesh,
|
|
166
|
-
ep_axis_name: str = 'model'):
|
|
167
|
-
super().__init__(moe)
|
|
168
|
-
self.mesh = mesh
|
|
169
|
-
self.use_kernel = envs.USE_MOE_EP_KERNEL
|
|
170
|
-
self.ep_axis_name = ep_axis_name
|
|
171
|
-
# TODO: Use autotune table once we have it.
|
|
172
|
-
self.block_size = {
|
|
173
|
-
"bt": 16,
|
|
174
|
-
"bf": 384,
|
|
175
|
-
"bd1": 512,
|
|
176
|
-
"bd2": 512,
|
|
177
|
-
"btc": 16,
|
|
178
|
-
"bfc": 384,
|
|
179
|
-
"bd1c": 256,
|
|
180
|
-
"bd2c": 256,
|
|
181
|
-
}
|
|
182
|
-
|
|
183
|
-
def select_gemm_impl(
|
|
203
|
+
def __init__(
|
|
184
204
|
self,
|
|
185
|
-
prepare_finalize: FusedMoEPrepareAndFinalize,
|
|
186
205
|
moe: FusedMoEConfig,
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
206
|
+
mesh: Mesh,
|
|
207
|
+
ep_axis_name: str = "model",
|
|
208
|
+
):
|
|
209
|
+
super().__init__(moe)
|
|
210
|
+
self.mesh = mesh
|
|
211
|
+
self.moe_backend = select_moe_backend(self.moe)
|
|
212
|
+
|
|
213
|
+
self.extra_backend_kwargs = {}
|
|
214
|
+
if self.moe_backend == FusedMoEBackend.FUSED_MOE:
|
|
215
|
+
# When fused moe kernle is used, we pass extra arguments like
|
|
216
|
+
# tuned block sizes to the kernel.
|
|
217
|
+
self.extra_backend_kwargs = dict(
|
|
218
|
+
ep_axis_name=ep_axis_name,
|
|
219
|
+
# TODO: Use autotune table once we have it.
|
|
220
|
+
bt=64,
|
|
221
|
+
bf=1024,
|
|
222
|
+
bd1=1536,
|
|
223
|
+
bd2=1536,
|
|
224
|
+
btc=64,
|
|
225
|
+
bfc=1024,
|
|
226
|
+
bd1c=1536,
|
|
227
|
+
bd2c=1536,
|
|
228
|
+
)
|
|
191
229
|
|
|
192
230
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
193
231
|
assert isinstance(layer, FusedMoE)
|
|
232
|
+
|
|
194
233
|
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
195
234
|
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
196
235
|
|
|
197
236
|
if self.moe.has_bias:
|
|
198
237
|
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
199
238
|
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
200
|
-
|
|
201
|
-
if layer.activation == "swigluoai":
|
|
202
|
-
# When using swigluoai, vLLM splits gmm output in a interleaved way.
|
|
203
|
-
# However, interleaved split is not performant on TPU. Therefore,
|
|
204
|
-
# we preprocess the weight so that splitting gmm output by middle
|
|
205
|
-
# can still get the same result.
|
|
206
|
-
w1_weight = w13_weight[:, ::2, :]
|
|
207
|
-
w3_weight = w13_weight[:, 1::2, :]
|
|
208
|
-
w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
|
|
209
|
-
|
|
210
|
-
if self.moe.has_bias:
|
|
211
|
-
w1_bias = w13_bias[:, ::2]
|
|
212
|
-
w3_bias = w13_bias[:, 1::2]
|
|
213
|
-
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
214
|
-
|
|
215
|
-
if self.use_kernel and layer.use_ep:
|
|
216
|
-
# Kernel expects:
|
|
217
|
-
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
218
|
-
# w2: (num_experts, intermediate_size, hidden_size)
|
|
219
|
-
# Current format:
|
|
220
|
-
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
221
|
-
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
222
|
-
num_experts = w13_weight.shape[0]
|
|
223
|
-
intermediate_size = w13_weight.shape[1] // 2
|
|
224
|
-
hidden_size = w13_weight.shape[2]
|
|
225
|
-
|
|
226
|
-
# Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
|
|
227
|
-
w13_reshaped = w13_weight.reshape(num_experts, 2,
|
|
228
|
-
intermediate_size, hidden_size)
|
|
229
|
-
w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
|
|
230
|
-
|
|
231
|
-
# Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
|
|
232
|
-
w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
|
|
233
|
-
|
|
234
|
-
# Apply EP sharding
|
|
235
|
-
w13_weight = jax.device_put(
|
|
236
|
-
w13_weight_transposed,
|
|
237
|
-
Format(Layout((0, 1, 2, 3)),
|
|
238
|
-
NamedSharding(self.mesh, P("model", None, None, None))))
|
|
239
|
-
w2_weight = jax.device_put(
|
|
240
|
-
w2_weight_transposed,
|
|
241
|
-
Format(Layout((0, 1, 2)),
|
|
242
|
-
NamedSharding(self.mesh, P("model", None, None))))
|
|
243
|
-
|
|
244
|
-
if self.moe.has_bias:
|
|
245
|
-
w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
|
|
246
|
-
|
|
247
|
-
# Apply EP sharding
|
|
248
|
-
w13_bias = jax.device_put(
|
|
249
|
-
w13_bias,
|
|
250
|
-
Format(Layout((0, 1, 2)),
|
|
251
|
-
NamedSharding(self.mesh, P("model", None, None))))
|
|
252
|
-
w2_bias = jax.device_put(
|
|
253
|
-
w2_bias,
|
|
254
|
-
Format(Layout((0, 1)),
|
|
255
|
-
NamedSharding(self.mesh, P("model", None))))
|
|
256
|
-
|
|
257
239
|
else:
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
240
|
+
w13_bias = w2_bias = None
|
|
241
|
+
|
|
242
|
+
@jax.jit
|
|
243
|
+
def process_unquantized_moe_weights(
|
|
244
|
+
w13_weight: jax.Array,
|
|
245
|
+
w13_bias: jax.Array | None,
|
|
246
|
+
w2_weight: jax.Array,
|
|
247
|
+
w2_bias: jax.Array | None,
|
|
248
|
+
) -> FusedMoEWeights:
|
|
249
|
+
|
|
250
|
+
w13_interleave = layer.activation == "swigluoai"
|
|
251
|
+
w13_reorder_size = get_mesh_shape_product(
|
|
252
|
+
self.mesh, ShardingAxisName.MLP_TENSOR)
|
|
253
|
+
|
|
254
|
+
return process_moe_weights(
|
|
255
|
+
FusedMoEWeights(
|
|
256
|
+
w13_weight=w13_weight,
|
|
257
|
+
w13_weight_scale=None,
|
|
258
|
+
w13_bias=w13_bias,
|
|
259
|
+
w2_weight=w2_weight,
|
|
260
|
+
w2_weight_scale=None,
|
|
261
|
+
w2_bias=w2_bias,
|
|
262
|
+
),
|
|
263
|
+
moe_backend=self.moe_backend,
|
|
264
|
+
w13_reorder_size=w13_reorder_size,
|
|
265
|
+
w13_interleave=w13_interleave,
|
|
266
|
+
)
|
|
278
267
|
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
NamedSharding(self.mesh, P(None, "model", None))))
|
|
291
|
-
w2_weight = jax.device_put(
|
|
292
|
-
w2_weight,
|
|
293
|
-
Format(Layout((0, 1, 2)),
|
|
294
|
-
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
295
|
-
|
|
296
|
-
if self.moe.has_bias:
|
|
297
|
-
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
298
|
-
w13_bias, output_sizes, n_shards, dim=1)
|
|
299
|
-
w13_bias = jax.device_put(
|
|
300
|
-
w13_bias,
|
|
301
|
-
Format(Layout((0, 1)),
|
|
302
|
-
NamedSharding(self.mesh, P(None, "model"))))
|
|
303
|
-
w2_bias = jax.device_put(
|
|
304
|
-
w2_bias,
|
|
305
|
-
Format(Layout((0, 1)),
|
|
306
|
-
NamedSharding(self.mesh, P(None, None))))
|
|
307
|
-
|
|
308
|
-
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
309
|
-
requires_grad=False)
|
|
310
|
-
layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
268
|
+
weights = process_unquantized_moe_weights(
|
|
269
|
+
w13_weight,
|
|
270
|
+
w13_bias,
|
|
271
|
+
w2_weight,
|
|
272
|
+
w2_bias,
|
|
273
|
+
)
|
|
274
|
+
weights = torch_view(
|
|
275
|
+
shard_moe_weights(weights, self.moe_backend, self.mesh))
|
|
276
|
+
|
|
277
|
+
layer.w13_weight = Parameter(weights.w13_weight, requires_grad=False)
|
|
278
|
+
layer.w2_weight = Parameter(weights.w2_weight, requires_grad=False)
|
|
311
279
|
|
|
312
280
|
if self.moe.has_bias:
|
|
313
|
-
layer.w13_bias = Parameter(
|
|
314
|
-
|
|
315
|
-
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
281
|
+
layer.w13_bias = Parameter(weights.w13_bias, requires_grad=False)
|
|
282
|
+
layer.w2_bias = Parameter(weights.w2_bias, requires_grad=False)
|
|
316
283
|
|
|
317
284
|
def apply(
|
|
318
285
|
self,
|
|
319
286
|
layer: torch.nn.Module,
|
|
320
287
|
x: torch.Tensor,
|
|
321
288
|
router_logits: torch.Tensor,
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
333
|
-
apply_router_weight_on_input: bool = False,
|
|
334
|
-
activation: str = "silu",
|
|
335
|
-
enable_eplb: bool = False,
|
|
336
|
-
expert_load_view: Optional[torch.Tensor] = None,
|
|
337
|
-
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
338
|
-
logical_replica_count: Optional[torch.Tensor] = None,
|
|
339
|
-
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
340
|
-
assert isinstance(layer, FusedMoE)
|
|
341
|
-
if scoring_func != "softmax":
|
|
342
|
-
raise NotImplementedError(
|
|
343
|
-
"Only softmax is supported for scoring_func")
|
|
344
|
-
|
|
345
|
-
if self.use_kernel and layer.use_ep:
|
|
346
|
-
output = fused_ep_moe(
|
|
347
|
-
mesh=self.mesh,
|
|
348
|
-
tokens=jax_view(x),
|
|
349
|
-
w1=jax_view(layer.w13_weight),
|
|
350
|
-
w2=jax_view(layer.w2_weight),
|
|
351
|
-
gating_output=jax_view(router_logits),
|
|
352
|
-
top_k=top_k,
|
|
353
|
-
ep_axis_name=self.ep_axis_name,
|
|
354
|
-
**self.block_size,
|
|
355
|
-
)
|
|
356
|
-
else:
|
|
357
|
-
# Use the original implementation
|
|
358
|
-
output = fused_moe_func_padded(
|
|
359
|
-
jax_view(x),
|
|
360
|
-
jax_view(layer.w13_weight),
|
|
361
|
-
jax_view(layer.w2_weight),
|
|
362
|
-
jax_view(layer.w13_bias) if self.moe.has_bias else None,
|
|
363
|
-
jax_view(layer.w2_bias) if self.moe.has_bias else None,
|
|
364
|
-
jax_view(router_logits),
|
|
365
|
-
topk=top_k,
|
|
366
|
-
global_num_experts=global_num_experts,
|
|
367
|
-
renormalize=renormalize,
|
|
368
|
-
reduce_results=layer.reduce_results,
|
|
369
|
-
mesh=self.mesh,
|
|
370
|
-
use_ep=layer.use_ep,
|
|
371
|
-
activation=activation,
|
|
372
|
-
)
|
|
373
|
-
|
|
374
|
-
return torch_view(output)
|
|
289
|
+
) -> torch.Tensor:
|
|
290
|
+
|
|
291
|
+
return fused_moe_apply(
|
|
292
|
+
layer,
|
|
293
|
+
x,
|
|
294
|
+
router_logits,
|
|
295
|
+
self.moe_backend,
|
|
296
|
+
self.mesh,
|
|
297
|
+
self.extra_backend_kwargs,
|
|
298
|
+
)
|
tpu_inference/lora/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -4,7 +4,6 @@
|
|
|
4
4
|
import jax
|
|
5
5
|
import jax.numpy as jnp
|
|
6
6
|
import torch
|
|
7
|
-
import torch.nn.functional as F
|
|
8
7
|
from torchax.interop import call_jax
|
|
9
8
|
|
|
10
9
|
|
|
@@ -85,19 +84,15 @@ def bgmv_expand_slice(
|
|
|
85
84
|
add_inputs (bool): Whether or not to add the input tensor to the output
|
|
86
85
|
tensor.
|
|
87
86
|
"""
|
|
88
|
-
outputs = bgmv_torch(inputs, lora_b_weights,
|
|
87
|
+
outputs = bgmv_torch(inputs, lora_b_weights,
|
|
88
|
+
lora_indices_tensor) # [num_tokens, out_features]
|
|
89
89
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
output_tensor.shape[1] - (slice_offset + slice_size),
|
|
95
|
-
0,
|
|
96
|
-
0,
|
|
97
|
-
),
|
|
98
|
-
)
|
|
90
|
+
# Create a padded tensor manually to avoid issues with F.pad on sharded tensors.
|
|
91
|
+
# This is a more robust way to handle padding in a distributed environment.
|
|
92
|
+
outputs_padded = torch.zeros_like(output_tensor)
|
|
93
|
+
outputs_padded[:, slice_offset:slice_offset + slice_size] = outputs
|
|
99
94
|
|
|
100
95
|
if add_inputs:
|
|
101
|
-
return output_tensor +
|
|
96
|
+
return output_tensor + outputs_padded
|
|
102
97
|
else:
|
|
103
|
-
return
|
|
98
|
+
return outputs_padded
|
tpu_inference/models/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|