tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,19 +1,29 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional
|
|
2
16
|
|
|
3
17
|
import jax
|
|
4
18
|
import jax.numpy as jnp
|
|
5
19
|
import torch
|
|
6
|
-
from jax.experimental.layout import Format, Layout
|
|
7
20
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
21
|
from torch.nn.parameter import Parameter
|
|
9
22
|
from torchax.interop import jax_view, torch_view
|
|
10
23
|
from torchax.ops.mappings import t2j
|
|
11
24
|
from vllm.attention.layer import Attention
|
|
12
|
-
from vllm.logger import init_logger
|
|
13
25
|
from vllm.model_executor.layers.fused_moe.layer import (
|
|
14
26
|
FusedMoE, FusedMoEConfig, UnquantizedFusedMoEMethod)
|
|
15
|
-
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
|
16
|
-
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
|
|
17
27
|
from vllm.model_executor.layers.linear import (LinearBase,
|
|
18
28
|
UnquantizedLinearMethod)
|
|
19
29
|
from vllm.model_executor.layers.quantization import \
|
|
@@ -21,27 +31,31 @@ from vllm.model_executor.layers.quantization import \
|
|
|
21
31
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
22
32
|
QuantizationConfig, QuantizeMethodBase)
|
|
23
33
|
|
|
24
|
-
from tpu_inference import envs
|
|
25
|
-
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
26
34
|
from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
|
|
27
35
|
get_tpu_quant_method)
|
|
28
|
-
from tpu_inference.layers.
|
|
29
|
-
from tpu_inference.layers.
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
36
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
37
|
+
from tpu_inference.layers.common.utils import \
|
|
38
|
+
slice_sharded_tensor_for_concatenation
|
|
39
|
+
from tpu_inference.layers.vllm.fused_moe import (FusedMoEBackend,
|
|
40
|
+
fused_moe_apply,
|
|
41
|
+
select_moe_backend)
|
|
42
|
+
from tpu_inference.layers.vllm.process_weights.fused_moe_weights import (
|
|
43
|
+
FusedMoEWeights, process_moe_weights, shard_moe_weights)
|
|
44
|
+
from tpu_inference.layers.vllm.process_weights.linear_weights import (
|
|
45
|
+
LinearWeights, process_lienar_weights, shard_linear_weights,
|
|
46
|
+
to_parameter_list)
|
|
47
|
+
from tpu_inference.layers.vllm.quantization.configs import (
|
|
48
|
+
VllmQuantConfig, VllmQuantLinearConfig)
|
|
49
|
+
from tpu_inference.logger import init_logger
|
|
50
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
34
51
|
|
|
35
52
|
P = PartitionSpec
|
|
36
|
-
logger = init_logger(__name__)
|
|
37
|
-
|
|
38
53
|
|
|
39
|
-
|
|
40
|
-
return (a + b - 1) // b * b
|
|
54
|
+
logger = init_logger(__name__)
|
|
41
55
|
|
|
42
56
|
|
|
43
57
|
@register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
|
|
44
|
-
class VllmUnquantizedConfig(QuantizationConfig,
|
|
58
|
+
class VllmUnquantizedConfig(QuantizationConfig, VllmQuantConfig):
|
|
45
59
|
|
|
46
60
|
@classmethod
|
|
47
61
|
def get_name(cls) -> str:
|
|
@@ -78,35 +92,54 @@ class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
|
|
|
78
92
|
|
|
79
93
|
class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|
80
94
|
|
|
81
|
-
def __init__(self,
|
|
82
|
-
self.
|
|
95
|
+
def __init__(self, linear_config: VllmQuantLinearConfig):
|
|
96
|
+
self.linear_config = linear_config
|
|
83
97
|
|
|
84
98
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
85
|
-
weight =
|
|
86
|
-
layer.weight,
|
|
87
|
-
NamedSharding(self.jax_config.mesh,
|
|
88
|
-
self.jax_config.weight_sharding),
|
|
89
|
-
self.jax_config.output_sizes,
|
|
90
|
-
self.jax_config.n_shards,
|
|
91
|
-
self.jax_config.fuse_matmuls,
|
|
92
|
-
)
|
|
99
|
+
weight = t2j(layer.weight, use_dlpack=False)
|
|
93
100
|
delattr(layer, "weight")
|
|
94
|
-
layer.weight = weight
|
|
95
|
-
|
|
96
101
|
if layer.bias is not None and not layer.skip_bias_add:
|
|
97
102
|
if layer.return_bias:
|
|
98
103
|
logger.warning_once("Bias might return incorrect value.")
|
|
99
|
-
|
|
100
|
-
bias = torch_to_jax_param(
|
|
101
|
-
layer.bias,
|
|
102
|
-
NamedSharding(self.jax_config.mesh,
|
|
103
|
-
self.jax_config.bias_sharding),
|
|
104
|
-
self.jax_config.output_sizes,
|
|
105
|
-
self.jax_config.n_shards,
|
|
106
|
-
self.jax_config.fuse_matmuls,
|
|
107
|
-
)
|
|
104
|
+
bias = t2j(layer.bias, use_dlpack=False)
|
|
108
105
|
delattr(layer, "bias")
|
|
109
|
-
|
|
106
|
+
else:
|
|
107
|
+
bias = None
|
|
108
|
+
|
|
109
|
+
@jax.jit
|
|
110
|
+
def process_unquantized_linear_weights(
|
|
111
|
+
weight: jax.Array,
|
|
112
|
+
bias: jax.Array | None,
|
|
113
|
+
) -> LinearWeights:
|
|
114
|
+
return process_lienar_weights(
|
|
115
|
+
LinearWeights(
|
|
116
|
+
weight=weight,
|
|
117
|
+
weight_scale=None,
|
|
118
|
+
zero_point=None,
|
|
119
|
+
bias=bias,
|
|
120
|
+
),
|
|
121
|
+
fused=self.linear_config.fuse_matmuls,
|
|
122
|
+
output_sizes=self.linear_config.output_sizes,
|
|
123
|
+
reorder_size=self.linear_config.n_shards,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
weights = process_unquantized_linear_weights(weight, bias)
|
|
127
|
+
weights = torch_view(
|
|
128
|
+
shard_linear_weights(
|
|
129
|
+
weights,
|
|
130
|
+
mesh=self.linear_config.mesh,
|
|
131
|
+
weight_p_spec=self.linear_config.weight_sharding,
|
|
132
|
+
bias_p_spec=self.linear_config.bias_sharding,
|
|
133
|
+
))
|
|
134
|
+
|
|
135
|
+
if self.linear_config.fuse_matmuls:
|
|
136
|
+
layer.weight = Parameter(weights.weight, requires_grad=False)
|
|
137
|
+
if bias is not None:
|
|
138
|
+
layer.bias = Parameter(weights.bias, requires_grad=False)
|
|
139
|
+
else:
|
|
140
|
+
layer.weight = to_parameter_list(weights.weight)
|
|
141
|
+
if bias is not None:
|
|
142
|
+
layer.bias = to_parameter_list(weights.bias)
|
|
110
143
|
|
|
111
144
|
def apply(self,
|
|
112
145
|
layer: torch.nn.Module,
|
|
@@ -115,16 +148,17 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|
|
115
148
|
assert isinstance(layer, LinearBase)
|
|
116
149
|
|
|
117
150
|
with jax.named_scope(layer._get_name()):
|
|
118
|
-
if in_sharding := self.
|
|
119
|
-
x.shard_(NamedSharding(self.
|
|
151
|
+
if in_sharding := self.linear_config.get_input_sharding(x):
|
|
152
|
+
x.shard_(NamedSharding(self.linear_config.mesh, in_sharding))
|
|
120
153
|
|
|
121
|
-
if self.
|
|
154
|
+
if self.linear_config.fuse_matmuls:
|
|
122
155
|
out = self._apply_fused(layer, x, bias)
|
|
123
156
|
else:
|
|
124
157
|
out = self._apply_split(layer, x, bias)
|
|
125
158
|
|
|
126
|
-
if out_sharding := self.
|
|
127
|
-
out.shard_(NamedSharding(self.
|
|
159
|
+
if out_sharding := self.linear_config.get_output_sharding(out):
|
|
160
|
+
out.shard_(NamedSharding(self.linear_config.mesh,
|
|
161
|
+
out_sharding))
|
|
128
162
|
|
|
129
163
|
return out
|
|
130
164
|
|
|
@@ -140,7 +174,7 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|
|
140
174
|
outs += bias.jax()
|
|
141
175
|
|
|
142
176
|
outs = slice_sharded_tensor_for_concatenation(
|
|
143
|
-
outs, self.
|
|
177
|
+
outs, self.linear_config.output_sizes, self.linear_config.n_shards)
|
|
144
178
|
out = jnp.concatenate(outs, axis=-1)
|
|
145
179
|
return torch_view(out)
|
|
146
180
|
|
|
@@ -166,232 +200,99 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
|
|
|
166
200
|
|
|
167
201
|
class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
168
202
|
|
|
169
|
-
def __init__(
|
|
170
|
-
moe: FusedMoEConfig,
|
|
171
|
-
mesh: Mesh,
|
|
172
|
-
ep_axis_name: str = 'model'):
|
|
173
|
-
super().__init__(moe)
|
|
174
|
-
self.mesh = mesh
|
|
175
|
-
self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
|
|
176
|
-
self.ep_axis_name = ep_axis_name
|
|
177
|
-
# TODO: Use autotune table once we have it.
|
|
178
|
-
self.block_size = {
|
|
179
|
-
"bt": 64,
|
|
180
|
-
"bf": 1024,
|
|
181
|
-
"bd1": 1536,
|
|
182
|
-
"bd2": 1536,
|
|
183
|
-
"btc": 64,
|
|
184
|
-
"bfc": 1024,
|
|
185
|
-
"bd1c": 1536,
|
|
186
|
-
"bd2c": 1536,
|
|
187
|
-
}
|
|
188
|
-
|
|
189
|
-
def select_gemm_impl(
|
|
203
|
+
def __init__(
|
|
190
204
|
self,
|
|
191
|
-
prepare_finalize: FusedMoEPrepareAndFinalize,
|
|
192
205
|
moe: FusedMoEConfig,
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
206
|
+
mesh: Mesh,
|
|
207
|
+
ep_axis_name: str = "model",
|
|
208
|
+
):
|
|
209
|
+
super().__init__(moe)
|
|
210
|
+
self.mesh = mesh
|
|
211
|
+
self.moe_backend = select_moe_backend(self.moe)
|
|
212
|
+
|
|
213
|
+
self.extra_backend_kwargs = {}
|
|
214
|
+
if self.moe_backend == FusedMoEBackend.FUSED_MOE:
|
|
215
|
+
# When fused moe kernle is used, we pass extra arguments like
|
|
216
|
+
# tuned block sizes to the kernel.
|
|
217
|
+
self.extra_backend_kwargs = dict(
|
|
218
|
+
ep_axis_name=ep_axis_name,
|
|
219
|
+
# TODO: Use autotune table once we have it.
|
|
220
|
+
bt=64,
|
|
221
|
+
bf=1024,
|
|
222
|
+
bd1=1536,
|
|
223
|
+
bd2=1536,
|
|
224
|
+
btc=64,
|
|
225
|
+
bfc=1024,
|
|
226
|
+
bd1c=1536,
|
|
227
|
+
bd2c=1536,
|
|
228
|
+
)
|
|
197
229
|
|
|
198
230
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
199
231
|
assert isinstance(layer, FusedMoE)
|
|
232
|
+
|
|
200
233
|
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
201
234
|
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
202
235
|
|
|
203
|
-
num_experts, hidden_size, intermediate_size = w2_weight.shape
|
|
204
|
-
|
|
205
236
|
if self.moe.has_bias:
|
|
206
237
|
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
207
238
|
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
208
|
-
|
|
209
|
-
if layer.activation == "swigluoai":
|
|
210
|
-
# When using swigluoai, vLLM splits gmm output in a interleaved way.
|
|
211
|
-
# However, interleaved split is not performant on TPU. Therefore,
|
|
212
|
-
# we preprocess the weight so that splitting gmm output by middle
|
|
213
|
-
# can still get the same result.
|
|
214
|
-
w1_weight = w13_weight[:, ::2, :]
|
|
215
|
-
w3_weight = w13_weight[:, 1::2, :]
|
|
216
|
-
w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
|
|
217
|
-
|
|
218
|
-
if self.moe.has_bias:
|
|
219
|
-
w1_bias = w13_bias[:, ::2]
|
|
220
|
-
w3_bias = w13_bias[:, 1::2]
|
|
221
|
-
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
222
|
-
|
|
223
|
-
if self.use_kernel:
|
|
224
|
-
# Kernel expects:
|
|
225
|
-
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
226
|
-
# w2: (num_experts, intermediate_size, hidden_size)
|
|
227
|
-
# Current format:
|
|
228
|
-
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
229
|
-
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
230
|
-
num_experts = w13_weight.shape[0]
|
|
231
|
-
intermediate_size = w13_weight.shape[1] // 2
|
|
232
|
-
hidden_size = w13_weight.shape[2]
|
|
233
|
-
|
|
234
|
-
padded_intermediate_size = align_to(intermediate_size, 256)
|
|
235
|
-
padded_hidden_size = align_to(hidden_size, 256)
|
|
236
|
-
|
|
237
|
-
w13_weight = w13_weight.reshape(num_experts, 2, intermediate_size,
|
|
238
|
-
hidden_size)
|
|
239
|
-
w13_weight = jnp.transpose(w13_weight, (0, 1, 3, 2))
|
|
240
|
-
|
|
241
|
-
# Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
|
|
242
|
-
w2_weight = jnp.transpose(w2_weight, (0, 2, 1))
|
|
243
|
-
|
|
244
|
-
w13_weight = jnp.pad(
|
|
245
|
-
w13_weight,
|
|
246
|
-
((0, 0), (0, 0), (0, padded_hidden_size - hidden_size),
|
|
247
|
-
(0, padded_intermediate_size - intermediate_size)),
|
|
248
|
-
constant_values=0)
|
|
249
|
-
|
|
250
|
-
w2_weight = jnp.pad(
|
|
251
|
-
w2_weight,
|
|
252
|
-
((0, 0), (0, padded_intermediate_size - intermediate_size),
|
|
253
|
-
(0, padded_hidden_size - hidden_size)),
|
|
254
|
-
constant_values=0)
|
|
255
|
-
|
|
256
|
-
# Apply EP sharding
|
|
257
|
-
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
258
|
-
|
|
259
|
-
w13_weight = jax.device_put(
|
|
260
|
-
w13_weight,
|
|
261
|
-
Format(Layout((0, 1, 2, 3)),
|
|
262
|
-
NamedSharding(self.mesh, P("model", None, None, None))))
|
|
263
|
-
w2_weight = jax.device_put(
|
|
264
|
-
w2_weight,
|
|
265
|
-
Format(Layout((0, 1, 2)),
|
|
266
|
-
NamedSharding(self.mesh, P("model", None, None))))
|
|
267
|
-
|
|
268
|
-
if self.moe.has_bias:
|
|
269
|
-
w13_bias = w13_bias.astype(jnp.float32).reshape(
|
|
270
|
-
num_experts, 2, 1, intermediate_size)
|
|
271
|
-
w2_bias = w2_bias.astype(jnp.float32).reshape(
|
|
272
|
-
num_experts, 1, hidden_size)
|
|
273
|
-
|
|
274
|
-
w13_bias = jnp.pad(
|
|
275
|
-
w13_bias,
|
|
276
|
-
((0, 0), (0, 0), (0, 0),
|
|
277
|
-
(0, padded_intermediate_size - intermediate_size)),
|
|
278
|
-
constant_values=0)
|
|
279
|
-
|
|
280
|
-
w2_bias = jnp.pad(w2_bias,
|
|
281
|
-
((0, 0), (0, 0),
|
|
282
|
-
(0, padded_hidden_size - hidden_size)),
|
|
283
|
-
constant_values=0)
|
|
284
|
-
|
|
285
|
-
# Apply EP sharding
|
|
286
|
-
w13_bias = jax.device_put(
|
|
287
|
-
w13_bias, Format(Layout((0, 1, 2, 3)), ep_sharding))
|
|
288
|
-
w2_bias = jax.device_put(
|
|
289
|
-
w2_bias, Format(Layout((0, 1, 2)), ep_sharding))
|
|
290
239
|
else:
|
|
240
|
+
w13_bias = w2_bias = None
|
|
241
|
+
|
|
242
|
+
@jax.jit
|
|
243
|
+
def process_unquantized_moe_weights(
|
|
244
|
+
w13_weight: jax.Array,
|
|
245
|
+
w13_bias: jax.Array | None,
|
|
246
|
+
w2_weight: jax.Array,
|
|
247
|
+
w2_bias: jax.Array | None,
|
|
248
|
+
) -> FusedMoEWeights:
|
|
249
|
+
|
|
250
|
+
w13_interleave = layer.activation == "swigluoai"
|
|
251
|
+
w13_reorder_size = get_mesh_shape_product(
|
|
252
|
+
self.mesh, ShardingAxisName.MLP_TENSOR)
|
|
253
|
+
|
|
254
|
+
return process_moe_weights(
|
|
255
|
+
FusedMoEWeights(
|
|
256
|
+
w13_weight=w13_weight,
|
|
257
|
+
w13_weight_scale=None,
|
|
258
|
+
w13_bias=w13_bias,
|
|
259
|
+
w2_weight=w2_weight,
|
|
260
|
+
w2_weight_scale=None,
|
|
261
|
+
w2_bias=w2_bias,
|
|
262
|
+
),
|
|
263
|
+
moe_backend=self.moe_backend,
|
|
264
|
+
w13_reorder_size=w13_reorder_size,
|
|
265
|
+
w13_interleave=w13_interleave,
|
|
266
|
+
)
|
|
291
267
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
268
|
+
weights = process_unquantized_moe_weights(
|
|
269
|
+
w13_weight,
|
|
270
|
+
w13_bias,
|
|
271
|
+
w2_weight,
|
|
272
|
+
w2_bias,
|
|
273
|
+
)
|
|
274
|
+
weights = torch_view(
|
|
275
|
+
shard_moe_weights(weights, self.moe_backend, self.mesh))
|
|
298
276
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
w13_bias, Format(Layout((0, 1)), ep_sharding))
|
|
302
|
-
w2_bias = jax.device_put(
|
|
303
|
-
w2_bias, Format(Layout((0, 1)), ep_sharding))
|
|
304
|
-
|
|
305
|
-
else:
|
|
306
|
-
output_sizes = [intermediate_size, intermediate_size]
|
|
307
|
-
n_shards = self.mesh.shape["model"]
|
|
308
|
-
assert intermediate_size % n_shards == 0
|
|
309
|
-
|
|
310
|
-
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
311
|
-
w13_weight, output_sizes, n_shards, dim=1)
|
|
312
|
-
w13_weight = jax.device_put(
|
|
313
|
-
w13_weight,
|
|
314
|
-
Format(Layout((0, 1, 2)),
|
|
315
|
-
NamedSharding(self.mesh, P(None, "model", None))))
|
|
316
|
-
w2_weight = jax.device_put(
|
|
317
|
-
w2_weight,
|
|
318
|
-
Format(Layout((0, 1, 2)),
|
|
319
|
-
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
320
|
-
|
|
321
|
-
if self.moe.has_bias:
|
|
322
|
-
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
323
|
-
w13_bias, output_sizes, n_shards, dim=1)
|
|
324
|
-
w13_bias = jax.device_put(
|
|
325
|
-
w13_bias,
|
|
326
|
-
Format(Layout((0, 1)),
|
|
327
|
-
NamedSharding(self.mesh, P(None, "model"))))
|
|
328
|
-
w2_bias = jax.device_put(
|
|
329
|
-
w2_bias,
|
|
330
|
-
Format(Layout((0, 1)),
|
|
331
|
-
NamedSharding(self.mesh, P(None, None))))
|
|
332
|
-
|
|
333
|
-
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
334
|
-
requires_grad=False)
|
|
335
|
-
layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
277
|
+
layer.w13_weight = Parameter(weights.w13_weight, requires_grad=False)
|
|
278
|
+
layer.w2_weight = Parameter(weights.w2_weight, requires_grad=False)
|
|
336
279
|
|
|
337
280
|
if self.moe.has_bias:
|
|
338
|
-
layer.w13_bias = Parameter(
|
|
339
|
-
|
|
340
|
-
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
281
|
+
layer.w13_bias = Parameter(weights.w13_bias, requires_grad=False)
|
|
282
|
+
layer.w2_bias = Parameter(weights.w2_bias, requires_grad=False)
|
|
341
283
|
|
|
342
284
|
def apply(
|
|
343
285
|
self,
|
|
344
286
|
layer: torch.nn.Module,
|
|
345
287
|
x: torch.Tensor,
|
|
346
288
|
router_logits: torch.Tensor,
|
|
347
|
-
) ->
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
if self.moe.has_bias:
|
|
358
|
-
w13_bias = jax_view(layer.w13_bias)
|
|
359
|
-
w2_bias = jax_view(layer.w2_bias)
|
|
360
|
-
gating_output = jax_view(router_logits)
|
|
361
|
-
|
|
362
|
-
if self.use_kernel:
|
|
363
|
-
actual_hidden_size = x.shape[-1]
|
|
364
|
-
padded_hidden_size = align_to(actual_hidden_size, 256)
|
|
365
|
-
x = jnp.pad(x,
|
|
366
|
-
((0, 0), (0, padded_hidden_size - actual_hidden_size)),
|
|
367
|
-
constant_values=0)
|
|
368
|
-
output = fused_ep_moe(
|
|
369
|
-
mesh=self.mesh,
|
|
370
|
-
tokens=x,
|
|
371
|
-
w1=w13_weight,
|
|
372
|
-
w2=w2_weight,
|
|
373
|
-
b1=w13_bias,
|
|
374
|
-
b2=w2_bias,
|
|
375
|
-
gating_output=gating_output,
|
|
376
|
-
top_k=layer.top_k,
|
|
377
|
-
ep_axis_name=self.ep_axis_name,
|
|
378
|
-
renormalize_topk_logits=layer.renormalize,
|
|
379
|
-
act_fn=layer.activation,
|
|
380
|
-
**self.block_size,
|
|
381
|
-
)[:, :actual_hidden_size]
|
|
382
|
-
else:
|
|
383
|
-
output = fused_moe_func(
|
|
384
|
-
hidden_states=x,
|
|
385
|
-
w1=w13_weight,
|
|
386
|
-
w2=w2_weight,
|
|
387
|
-
w1_bias=w13_bias,
|
|
388
|
-
w2_bias=w2_bias,
|
|
389
|
-
gating_output=gating_output,
|
|
390
|
-
topk=layer.top_k,
|
|
391
|
-
renormalize=layer.renormalize,
|
|
392
|
-
mesh=self.mesh,
|
|
393
|
-
use_ep=layer.use_ep,
|
|
394
|
-
activation=layer.activation,
|
|
395
|
-
)
|
|
396
|
-
|
|
397
|
-
return torch_view(output)
|
|
289
|
+
) -> torch.Tensor:
|
|
290
|
+
|
|
291
|
+
return fused_moe_apply(
|
|
292
|
+
layer,
|
|
293
|
+
x,
|
|
294
|
+
router_logits,
|
|
295
|
+
self.moe_backend,
|
|
296
|
+
self.mesh,
|
|
297
|
+
self.extra_backend_kwargs,
|
|
298
|
+
)
|
tpu_inference/lora/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -4,7 +4,6 @@
|
|
|
4
4
|
import jax
|
|
5
5
|
import jax.numpy as jnp
|
|
6
6
|
import torch
|
|
7
|
-
import torch.nn.functional as F
|
|
8
7
|
from torchax.interop import call_jax
|
|
9
8
|
|
|
10
9
|
|
|
@@ -85,19 +84,15 @@ def bgmv_expand_slice(
|
|
|
85
84
|
add_inputs (bool): Whether or not to add the input tensor to the output
|
|
86
85
|
tensor.
|
|
87
86
|
"""
|
|
88
|
-
outputs = bgmv_torch(inputs, lora_b_weights,
|
|
87
|
+
outputs = bgmv_torch(inputs, lora_b_weights,
|
|
88
|
+
lora_indices_tensor) # [num_tokens, out_features]
|
|
89
89
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
output_tensor.shape[1] - (slice_offset + slice_size),
|
|
95
|
-
0,
|
|
96
|
-
0,
|
|
97
|
-
),
|
|
98
|
-
)
|
|
90
|
+
# Create a padded tensor manually to avoid issues with F.pad on sharded tensors.
|
|
91
|
+
# This is a more robust way to handle padding in a distributed environment.
|
|
92
|
+
outputs_padded = torch.zeros_like(output_tensor)
|
|
93
|
+
outputs_padded[:, slice_offset:slice_offset + slice_size] = outputs
|
|
99
94
|
|
|
100
95
|
if add_inputs:
|
|
101
|
-
return output_tensor +
|
|
96
|
+
return output_tensor + outputs_padded
|
|
102
97
|
else:
|
|
103
|
-
return
|
|
98
|
+
return outputs_padded
|
tpu_inference/models/__init__.py
CHANGED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|