tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import copy
|
|
16
|
+
import functools
|
|
17
|
+
from collections.abc import Sequence
|
|
18
|
+
from contextlib import nullcontext
|
|
19
|
+
from typing import Any, List, Optional, Tuple
|
|
20
|
+
from unittest.mock import patch
|
|
21
|
+
|
|
22
|
+
import jax
|
|
23
|
+
import torch
|
|
24
|
+
import torch.nn
|
|
25
|
+
import torchax
|
|
26
|
+
import vllm.envs as vllm_envs
|
|
27
|
+
from flax.typing import PRNGKey
|
|
28
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
29
|
+
from torchax.interop import jax_view, torch_view
|
|
30
|
+
from torchax.ops.mappings import TORCH_DTYPE_TO_JAX
|
|
31
|
+
from vllm.config import VllmConfig
|
|
32
|
+
from vllm.forward_context import set_forward_context
|
|
33
|
+
from vllm.lora.layers import BaseLayerWithLoRA
|
|
34
|
+
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
|
35
|
+
from vllm.model_executor.model_loader import get_model as vllm_get_model
|
|
36
|
+
from vllm.model_executor.models import supports_lora, supports_multimodal
|
|
37
|
+
from vllm.sequence import IntermediateTensors
|
|
38
|
+
|
|
39
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
|
+
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
41
|
+
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
|
|
42
|
+
from tpu_inference.logger import init_logger
|
|
43
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
44
|
+
JaxIntermediateTensors
|
|
45
|
+
from tpu_inference.models.vllm.vllm_model_wrapper_context import (
|
|
46
|
+
get_vllm_model_wrapper_context, set_vllm_model_wrapper_context)
|
|
47
|
+
from tpu_inference.runner.lora_utils import replace_lora_metadata
|
|
48
|
+
|
|
49
|
+
logger = init_logger(__name__)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class _VllmRunner(torch.nn.Module):
|
|
53
|
+
|
|
54
|
+
def __init__(self, vllm_model: torch.nn.Module):
|
|
55
|
+
super().__init__()
|
|
56
|
+
self.vllm_model = vllm_model
|
|
57
|
+
|
|
58
|
+
def forward(self, **kwargs) -> torch.Tensor:
|
|
59
|
+
if "hidden_state" in kwargs:
|
|
60
|
+
return self.compute_logits(kwargs["hidden_state"])
|
|
61
|
+
else:
|
|
62
|
+
return self.compute_hidden_state(
|
|
63
|
+
kwargs["input_ids"],
|
|
64
|
+
kwargs["positions"],
|
|
65
|
+
kwargs["intermediate_tensors"],
|
|
66
|
+
kwargs["inputs_embeds"],
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def compute_hidden_state(
|
|
70
|
+
self,
|
|
71
|
+
input_ids: torch.Tensor,
|
|
72
|
+
positions: torch.Tensor,
|
|
73
|
+
intermediate_tensors: Optional[IntermediateTensors],
|
|
74
|
+
inputs_embeds: Optional[torch.Tensor],
|
|
75
|
+
) -> torch.Tensor:
|
|
76
|
+
hidden_state = self.vllm_model(input_ids, positions,
|
|
77
|
+
intermediate_tensors, inputs_embeds)
|
|
78
|
+
return hidden_state
|
|
79
|
+
|
|
80
|
+
def compute_logits(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
|
81
|
+
return self.vllm_model.compute_logits(hidden_state)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class VllmModelWrapper:
|
|
85
|
+
""" Wraps a vLLM Pytorch model and let it run on the JAX engine. """
|
|
86
|
+
|
|
87
|
+
rng: PRNGKey
|
|
88
|
+
mesh: Mesh
|
|
89
|
+
model: _VllmRunner
|
|
90
|
+
|
|
91
|
+
def __init__(self, vllm_config: VllmConfig, rng: PRNGKey, mesh: Mesh):
|
|
92
|
+
self.vllm_config = vllm_config
|
|
93
|
+
self.rng = rng
|
|
94
|
+
self.mesh = mesh
|
|
95
|
+
|
|
96
|
+
self.vllm_config.quant_config = get_tpu_quantization_config(
|
|
97
|
+
self.vllm_config, self.mesh)
|
|
98
|
+
|
|
99
|
+
def load_weights(self):
|
|
100
|
+
# Set up to load the model into CPU first.
|
|
101
|
+
# Cache device slice config since device config cannot be deepcopied
|
|
102
|
+
modified_slice_config = False
|
|
103
|
+
if hasattr(
|
|
104
|
+
self.vllm_config.device_config,
|
|
105
|
+
'slice') and self.vllm_config.device_config.slice is not None:
|
|
106
|
+
slice_config = self.vllm_config.device_config.slice
|
|
107
|
+
modified_slice_config = True
|
|
108
|
+
self.vllm_config.device_config.slice = None
|
|
109
|
+
self.vllm_config.compilation_config.static_forward_context.clear()
|
|
110
|
+
|
|
111
|
+
vllm_config_for_load = copy.deepcopy(self.vllm_config)
|
|
112
|
+
if modified_slice_config:
|
|
113
|
+
self.vllm_config.device_config.slice = slice_config
|
|
114
|
+
assert self.vllm_config.model_config.dtype in TORCH_DTYPE_TO_JAX, "The model_config.dtype must be a PyTorch dtype."
|
|
115
|
+
vllm_config_for_load.device_config.device = "cpu"
|
|
116
|
+
# Clearing the cached compilation config, otherwise vllm model init will fail
|
|
117
|
+
|
|
118
|
+
# When expert parallelism is enabled, vLLM loads weight in sharding
|
|
119
|
+
# aware manner. Since tpu-inference has its own sharding logic, this
|
|
120
|
+
# may casue errors. Therefore, we disable it during weight loading.
|
|
121
|
+
vllm_config_for_load.parallel_config.enable_expert_parallel = False
|
|
122
|
+
|
|
123
|
+
use_random_weights = (
|
|
124
|
+
vllm_config_for_load.load_config.load_format == "dummy")
|
|
125
|
+
if use_random_weights:
|
|
126
|
+
logger.info(
|
|
127
|
+
"Initializing vLLM model with random weights, weight loading skipped."
|
|
128
|
+
)
|
|
129
|
+
# The DummyModelLoader in vLLM calls torch._sync for torch_xla path when
|
|
130
|
+
# it detects the tpu platform, but we don't need it and it causes crash
|
|
131
|
+
# without proper setup.
|
|
132
|
+
load_context = patch(
|
|
133
|
+
"torch._sync",
|
|
134
|
+
return_value=None) if use_random_weights else nullcontext()
|
|
135
|
+
|
|
136
|
+
# By default load weights to the CPU device first. If we are running
|
|
137
|
+
# under Pathways, this would cause weights to be loaded on a CPU-only
|
|
138
|
+
# node, so we'll need to remove this context.
|
|
139
|
+
jax_context = jax.default_device(
|
|
140
|
+
jax.devices("cpu")
|
|
141
|
+
[0]) if not vllm_envs.VLLM_TPU_USING_PATHWAYS else nullcontext()
|
|
142
|
+
|
|
143
|
+
# Load the vLLM model and wrap it into a new model whose forward
|
|
144
|
+
# function can calculate the hidden_state and logits.
|
|
145
|
+
with load_context, jax_context:
|
|
146
|
+
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
|
|
147
|
+
lora_manager = None
|
|
148
|
+
if vllm_config_for_load.lora_config is not None:
|
|
149
|
+
# Replace layers in the model with LoRA layers.
|
|
150
|
+
with torchax.default_env():
|
|
151
|
+
# Argument "device" in load_lora_model is used to set the device
|
|
152
|
+
# used in punica wrapper.
|
|
153
|
+
lora_manager, vllm_model = load_lora_model(
|
|
154
|
+
vllm_model, vllm_config_for_load, device="jax")
|
|
155
|
+
replace_set_lora(vllm_model)
|
|
156
|
+
|
|
157
|
+
static_forward_context = vllm_config_for_load.compilation_config.static_forward_context
|
|
158
|
+
self.vllm_config.compilation_config.static_forward_context = static_forward_context
|
|
159
|
+
|
|
160
|
+
self.model = _VllmRunner(vllm_model)
|
|
161
|
+
params_and_buffers = shard_model_to_tpu(self.model, self.mesh)
|
|
162
|
+
|
|
163
|
+
# Returning to the jax land, so we need to wrap it into a JaxValue.
|
|
164
|
+
return jax_view(params_and_buffers), lora_manager
|
|
165
|
+
|
|
166
|
+
def jit_step_func(self):
|
|
167
|
+
|
|
168
|
+
@functools.partial(
|
|
169
|
+
jax.jit,
|
|
170
|
+
donate_argnames=("kv_caches", ),
|
|
171
|
+
compiler_options={
|
|
172
|
+
"xla_tpu_all_gather_collective_matmul_mode":
|
|
173
|
+
"post_spmd_conservative",
|
|
174
|
+
"xla_tpu_reduce_scatter_collective_matmul_mode":
|
|
175
|
+
"post_spmd_conservative"
|
|
176
|
+
},
|
|
177
|
+
static_argnames=("layer_name_to_kvcache_index", "is_first_rank",
|
|
178
|
+
"is_last_rank"),
|
|
179
|
+
)
|
|
180
|
+
def step_fun(
|
|
181
|
+
params_and_buffers, # This has been wrapped into torchax TorchValue
|
|
182
|
+
kv_caches: List[jax.Array],
|
|
183
|
+
input_ids: jax.Array,
|
|
184
|
+
attn_metadata: AttentionMetadata,
|
|
185
|
+
input_embeds: jax.Array,
|
|
186
|
+
input_positions: jax.Array,
|
|
187
|
+
layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
|
|
188
|
+
lora_metadata,
|
|
189
|
+
intermediate_tensors: JaxIntermediateTensors = None,
|
|
190
|
+
is_first_rank: bool = True,
|
|
191
|
+
is_last_rank: bool = True,
|
|
192
|
+
*args,
|
|
193
|
+
) -> Tuple[List[jax.Array], jax.Array]:
|
|
194
|
+
layer_name_to_kvcache_index = dict(layer_name_to_kvcache_index)
|
|
195
|
+
lora_metadata = torch_view(lora_metadata)
|
|
196
|
+
with torchax.default_env(), set_vllm_model_wrapper_context(
|
|
197
|
+
kv_caches=kv_caches,
|
|
198
|
+
mesh=self.mesh,
|
|
199
|
+
layer_name_to_kvcache_index=layer_name_to_kvcache_index
|
|
200
|
+
), set_forward_context(attn_metadata=attn_metadata,
|
|
201
|
+
vllm_config=self.vllm_config):
|
|
202
|
+
# We need to wrap args from jax land into TorchValue with
|
|
203
|
+
# torch_view in order to call the Torch function.
|
|
204
|
+
original_lora_metadata = replace_lora_metadata(
|
|
205
|
+
self.model, lora_metadata, self.vllm_config.lora_config)
|
|
206
|
+
if not is_first_rank:
|
|
207
|
+
intermediate_tensors = intermediate_tensors.to_torch()
|
|
208
|
+
output_from_torch = torch.func.functional_call(
|
|
209
|
+
self.model,
|
|
210
|
+
torch_view(params_and_buffers),
|
|
211
|
+
kwargs={
|
|
212
|
+
"input_ids": torch_view(input_ids),
|
|
213
|
+
"positions": torch_view(input_positions),
|
|
214
|
+
"intermediate_tensors": intermediate_tensors,
|
|
215
|
+
"inputs_embeds": None,
|
|
216
|
+
},
|
|
217
|
+
tie_weights=False,
|
|
218
|
+
)
|
|
219
|
+
replace_lora_metadata(self.model, original_lora_metadata,
|
|
220
|
+
self.vllm_config.lora_config)
|
|
221
|
+
vllm_model_wrapper_context = get_vllm_model_wrapper_context()
|
|
222
|
+
new_kv_caches = vllm_model_wrapper_context.kv_caches
|
|
223
|
+
# Wrap the output(hidden states or intermediate tensor)
|
|
224
|
+
# from torch land into a JaxValue for the jax code to consume.
|
|
225
|
+
if not is_last_rank:
|
|
226
|
+
output = JaxIntermediateTensors.from_torch(output_from_torch)
|
|
227
|
+
else:
|
|
228
|
+
output = jax_view(output_from_torch)
|
|
229
|
+
return new_kv_caches, output, []
|
|
230
|
+
|
|
231
|
+
return step_fun
|
|
232
|
+
|
|
233
|
+
def jit_compute_logits_func(self):
|
|
234
|
+
|
|
235
|
+
@functools.partial(
|
|
236
|
+
jax.jit,
|
|
237
|
+
out_shardings=(NamedSharding(self.mesh,
|
|
238
|
+
PartitionSpec("data", "model"))),
|
|
239
|
+
)
|
|
240
|
+
def compute_logits_func(
|
|
241
|
+
params_and_buffers: Any,
|
|
242
|
+
hidden_states: jax.Array,
|
|
243
|
+
lora_metadata,
|
|
244
|
+
) -> jax.Array:
|
|
245
|
+
lora_metadata = torch_view(lora_metadata)
|
|
246
|
+
with torchax.default_env(), set_vllm_model_wrapper_context(
|
|
247
|
+
kv_caches=None, mesh=self.mesh):
|
|
248
|
+
original_lora_metadata = replace_lora_metadata(
|
|
249
|
+
self.model, lora_metadata, self.vllm_config.lora_config)
|
|
250
|
+
logits = torch.func.functional_call(
|
|
251
|
+
self.model,
|
|
252
|
+
torch_view(params_and_buffers),
|
|
253
|
+
kwargs={
|
|
254
|
+
"hidden_state": torch_view(hidden_states),
|
|
255
|
+
},
|
|
256
|
+
tie_weights=False,
|
|
257
|
+
)
|
|
258
|
+
replace_lora_metadata(self.model, original_lora_metadata,
|
|
259
|
+
self.vllm_config.lora_config)
|
|
260
|
+
return jax_view(logits)
|
|
261
|
+
|
|
262
|
+
return compute_logits_func
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
|
|
266
|
+
device: str) -> torch.nn.Module:
|
|
267
|
+
if not supports_lora(model):
|
|
268
|
+
raise ValueError(
|
|
269
|
+
f"{model.__class__.__name__} does not support LoRA yet.")
|
|
270
|
+
|
|
271
|
+
if supports_multimodal(model):
|
|
272
|
+
logger.warning("Regarding multimodal models, vLLM currently "
|
|
273
|
+
"only supports adding LoRA to language model.")
|
|
274
|
+
|
|
275
|
+
# Add LoRA Manager to the Model Runner
|
|
276
|
+
lora_manager = LRUCacheWorkerLoRAManager(
|
|
277
|
+
vllm_config,
|
|
278
|
+
device,
|
|
279
|
+
model.embedding_modules,
|
|
280
|
+
)
|
|
281
|
+
return lora_manager, lora_manager.create_lora_manager(model)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
# The reason why replace the method is that the set_lora and reset_lora need to
|
|
285
|
+
# run under torchax env.
|
|
286
|
+
def replace_set_lora(model):
|
|
287
|
+
|
|
288
|
+
def _tpu_set_lora(
|
|
289
|
+
self,
|
|
290
|
+
index: int,
|
|
291
|
+
lora_a: torch.Tensor,
|
|
292
|
+
lora_b: torch.Tensor,
|
|
293
|
+
):
|
|
294
|
+
with torchax.default_env():
|
|
295
|
+
self._original_set_lora(index, lora_a, lora_b)
|
|
296
|
+
|
|
297
|
+
def _tpu_reset_lora(self, index: int):
|
|
298
|
+
with torchax.default_env():
|
|
299
|
+
self._original_reset_lora(index)
|
|
300
|
+
|
|
301
|
+
for _, module in model.named_modules():
|
|
302
|
+
if isinstance(module, BaseLayerWithLoRA):
|
|
303
|
+
module._original_set_lora = module.set_lora
|
|
304
|
+
module._original_reset_lora = module.reset_lora
|
|
305
|
+
module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
|
|
306
|
+
module.reset_lora = _tpu_reset_lora.__get__(
|
|
307
|
+
module, module.__class__)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from contextlib import contextmanager
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import Dict, List, Optional
|
|
18
|
+
|
|
19
|
+
import jax
|
|
20
|
+
from jax.sharding import Mesh
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class VllmModelWrapperContext:
|
|
25
|
+
kv_caches: List[jax.Array]
|
|
26
|
+
mesh: Mesh
|
|
27
|
+
layer_name_to_kvcache_index: Dict[str, int]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
_vllm_model_wrapper_context: Optional[VllmModelWrapperContext] = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_vllm_model_wrapper_context() -> VllmModelWrapperContext:
|
|
34
|
+
assert _vllm_model_wrapper_context is not None, (
|
|
35
|
+
"VllmModelWrapperContext is not set. "
|
|
36
|
+
"Please use `set_vllm_model_wrapper_context` to set the VllmModelWrapperContext."
|
|
37
|
+
)
|
|
38
|
+
return _vllm_model_wrapper_context
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@contextmanager
|
|
42
|
+
def set_vllm_model_wrapper_context(
|
|
43
|
+
*,
|
|
44
|
+
kv_caches: List[jax.Array],
|
|
45
|
+
mesh: Mesh,
|
|
46
|
+
layer_name_to_kvcache_index: Dict[str, int] = None,
|
|
47
|
+
):
|
|
48
|
+
global _vllm_model_wrapper_context
|
|
49
|
+
prev_context = _vllm_model_wrapper_context
|
|
50
|
+
_vllm_model_wrapper_context = VllmModelWrapperContext(
|
|
51
|
+
kv_caches=kv_caches,
|
|
52
|
+
mesh=mesh,
|
|
53
|
+
layer_name_to_kvcache_index=layer_name_to_kvcache_index,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
yield
|
|
58
|
+
finally:
|
|
59
|
+
_vllm_model_wrapper_context = prev_context
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# ruff: noqa
|
|
16
|
+
from tpu_inference.platforms.tpu_platform import TpuPlatform
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import torch
|
|
7
|
+
import vllm.envs as vllm_envs
|
|
8
|
+
from tpu_info import device
|
|
9
|
+
from vllm.inputs import ProcessorInputs, PromptType
|
|
10
|
+
from vllm.platforms.interface import Platform, PlatformEnum
|
|
11
|
+
|
|
12
|
+
from tpu_inference import envs
|
|
13
|
+
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
14
|
+
from tpu_inference.logger import init_logger
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
18
|
+
from vllm.attention.selector import AttentionSelectorConfig
|
|
19
|
+
from vllm.config import BlockSize, ModelConfig, VllmConfig
|
|
20
|
+
from vllm.pooling_params import PoolingParams
|
|
21
|
+
from vllm.sampling_params import SamplingParams, SamplingType
|
|
22
|
+
else:
|
|
23
|
+
BlockSize = None
|
|
24
|
+
ModelConfig = None
|
|
25
|
+
VllmConfig = None
|
|
26
|
+
PoolingParams = None
|
|
27
|
+
AttentionBackendEnum = None
|
|
28
|
+
SamplingParams = None
|
|
29
|
+
SamplingType = None
|
|
30
|
+
|
|
31
|
+
logger = init_logger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TpuPlatform(Platform):
|
|
35
|
+
_enum = PlatformEnum.TPU
|
|
36
|
+
device_name: str = "tpu"
|
|
37
|
+
device_type: str = "tpu"
|
|
38
|
+
dispatch_key: str = "XLA"
|
|
39
|
+
ray_device_key: str = "TPU"
|
|
40
|
+
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
|
|
41
|
+
simple_compile_backend: str = "openxla"
|
|
42
|
+
|
|
43
|
+
supported_quantization: list[str] = [
|
|
44
|
+
"tpu_int8", "compressed-tensors", "awq", "fp8", "mxfp4"
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
additional_env_vars: list[str] = [
|
|
48
|
+
"PHASED_PROFILING_DIR", "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
|
|
49
|
+
"TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE",
|
|
50
|
+
"NEW_MODEL_DESIGN"
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
|
|
55
|
+
attn_selector_config: "AttentionSelectorConfig",
|
|
56
|
+
**kwargs) -> str:
|
|
57
|
+
from vllm.attention.backends.registry import AttentionBackendEnum
|
|
58
|
+
|
|
59
|
+
if selected_backend != AttentionBackendEnum.PALLAS:
|
|
60
|
+
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
|
61
|
+
|
|
62
|
+
logger.info("Using Pallas V1 backend.")
|
|
63
|
+
return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def get_device_name(cls, device_id: int = 0) -> str:
|
|
67
|
+
try:
|
|
68
|
+
if vllm_envs.VLLM_TPU_USING_PATHWAYS:
|
|
69
|
+
# Causes mutliprocess accessing IFRT when calling jax.devices()
|
|
70
|
+
return "TPU v6 lite"
|
|
71
|
+
else:
|
|
72
|
+
chip_type, _ = device.get_local_chips()
|
|
73
|
+
return f"TPU {chip_type.name}"
|
|
74
|
+
except Exception as e:
|
|
75
|
+
logger.warning(f"Error getting device name: {e}")
|
|
76
|
+
return 'TPU'
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def fp8_dtype(cls) -> torch.dtype:
|
|
80
|
+
if cls.get_device_name().lower() == "tpu v6e":
|
|
81
|
+
logger.info(
|
|
82
|
+
"Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.")
|
|
83
|
+
return torch.float8_e5m2
|
|
84
|
+
return torch.float8_e4m3fn
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
|
88
|
+
raise NotImplementedError
|
|
89
|
+
|
|
90
|
+
@classmethod
|
|
91
|
+
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def get_punica_wrapper(cls) -> str:
|
|
96
|
+
return "tpu_inference.lora.torch_punica_tpu.PunicaWrapperTPU"
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def get_infinity_values(cls, dtype: jnp.dtype) -> Tuple[float, float]:
|
|
100
|
+
return jnp.finfo(dtype).min, jnp.finfo(dtype).max
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
def can_update_inplace(cls):
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def get_lora_vocab_padding_size(cls) -> int:
|
|
108
|
+
return 1
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def inference_mode(cls):
|
|
112
|
+
return True
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def _initialize_sharding_config(cls, vllm_config: VllmConfig) -> None:
|
|
116
|
+
|
|
117
|
+
sharding_config = ShardingConfigManager.from_vllm_config(vllm_config)
|
|
118
|
+
vllm_config.sharding_config = sharding_config
|
|
119
|
+
logger.info(f"Initialized sharding configuration: {sharding_config}")
|
|
120
|
+
|
|
121
|
+
@classmethod
|
|
122
|
+
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
|
123
|
+
|
|
124
|
+
if vllm_envs.VLLM_TPU_USING_PATHWAYS:
|
|
125
|
+
assert not vllm_envs.VLLM_ENABLE_V1_MULTIPROCESSING, (
|
|
126
|
+
"VLLM_ENABLE_V1_MULTIPROCESSING must be 0 when using Pathways(JAX_PLATFORMS=proxy)"
|
|
127
|
+
)
|
|
128
|
+
cls._initialize_sharding_config(vllm_config)
|
|
129
|
+
|
|
130
|
+
from vllm.config import CompilationMode
|
|
131
|
+
|
|
132
|
+
cache_config = vllm_config.cache_config
|
|
133
|
+
# For v0, the default block size is 16.
|
|
134
|
+
if cache_config and cache_config.block_size is None:
|
|
135
|
+
cache_config.block_size = cast(BlockSize, 16)
|
|
136
|
+
|
|
137
|
+
compilation_config = vllm_config.compilation_config
|
|
138
|
+
|
|
139
|
+
# TPU only supports DYNAMO_TRACE_ONCE compilation level
|
|
140
|
+
# NOTE(xiang): the compilation_config is not used by jax.
|
|
141
|
+
if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE:
|
|
142
|
+
compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
|
|
143
|
+
|
|
144
|
+
if compilation_config.backend == "":
|
|
145
|
+
compilation_config.backend = "openxla"
|
|
146
|
+
|
|
147
|
+
# TODO(cuiq): remove this dependency.
|
|
148
|
+
if vllm_config.model_config:
|
|
149
|
+
from vllm.v1.attention.backends.pallas import \
|
|
150
|
+
PallasAttentionBackend
|
|
151
|
+
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
|
152
|
+
vllm_config) # type: ignore[assignment]
|
|
153
|
+
min_page_size = PallasAttentionBackend.get_min_page_size(
|
|
154
|
+
vllm_config)
|
|
155
|
+
if min_page_size > cache_config.block_size:
|
|
156
|
+
logger.warning(
|
|
157
|
+
"Increase the page size from %s to %s to avoid SMEM OOM",
|
|
158
|
+
cache_config.block_size,
|
|
159
|
+
min_page_size,
|
|
160
|
+
)
|
|
161
|
+
cache_config.block_size = min_page_size # type: ignore[assignment]
|
|
162
|
+
|
|
163
|
+
parallel_config = vllm_config.parallel_config
|
|
164
|
+
scheduler_config = vllm_config.scheduler_config
|
|
165
|
+
parallel_config.worker_cls = \
|
|
166
|
+
"tpu_inference.worker.tpu_worker.TPUWorker"
|
|
167
|
+
|
|
168
|
+
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
169
|
+
if not multihost_backend: # Single host
|
|
170
|
+
if parallel_config.pipeline_parallel_size == 1:
|
|
171
|
+
logger.info("Force using UniProcExecutor for JAX on \
|
|
172
|
+
single host without pipeline parallelism.")
|
|
173
|
+
parallel_config.distributed_executor_backend = "uni"
|
|
174
|
+
else:
|
|
175
|
+
logger.info("Force using MultiprocExecutor for JAX on \
|
|
176
|
+
single host with pipeline parallelism.")
|
|
177
|
+
parallel_config.distributed_executor_backend = "mp"
|
|
178
|
+
elif multihost_backend == "ray":
|
|
179
|
+
from tpu_inference.executors.ray_distributed_executor import \
|
|
180
|
+
RayDistributedExecutor
|
|
181
|
+
parallel_config.distributed_executor_backend = RayDistributedExecutor
|
|
182
|
+
logger.info(
|
|
183
|
+
"Force using RayDistributedExecutor for JAX on multihost.")
|
|
184
|
+
else:
|
|
185
|
+
logger.warning(
|
|
186
|
+
f"Unknown TPU multihost backend: {multihost_backend}. "
|
|
187
|
+
"Using uniproc_executor.")
|
|
188
|
+
parallel_config.distributed_executor_backend = "uni"
|
|
189
|
+
|
|
190
|
+
if scheduler_config.is_multimodal_model and not \
|
|
191
|
+
scheduler_config.disable_chunked_mm_input:
|
|
192
|
+
logger.warning("TPU does not support running Multimodal models"\
|
|
193
|
+
" without setting `--disable_chunked_mm_input`. " \
|
|
194
|
+
"Forcing --disable_chunked_mm_input.")
|
|
195
|
+
scheduler_config.disable_chunked_mm_input = True
|
|
196
|
+
|
|
197
|
+
kv_transfer_config = vllm_config.kv_transfer_config
|
|
198
|
+
if kv_transfer_config is not None:
|
|
199
|
+
assert kv_transfer_config.kv_connector == "TPUConnector"
|
|
200
|
+
# Late initialization to avoid circular import.
|
|
201
|
+
# Only perform qwix quantization if it is jax model.
|
|
202
|
+
if vllm_config.model_config is not None:
|
|
203
|
+
from tpu_inference.models.jax.utils.qwix.qwix_utils import \
|
|
204
|
+
update_vllm_config_for_qwix_quantization
|
|
205
|
+
if vllm_config.model_config:
|
|
206
|
+
update_vllm_config_for_qwix_quantization(vllm_config)
|
|
207
|
+
|
|
208
|
+
from tpu_inference.core.sched.dp_scheduler import \
|
|
209
|
+
update_vllm_config_for_dp_scheduler
|
|
210
|
+
update_vllm_config_for_dp_scheduler(vllm_config)
|
|
211
|
+
|
|
212
|
+
@classmethod
|
|
213
|
+
def is_pin_memory_available(cls):
|
|
214
|
+
logger.warning("Pin memory is not supported on TPU.")
|
|
215
|
+
return False
|
|
216
|
+
|
|
217
|
+
@classmethod
|
|
218
|
+
def get_device_communicator_cls(cls) -> str:
|
|
219
|
+
return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa
|
|
220
|
+
|
|
221
|
+
@classmethod
|
|
222
|
+
def use_all_gather(cls) -> bool:
|
|
223
|
+
return True
|
|
224
|
+
|
|
225
|
+
@classmethod
|
|
226
|
+
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
|
227
|
+
# V1 support on TPU is experimental
|
|
228
|
+
return True
|
|
229
|
+
|
|
230
|
+
@classmethod
|
|
231
|
+
def validate_request(
|
|
232
|
+
cls,
|
|
233
|
+
prompt: PromptType,
|
|
234
|
+
params: Union["SamplingParams", PoolingParams],
|
|
235
|
+
processed_inputs: ProcessorInputs,
|
|
236
|
+
) -> None:
|
|
237
|
+
"""Raises if this request is unsupported on this platform"""
|
|
238
|
+
from vllm.sampling_params import SamplingParams, SamplingType
|
|
239
|
+
|
|
240
|
+
if isinstance(params, SamplingParams):
|
|
241
|
+
if params.sampling_type == SamplingType.RANDOM_SEED:
|
|
242
|
+
raise ValueError("JAX does not support per-request seed.")
|
|
243
|
+
|
|
244
|
+
@classmethod
|
|
245
|
+
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
|
246
|
+
model_config: ModelConfig) -> bool:
|
|
247
|
+
return True
|
|
248
|
+
|
|
249
|
+
@classmethod
|
|
250
|
+
def use_sync_weight_loader(cls) -> bool:
|
|
251
|
+
"""
|
|
252
|
+
Returns if the current platform needs to sync weight loader.
|
|
253
|
+
"""
|
|
254
|
+
return True
|
|
255
|
+
|
|
256
|
+
@classmethod
|
|
257
|
+
def support_hybrid_kv_cache(cls) -> bool:
|
|
258
|
+
return True
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|