tpu-inference 0.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_adapters.py +83 -0
- tests/core/test_core_tpu.py +523 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/test_lora.py +123 -0
- tests/test_base.py +201 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +218 -0
- tests/tpu_backend_test.py +59 -0
- tpu_inference/__init__.py +30 -0
- tpu_inference/adapters/__init__.py +0 -0
- tpu_inference/adapters/vllm_adapters.py +42 -0
- tpu_inference/adapters/vllm_config_adapters.py +134 -0
- tpu_inference/backend.py +69 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/adapters.py +153 -0
- tpu_inference/core/core_tpu.py +776 -0
- tpu_inference/core/disagg_executor.py +117 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/di/__init__.py +0 -0
- tpu_inference/di/abstracts.py +28 -0
- tpu_inference/di/host.py +76 -0
- tpu_inference/di/interfaces.py +51 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/tpu_connector.py +699 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +346 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/interfaces/__init__.py +0 -0
- tpu_inference/interfaces/cache.py +31 -0
- tpu_inference/interfaces/config.py +47 -0
- tpu_inference/interfaces/config_parts.py +117 -0
- tpu_inference/interfaces/engine.py +51 -0
- tpu_inference/interfaces/outputs.py +22 -0
- tpu_inference/interfaces/params.py +21 -0
- tpu_inference/interfaces/platform.py +74 -0
- tpu_inference/interfaces/request.py +39 -0
- tpu_inference/interfaces/scheduler.py +31 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +308 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1233 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/llama3.py +366 -0
- tpu_inference/models/jax/llama4.py +473 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +976 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
- tpu_inference/models/jax/utils/weight_utils.py +510 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_jax.py +257 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table_jax.py +122 -0
- tpu_inference/runner/compilation_manager.py +672 -0
- tpu_inference/runner/input_batch_jax.py +435 -0
- tpu_inference/runner/kv_cache.py +119 -0
- tpu_inference/runner/kv_cache_manager.py +460 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +208 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +250 -0
- tpu_inference/runner/structured_decoding_manager.py +89 -0
- tpu_inference/runner/tpu_jax_runner.py +771 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +334 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +294 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/_temporary_vllm_compat.py +129 -0
- tpu_inference/worker/base.py +100 -0
- tpu_inference/worker/tpu_worker_jax.py +321 -0
- tpu_inference-0.11.1.dist-info/METADATA +101 -0
- tpu_inference-0.11.1.dist-info/RECORD +168 -0
- tpu_inference-0.11.1.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import functools
|
|
3
|
+
import os
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from contextlib import nullcontext
|
|
6
|
+
from typing import Any, List, Optional, Tuple
|
|
7
|
+
from unittest.mock import patch
|
|
8
|
+
|
|
9
|
+
import jax
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn
|
|
12
|
+
import torchax
|
|
13
|
+
from flax.typing import PRNGKey
|
|
14
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
15
|
+
from torchax.interop import jax_view, torch_view
|
|
16
|
+
from torchax.ops.mappings import TORCH_DTYPE_TO_JAX
|
|
17
|
+
from vllm.config import VllmConfig
|
|
18
|
+
from vllm.forward_context import set_forward_context
|
|
19
|
+
from vllm.lora.layers import BaseLayerWithLoRA
|
|
20
|
+
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
|
21
|
+
from vllm.model_executor.model_loader import get_model as vllm_get_model
|
|
22
|
+
from vllm.model_executor.models import supports_lora, supports_multimodal
|
|
23
|
+
from vllm.sequence import IntermediateTensors
|
|
24
|
+
|
|
25
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
26
|
+
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
27
|
+
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
|
|
28
|
+
from tpu_inference.logger import init_logger
|
|
29
|
+
from tpu_inference.models.vllm.vllm_model_wrapper_context import (
|
|
30
|
+
get_vllm_model_wrapper_context, set_vllm_model_wrapper_context)
|
|
31
|
+
from tpu_inference.runner.lora_utils import replace_lora_metadata
|
|
32
|
+
|
|
33
|
+
logger = init_logger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class _VllmRunner(torch.nn.Module):
|
|
37
|
+
|
|
38
|
+
def __init__(self, vllm_model: torch.nn.Module):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.vllm_model = vllm_model
|
|
41
|
+
|
|
42
|
+
def forward(self, **kwargs) -> torch.Tensor:
|
|
43
|
+
# We don't support multimodal input in Gemma3, but we need patch it to
|
|
44
|
+
# None to workaround vLLM Gemma3 model bug that
|
|
45
|
+
# `get_multimodal_embeddings` returns empty list but it's caller checks
|
|
46
|
+
# for None.
|
|
47
|
+
with patch(
|
|
48
|
+
"vllm.model_executor.models.gemma3_mm."
|
|
49
|
+
"Gemma3ForConditionalGeneration."
|
|
50
|
+
"get_multimodal_embeddings",
|
|
51
|
+
return_value=None):
|
|
52
|
+
if "hidden_state" in kwargs:
|
|
53
|
+
return self.compute_logits(kwargs["hidden_state"])
|
|
54
|
+
else:
|
|
55
|
+
return self.compute_hidden_state(
|
|
56
|
+
kwargs["input_ids"],
|
|
57
|
+
kwargs["positions"],
|
|
58
|
+
kwargs["intermediate_tensors"],
|
|
59
|
+
kwargs["inputs_embeds"],
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def compute_hidden_state(
|
|
63
|
+
self,
|
|
64
|
+
input_ids: torch.Tensor,
|
|
65
|
+
positions: torch.Tensor,
|
|
66
|
+
intermediate_tensors: Optional[IntermediateTensors],
|
|
67
|
+
inputs_embeds: Optional[torch.Tensor],
|
|
68
|
+
) -> torch.Tensor:
|
|
69
|
+
hidden_state = self.vllm_model(input_ids, positions,
|
|
70
|
+
intermediate_tensors, inputs_embeds)
|
|
71
|
+
return hidden_state
|
|
72
|
+
|
|
73
|
+
def compute_logits(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
|
74
|
+
return self.vllm_model.compute_logits(hidden_state)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class VllmModelWrapper:
|
|
78
|
+
""" Wraps a vLLM Pytorch model and let it run on the JAX engine. """
|
|
79
|
+
|
|
80
|
+
rng: PRNGKey
|
|
81
|
+
mesh: Mesh
|
|
82
|
+
model: _VllmRunner
|
|
83
|
+
|
|
84
|
+
def __init__(self, vllm_config: VllmConfig, rng: PRNGKey, mesh: Mesh):
|
|
85
|
+
self.vllm_config = vllm_config
|
|
86
|
+
self.rng = rng
|
|
87
|
+
self.mesh = mesh
|
|
88
|
+
|
|
89
|
+
self.vllm_config.quant_config = get_tpu_quantization_config(
|
|
90
|
+
self.vllm_config, self.mesh)
|
|
91
|
+
|
|
92
|
+
def load_weights(self):
|
|
93
|
+
# Set up to load the model into CPU first.
|
|
94
|
+
vllm_config_for_load = copy.deepcopy(self.vllm_config)
|
|
95
|
+
assert self.vllm_config.model_config.dtype in TORCH_DTYPE_TO_JAX, "The model_config.dtype must be a PyTorch dtype."
|
|
96
|
+
vllm_config_for_load.device_config.device = "cpu"
|
|
97
|
+
|
|
98
|
+
if os.getenv("JAX_RANDOM_WEIGHTS", False):
|
|
99
|
+
vllm_config_for_load.load_config.load_format = "dummy"
|
|
100
|
+
use_random_weights = True
|
|
101
|
+
else:
|
|
102
|
+
use_random_weights = (
|
|
103
|
+
vllm_config_for_load.load_config.load_format == "dummy")
|
|
104
|
+
if use_random_weights:
|
|
105
|
+
logger.info(
|
|
106
|
+
"Initializing vLLM model with random weights, weight loading skipped."
|
|
107
|
+
)
|
|
108
|
+
# The DummyModelLoader in vLLM calls torch._sync for torch_xla path when
|
|
109
|
+
# it detects the tpu platform, but we don't need it and it causes crash
|
|
110
|
+
# without proper setup.
|
|
111
|
+
load_context = patch(
|
|
112
|
+
"torch._sync",
|
|
113
|
+
return_value=None) if use_random_weights else nullcontext()
|
|
114
|
+
|
|
115
|
+
# Load the vLLM model and wrap it into a new model whose forward
|
|
116
|
+
# function can calculate the hidden_state and logits.
|
|
117
|
+
with load_context, jax.default_device(jax.devices('cpu')[0]):
|
|
118
|
+
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
|
|
119
|
+
lora_manager = None
|
|
120
|
+
if vllm_config_for_load.lora_config is not None:
|
|
121
|
+
# Replace layers in the model with LoRA layers.
|
|
122
|
+
with torchax.default_env():
|
|
123
|
+
# Argument "device" in load_lora_model is used to set the device
|
|
124
|
+
# used in punica wrapper.
|
|
125
|
+
lora_manager, vllm_model = load_lora_model(
|
|
126
|
+
vllm_model, vllm_config_for_load, device="jax")
|
|
127
|
+
replace_set_lora(vllm_model)
|
|
128
|
+
|
|
129
|
+
static_forward_context = vllm_config_for_load.compilation_config.static_forward_context
|
|
130
|
+
self.vllm_config.compilation_config.static_forward_context = static_forward_context
|
|
131
|
+
|
|
132
|
+
self.model = _VllmRunner(vllm_model)
|
|
133
|
+
params_and_buffers = shard_model_to_tpu(self.model, self.mesh)
|
|
134
|
+
|
|
135
|
+
# Returning to the jax land, so we need to wrap it into a JaxValue.
|
|
136
|
+
return jax_view(params_and_buffers), lora_manager
|
|
137
|
+
|
|
138
|
+
def jit_step_func(self):
|
|
139
|
+
|
|
140
|
+
@functools.partial(
|
|
141
|
+
jax.jit,
|
|
142
|
+
donate_argnames=("kv_caches", ),
|
|
143
|
+
compiler_options={
|
|
144
|
+
"xla_tpu_all_gather_collective_matmul_mode":
|
|
145
|
+
"post_spmd_conservative",
|
|
146
|
+
"xla_tpu_reduce_scatter_collective_matmul_mode":
|
|
147
|
+
"post_spmd_conservative"
|
|
148
|
+
},
|
|
149
|
+
static_argnames=("layer_name_to_kvcache_index", ),
|
|
150
|
+
)
|
|
151
|
+
def step_fun(
|
|
152
|
+
params_and_buffers, # This has been wrapped into torchax TorchValue
|
|
153
|
+
kv_caches: List[jax.Array],
|
|
154
|
+
input_ids: jax.Array,
|
|
155
|
+
attn_metadata: AttentionMetadata,
|
|
156
|
+
input_embeds: jax.Array,
|
|
157
|
+
layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
|
|
158
|
+
lora_metadata,
|
|
159
|
+
*args,
|
|
160
|
+
) -> Tuple[List[jax.Array], jax.Array]:
|
|
161
|
+
layer_name_to_kvcache_index = dict(layer_name_to_kvcache_index)
|
|
162
|
+
lora_metadata = torch_view(lora_metadata)
|
|
163
|
+
with torchax.default_env(), set_vllm_model_wrapper_context(
|
|
164
|
+
kv_caches=kv_caches,
|
|
165
|
+
mesh=self.mesh,
|
|
166
|
+
layer_name_to_kvcache_index=layer_name_to_kvcache_index
|
|
167
|
+
), set_forward_context(attn_metadata=attn_metadata,
|
|
168
|
+
vllm_config=self.vllm_config):
|
|
169
|
+
# We need to wrap args from jax land into TorchValue with
|
|
170
|
+
# torch_view in order to call the Torch function.
|
|
171
|
+
original_lora_metadata = replace_lora_metadata(
|
|
172
|
+
self.model, lora_metadata, self.vllm_config.lora_config)
|
|
173
|
+
hidden_states = torch.func.functional_call(
|
|
174
|
+
self.model,
|
|
175
|
+
torch_view(params_and_buffers),
|
|
176
|
+
kwargs={
|
|
177
|
+
"input_ids": torch_view(input_ids),
|
|
178
|
+
"positions": torch_view(attn_metadata.input_positions),
|
|
179
|
+
"intermediate_tensors": None,
|
|
180
|
+
"inputs_embeds": None,
|
|
181
|
+
},
|
|
182
|
+
tie_weights=False,
|
|
183
|
+
)
|
|
184
|
+
replace_lora_metadata(self.model, original_lora_metadata,
|
|
185
|
+
self.vllm_config.lora_config)
|
|
186
|
+
vllm_model_wrapper_context = get_vllm_model_wrapper_context()
|
|
187
|
+
new_kv_caches = vllm_model_wrapper_context.kv_caches
|
|
188
|
+
# Wrap the hidden_states from torch land into a JaxValue for the jax
|
|
189
|
+
# code to consume.
|
|
190
|
+
hidden_states = jax_view(hidden_states)
|
|
191
|
+
|
|
192
|
+
return new_kv_caches, hidden_states, []
|
|
193
|
+
|
|
194
|
+
return step_fun
|
|
195
|
+
|
|
196
|
+
def jit_compute_logits_func(self):
|
|
197
|
+
|
|
198
|
+
@functools.partial(
|
|
199
|
+
jax.jit,
|
|
200
|
+
out_shardings=(NamedSharding(self.mesh,
|
|
201
|
+
PartitionSpec(None, "model"))),
|
|
202
|
+
)
|
|
203
|
+
def compute_logits_func(
|
|
204
|
+
params_and_buffers: Any,
|
|
205
|
+
hidden_states: jax.Array,
|
|
206
|
+
lora_metadata,
|
|
207
|
+
) -> jax.Array:
|
|
208
|
+
lora_metadata = torch_view(lora_metadata)
|
|
209
|
+
with torchax.default_env(), set_vllm_model_wrapper_context(
|
|
210
|
+
kv_caches=None, mesh=self.mesh):
|
|
211
|
+
original_lora_metadata = replace_lora_metadata(
|
|
212
|
+
self.model, lora_metadata, self.vllm_config.lora_config)
|
|
213
|
+
logits = torch.func.functional_call(
|
|
214
|
+
self.model,
|
|
215
|
+
torch_view(params_and_buffers),
|
|
216
|
+
kwargs={
|
|
217
|
+
"hidden_state": torch_view(hidden_states),
|
|
218
|
+
},
|
|
219
|
+
tie_weights=False,
|
|
220
|
+
)
|
|
221
|
+
replace_lora_metadata(self.model, original_lora_metadata,
|
|
222
|
+
self.vllm_config.lora_config)
|
|
223
|
+
return jax_view(logits)
|
|
224
|
+
|
|
225
|
+
return compute_logits_func
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def load_lora_model(model: torch.nn.Module, vllm_config: VllmConfig,
|
|
229
|
+
device: str) -> torch.nn.Module:
|
|
230
|
+
if not supports_lora(model):
|
|
231
|
+
raise ValueError(
|
|
232
|
+
f"{model.__class__.__name__} does not support LoRA yet.")
|
|
233
|
+
|
|
234
|
+
if supports_multimodal(model):
|
|
235
|
+
logger.warning("Regarding multimodal models, vLLM currently "
|
|
236
|
+
"only supports adding LoRA to language model.")
|
|
237
|
+
|
|
238
|
+
# Add LoRA Manager to the Model Runner
|
|
239
|
+
lora_manager = LRUCacheWorkerLoRAManager(
|
|
240
|
+
vllm_config,
|
|
241
|
+
device,
|
|
242
|
+
model.embedding_modules,
|
|
243
|
+
model.embedding_padding_modules,
|
|
244
|
+
)
|
|
245
|
+
return lora_manager, lora_manager.create_lora_manager(model)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
# The reason why replace the method is that the set_lora and reset_lora need to
|
|
249
|
+
# run under torchax env.
|
|
250
|
+
def replace_set_lora(model):
|
|
251
|
+
|
|
252
|
+
def _tpu_set_lora(
|
|
253
|
+
self,
|
|
254
|
+
index: int,
|
|
255
|
+
lora_a: torch.Tensor,
|
|
256
|
+
lora_b: torch.Tensor,
|
|
257
|
+
embeddings_tensor: Optional[torch.Tensor],
|
|
258
|
+
):
|
|
259
|
+
with torchax.default_env():
|
|
260
|
+
self._original_set_lora(index, lora_a, lora_b, embeddings_tensor)
|
|
261
|
+
|
|
262
|
+
def _tpu_reset_lora(self, index: int):
|
|
263
|
+
with torchax.default_env():
|
|
264
|
+
self._original_reset_lora(index)
|
|
265
|
+
|
|
266
|
+
for _, module in model.named_modules():
|
|
267
|
+
if isinstance(module, BaseLayerWithLoRA):
|
|
268
|
+
module._original_set_lora = module.set_lora
|
|
269
|
+
module._original_reset_lora = module.reset_lora
|
|
270
|
+
module.set_lora = _tpu_set_lora.__get__(module, module.__class__)
|
|
271
|
+
module.reset_lora = _tpu_reset_lora.__get__(
|
|
272
|
+
module, module.__class__)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
from jax.sharding import Mesh
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class VllmModelWrapperContext:
|
|
11
|
+
kv_caches: List[jax.Array]
|
|
12
|
+
mesh: Mesh
|
|
13
|
+
layer_name_to_kvcache_index: Dict[str, int]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
_vllm_model_wrapper_context: Optional[VllmModelWrapperContext] = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_vllm_model_wrapper_context() -> VllmModelWrapperContext:
|
|
20
|
+
assert _vllm_model_wrapper_context is not None, (
|
|
21
|
+
"VllmModelWrapperContext is not set. "
|
|
22
|
+
"Please use `set_vllm_model_wrapper_context` to set the VllmModelWrapperContext."
|
|
23
|
+
)
|
|
24
|
+
return _vllm_model_wrapper_context
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@contextmanager
|
|
28
|
+
def set_vllm_model_wrapper_context(
|
|
29
|
+
*,
|
|
30
|
+
kv_caches: List[jax.Array],
|
|
31
|
+
mesh: Mesh,
|
|
32
|
+
layer_name_to_kvcache_index: Dict[str, int] = None,
|
|
33
|
+
):
|
|
34
|
+
global _vllm_model_wrapper_context
|
|
35
|
+
prev_context = _vllm_model_wrapper_context
|
|
36
|
+
_vllm_model_wrapper_context = VllmModelWrapperContext(
|
|
37
|
+
kv_caches=kv_caches,
|
|
38
|
+
mesh=mesh,
|
|
39
|
+
layer_name_to_kvcache_index=layer_name_to_kvcache_index,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
yield
|
|
44
|
+
finally:
|
|
45
|
+
_vllm_model_wrapper_context = prev_context
|
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
|
|
5
|
+
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import vllm.envs as envs
|
|
8
|
+
from torchax.ops.mappings import j2t_dtype
|
|
9
|
+
from tpu_info import device
|
|
10
|
+
from vllm.inputs import ProcessorInputs, PromptType
|
|
11
|
+
from vllm.platforms.interface import Platform, PlatformEnum
|
|
12
|
+
from vllm.sampling_params import SamplingParams, SamplingType
|
|
13
|
+
|
|
14
|
+
from tpu_inference.logger import init_logger
|
|
15
|
+
from tpu_inference.models.jax.utils.quantization.quantization_utils import \
|
|
16
|
+
update_vllm_config_for_qwix_quantization
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from vllm.attention.backends.registry import _Backend
|
|
20
|
+
from vllm.config import BlockSize, ModelConfig, VllmConfig
|
|
21
|
+
from vllm.pooling_params import PoolingParams
|
|
22
|
+
else:
|
|
23
|
+
BlockSize = None
|
|
24
|
+
ModelConfig = None
|
|
25
|
+
VllmConfig = None
|
|
26
|
+
PoolingParams = None
|
|
27
|
+
_Backend = None
|
|
28
|
+
|
|
29
|
+
logger = init_logger(__name__)
|
|
30
|
+
|
|
31
|
+
_DTYPE: dict[str, jnp.dtype] = {
|
|
32
|
+
"bfloat16": jnp.bfloat16,
|
|
33
|
+
"float": jnp.float32,
|
|
34
|
+
"float32": jnp.float32,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TpuPlatform(Platform):
|
|
39
|
+
_enum = PlatformEnum.TPU
|
|
40
|
+
device_name: str = "tpu"
|
|
41
|
+
device_type: str = "tpu"
|
|
42
|
+
dispatch_key: str = "XLA"
|
|
43
|
+
ray_device_key: str = "TPU"
|
|
44
|
+
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
|
|
45
|
+
simple_compile_backend: str = "openxla"
|
|
46
|
+
|
|
47
|
+
supported_quantization: list[str] = [
|
|
48
|
+
"tpu_int8", "compressed-tensors", "awq", "fp8"
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
additional_env_vars: list[str] = [
|
|
52
|
+
"JAX_RANDOM_WEIGHTS", "PHASED_PROFILING_DIR",
|
|
53
|
+
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
|
|
54
|
+
"TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "NEW_MODEL_DESIGN",
|
|
55
|
+
"TPU_BACKEND_TYPE"
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
|
60
|
+
dtype: jnp.dtype, kv_cache_dtype: Optional[str],
|
|
61
|
+
block_size: int, use_v1: bool, use_mla: bool,
|
|
62
|
+
has_sink: bool, use_sparse: bool) -> str:
|
|
63
|
+
from vllm.attention.backends.registry import _Backend
|
|
64
|
+
if selected_backend != _Backend.PALLAS:
|
|
65
|
+
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
|
66
|
+
|
|
67
|
+
if use_v1:
|
|
68
|
+
logger.info("Using Pallas V1 backend.")
|
|
69
|
+
return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
|
|
70
|
+
else:
|
|
71
|
+
logger.info("Using Pallas backend.")
|
|
72
|
+
return "vllm.attention.backends.pallas.PallasAttentionBackend"
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def get_device_name(cls, device_id: int = 0) -> str:
|
|
76
|
+
try:
|
|
77
|
+
if envs.VLLM_TPU_USING_PATHWAYS:
|
|
78
|
+
# Causes mutliprocess accessing IFRT when calling jax.devices()
|
|
79
|
+
return "TPU v6 lite"
|
|
80
|
+
else:
|
|
81
|
+
chip_type, _ = device.get_local_chips()
|
|
82
|
+
return f"TPU {chip_type.name}"
|
|
83
|
+
except Exception as e:
|
|
84
|
+
logger.warning(f"Error getting device name: {e}")
|
|
85
|
+
return 'TPU'
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
|
89
|
+
raise NotImplementedError
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
|
93
|
+
return not envs.VLLM_USE_V1
|
|
94
|
+
|
|
95
|
+
@classmethod
|
|
96
|
+
def get_punica_wrapper(cls) -> str:
|
|
97
|
+
return "tpu_inference.lora.torch_punica_tpu.PunicaWrapperTPU"
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def get_infinity_values(cls, dtype: jnp.dtype) -> Tuple[float, float]:
|
|
101
|
+
return jnp.finfo(dtype).min, jnp.finfo(dtype).max
|
|
102
|
+
|
|
103
|
+
@classmethod
|
|
104
|
+
def can_update_inplace(cls):
|
|
105
|
+
return False
|
|
106
|
+
|
|
107
|
+
@classmethod
|
|
108
|
+
def get_lora_vocab_padding_size(cls) -> int:
|
|
109
|
+
return 1
|
|
110
|
+
|
|
111
|
+
@classmethod
|
|
112
|
+
def inference_mode(cls):
|
|
113
|
+
return True
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
|
117
|
+
if not envs.VLLM_USE_V1:
|
|
118
|
+
raise RuntimeError("VLLM_USE_V1=1 must be set for JAX backend.")
|
|
119
|
+
|
|
120
|
+
if envs.VLLM_TPU_USING_PATHWAYS:
|
|
121
|
+
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, (
|
|
122
|
+
"VLLM_ENABLE_V1_MULTIPROCESSING must be 0 when using Pathways(JAX_PLATFORMS=proxy)"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
from vllm.config import CompilationLevel
|
|
126
|
+
|
|
127
|
+
cache_config = vllm_config.cache_config
|
|
128
|
+
# For v0, the default block size is 16.
|
|
129
|
+
if cache_config and cache_config.block_size is None:
|
|
130
|
+
cache_config.block_size = cast(BlockSize, 16)
|
|
131
|
+
compilation_config = vllm_config.compilation_config
|
|
132
|
+
|
|
133
|
+
# TPU only supports DYNAMO_ONCE compilation level
|
|
134
|
+
# NOTE(xiang): the compilation_config is not used by jax.
|
|
135
|
+
if compilation_config.level != CompilationLevel.DYNAMO_ONCE:
|
|
136
|
+
compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
|
137
|
+
|
|
138
|
+
if compilation_config.backend == "":
|
|
139
|
+
compilation_config.backend = "openxla"
|
|
140
|
+
|
|
141
|
+
# If we use vLLM's model implementation in PyTorch, we should set it with torch version of the dtype.
|
|
142
|
+
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
|
|
143
|
+
|
|
144
|
+
# NOTE(xiang): convert dtype to jnp.dtype
|
|
145
|
+
# NOTE(wenlong): skip this logic for mm model preprocessing
|
|
146
|
+
# For mm model preprocessors, it may need the output dtype to be torch.
|
|
147
|
+
# In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
|
|
148
|
+
if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
|
|
149
|
+
if not isinstance(vllm_config.model_config.dtype, str):
|
|
150
|
+
logger.warning(
|
|
151
|
+
"The model dtype is not properly set for JAX backend. "
|
|
152
|
+
"Overwriting it to jnp.bfloat16")
|
|
153
|
+
vllm_config.model_config.dtype = jnp.bfloat16
|
|
154
|
+
else:
|
|
155
|
+
vllm_config.model_config.dtype = _DTYPE.get(
|
|
156
|
+
vllm_config.model_config.dtype, jnp.bfloat16)
|
|
157
|
+
|
|
158
|
+
if impl == "vllm":
|
|
159
|
+
vllm_config.model_config.dtype = j2t_dtype(
|
|
160
|
+
vllm_config.model_config.dtype.dtype)
|
|
161
|
+
|
|
162
|
+
if envs.VLLM_USE_V1:
|
|
163
|
+
# TODO(cuiq): remove this dependency.
|
|
164
|
+
from vllm.v1.attention.backends.pallas import \
|
|
165
|
+
PallasAttentionBackend
|
|
166
|
+
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
|
167
|
+
vllm_config) # type: ignore[assignment]
|
|
168
|
+
min_page_size = PallasAttentionBackend.get_min_page_size(
|
|
169
|
+
vllm_config)
|
|
170
|
+
if min_page_size > cache_config.block_size:
|
|
171
|
+
logger.warning(
|
|
172
|
+
"Increase the page size from %s to %s to make sure there's"
|
|
173
|
+
"no SMEM OOM",
|
|
174
|
+
cache_config.block_size,
|
|
175
|
+
min_page_size,
|
|
176
|
+
)
|
|
177
|
+
cache_config.block_size = min_page_size # type: ignore[assignment]
|
|
178
|
+
|
|
179
|
+
parallel_config = vllm_config.parallel_config
|
|
180
|
+
scheduler_config = vllm_config.scheduler_config
|
|
181
|
+
parallel_config.worker_cls = \
|
|
182
|
+
"tpu_inference.worker.tpu_worker_jax.TPUWorker"
|
|
183
|
+
|
|
184
|
+
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
185
|
+
if not multihost_backend: # Single host
|
|
186
|
+
logger.info("Force using UniProcExecutor for JAX on single host.")
|
|
187
|
+
parallel_config.distributed_executor_backend = "uni"
|
|
188
|
+
elif multihost_backend == "ray":
|
|
189
|
+
from tpu_inference.executors.ray_distributed_executor import \
|
|
190
|
+
RayDistributedExecutor
|
|
191
|
+
parallel_config.distributed_executor_backend = RayDistributedExecutor
|
|
192
|
+
logger.info(
|
|
193
|
+
"Force using RayDistributedExecutor for JAX on single host.")
|
|
194
|
+
else:
|
|
195
|
+
logger.warning(
|
|
196
|
+
f"Unknown TPU multihost backend: {multihost_backend}. "
|
|
197
|
+
"Using uniproc_executor.")
|
|
198
|
+
parallel_config.distributed_executor_backend = "uni"
|
|
199
|
+
|
|
200
|
+
if scheduler_config.is_multimodal_model and not \
|
|
201
|
+
scheduler_config.disable_chunked_mm_input:
|
|
202
|
+
logger.warning("TPU does not support running Multimodal models"\
|
|
203
|
+
" without setting `--disable_chunked_mm_input`. " \
|
|
204
|
+
"Forcing --disable_chunked_mm_input.")
|
|
205
|
+
scheduler_config.disable_chunked_mm_input = True
|
|
206
|
+
|
|
207
|
+
kv_transfer_config = vllm_config.kv_transfer_config
|
|
208
|
+
if kv_transfer_config is not None:
|
|
209
|
+
assert kv_transfer_config.kv_connector == "TPUConnector"
|
|
210
|
+
|
|
211
|
+
update_vllm_config_for_qwix_quantization(vllm_config)
|
|
212
|
+
|
|
213
|
+
@classmethod
|
|
214
|
+
def is_pin_memory_available(cls):
|
|
215
|
+
logger.warning("Pin memory is not supported on TPU.")
|
|
216
|
+
return False
|
|
217
|
+
|
|
218
|
+
@classmethod
|
|
219
|
+
def get_device_communicator_cls(cls) -> str:
|
|
220
|
+
return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa
|
|
221
|
+
|
|
222
|
+
@classmethod
|
|
223
|
+
def use_all_gather(cls) -> bool:
|
|
224
|
+
return True
|
|
225
|
+
|
|
226
|
+
@classmethod
|
|
227
|
+
def supports_v1(cls, model_config: ModelConfig) -> bool:
|
|
228
|
+
# V1 support on TPU is experimental
|
|
229
|
+
return True
|
|
230
|
+
|
|
231
|
+
@classmethod
|
|
232
|
+
def validate_request(
|
|
233
|
+
cls,
|
|
234
|
+
prompt: PromptType,
|
|
235
|
+
params: Union[SamplingParams, PoolingParams],
|
|
236
|
+
processed_inputs: ProcessorInputs,
|
|
237
|
+
) -> None:
|
|
238
|
+
"""Raises if this request is unsupported on this platform"""
|
|
239
|
+
|
|
240
|
+
if isinstance(params, SamplingParams):
|
|
241
|
+
if params.structured_outputs is not None and not envs.VLLM_USE_V1:
|
|
242
|
+
raise ValueError("Structured output is not supported on "
|
|
243
|
+
f"{cls.device_name} V0.")
|
|
244
|
+
if params.sampling_type == SamplingType.RANDOM_SEED:
|
|
245
|
+
raise ValueError("JAX does not support per-request seed.")
|
|
246
|
+
|
|
247
|
+
@classmethod
|
|
248
|
+
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
|
|
249
|
+
model_config: ModelConfig) -> bool:
|
|
250
|
+
return True
|
|
251
|
+
|
|
252
|
+
@classmethod
|
|
253
|
+
def use_sync_weight_loader(cls) -> bool:
|
|
254
|
+
"""
|
|
255
|
+
Returns if the current platform needs to sync weight loader.
|
|
256
|
+
"""
|
|
257
|
+
return True
|
|
File without changes
|