tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,479 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import math
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, List
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import vllm.envs as envs
|
|
8
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
9
|
+
from torchax.ops.mappings import t2j_dtype
|
|
10
|
+
from vllm.attention import Attention
|
|
11
|
+
from vllm.attention.backends.abstract import AttentionType
|
|
12
|
+
from vllm.config import get_layers_from_vllm_config
|
|
13
|
+
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|
14
|
+
KVCacheSpec, MLAAttentionSpec,
|
|
15
|
+
SlidingWindowSpec)
|
|
16
|
+
|
|
17
|
+
from tpu_inference import utils
|
|
18
|
+
from tpu_inference import utils as common_utils
|
|
19
|
+
from tpu_inference.logger import init_logger
|
|
20
|
+
from tpu_inference.runner import utils as runner_utils
|
|
21
|
+
from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
|
|
22
|
+
from tpu_inference.runner.kv_cache import create_kv_caches
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from vllm.v1.request import Request
|
|
26
|
+
|
|
27
|
+
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
28
|
+
|
|
29
|
+
logger = init_logger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class KVCacheManager:
|
|
33
|
+
|
|
34
|
+
def __init__(self, runner: "TPUModelRunner"):
|
|
35
|
+
self.runner = runner
|
|
36
|
+
# Layer pairings for cross-layer KV sharing.
|
|
37
|
+
# If an Attention layer `layer_name` is in the keys of this dict, it
|
|
38
|
+
# means this layer will perform attention using the keys and values
|
|
39
|
+
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
|
40
|
+
self.shared_kv_cache_layers: dict[str, str] = {}
|
|
41
|
+
|
|
42
|
+
def get_kv_cache_spec(self):
|
|
43
|
+
# TODO(xiang): this hack tricks engine core to init successfully
|
|
44
|
+
block_size = self.runner.cache_config.block_size
|
|
45
|
+
use_mla = self.runner.model_config.use_mla
|
|
46
|
+
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
|
47
|
+
|
|
48
|
+
# If use pure jax (MODEL_IMPL_TYPE=flax_nnx), we don't register
|
|
49
|
+
# attention into compilation config.
|
|
50
|
+
# Use FullAttentionSpec for each layer
|
|
51
|
+
# TODO(pooyam): Is it possible to merge the logic for vllm and non-vllm models?
|
|
52
|
+
if len(self.runner.vllm_config.compilation_config.
|
|
53
|
+
static_forward_context) == 0:
|
|
54
|
+
model_config = self.runner.model_config
|
|
55
|
+
parallel_config = self.runner.parallel_config
|
|
56
|
+
# Pad num_kv_heads to multiple of TP size.
|
|
57
|
+
num_kv_heads = common_utils.get_padded_num_heads(
|
|
58
|
+
model_config.get_total_num_kv_heads(),
|
|
59
|
+
self.runner.mesh.shape["model"])
|
|
60
|
+
head_size = common_utils.get_padded_head_dim(
|
|
61
|
+
model_config.get_head_size())
|
|
62
|
+
for i in range(model_config.get_num_layers(parallel_config)):
|
|
63
|
+
if use_mla:
|
|
64
|
+
kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
|
|
65
|
+
block_size=block_size,
|
|
66
|
+
num_kv_heads=num_kv_heads,
|
|
67
|
+
head_size=head_size,
|
|
68
|
+
dtype=self.runner.kv_cache_dtype,
|
|
69
|
+
cache_dtype_str=self.runner.vllm_config.cache_config.
|
|
70
|
+
cache_dtype)
|
|
71
|
+
else:
|
|
72
|
+
kv_cache_spec[f"layer.{i}"] = FullAttentionSpec(
|
|
73
|
+
block_size=block_size,
|
|
74
|
+
num_kv_heads=num_kv_heads,
|
|
75
|
+
head_size=head_size,
|
|
76
|
+
dtype=self.runner.kv_cache_dtype)
|
|
77
|
+
if self.runner.speculative_config and self.runner.speculative_config.method == "eagle3":
|
|
78
|
+
draft_model_config = self.runner.speculative_config.draft_model_config
|
|
79
|
+
hf_config = draft_model_config.hf_config
|
|
80
|
+
num_kv_heads = common_utils.get_padded_num_heads(
|
|
81
|
+
hf_config.num_key_value_heads,
|
|
82
|
+
self.runner.mesh.shape["model"])
|
|
83
|
+
head_size = common_utils.get_padded_head_dim(
|
|
84
|
+
hf_config.hidden_size // hf_config.num_attention_heads)
|
|
85
|
+
|
|
86
|
+
# Eagle3 has only 1 layer
|
|
87
|
+
for i in range(1):
|
|
88
|
+
if use_mla:
|
|
89
|
+
kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
|
|
90
|
+
block_size=block_size,
|
|
91
|
+
num_kv_heads=num_kv_heads,
|
|
92
|
+
head_size=head_size,
|
|
93
|
+
dtype=self.runner.kv_cache_dtype,
|
|
94
|
+
cache_dtype_str=self.runner.vllm_config.
|
|
95
|
+
cache_config.cache_dtype)
|
|
96
|
+
else:
|
|
97
|
+
kv_cache_spec[f"draft_layer.{i}"] = FullAttentionSpec(
|
|
98
|
+
block_size=block_size,
|
|
99
|
+
num_kv_heads=num_kv_heads,
|
|
100
|
+
head_size=head_size,
|
|
101
|
+
dtype=self.runner.kv_cache_dtype)
|
|
102
|
+
else:
|
|
103
|
+
# Else propagate attention modules from compilation config.
|
|
104
|
+
layers = get_layers_from_vllm_config(self.runner.vllm_config,
|
|
105
|
+
Attention)
|
|
106
|
+
for layer_name, attn_module in layers.items():
|
|
107
|
+
if (kv_tgt_layer :=
|
|
108
|
+
attn_module.kv_sharing_target_layer_name) is not None:
|
|
109
|
+
# The layer doesn't need its own KV cache and will use that of
|
|
110
|
+
# the target layer. We skip creating a KVCacheSpec for it, so
|
|
111
|
+
# that KV cache management logic will act as this layer does
|
|
112
|
+
# not exist, and doesn't allocate KV cache for the layer. This
|
|
113
|
+
# enables the memory saving of cross-layer kv sharing, allowing
|
|
114
|
+
# a given amount of memory to accommodate longer context lengths
|
|
115
|
+
# or enable more requests to be processed simultaneously.
|
|
116
|
+
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
|
117
|
+
continue
|
|
118
|
+
if attn_module.attn_type == AttentionType.DECODER:
|
|
119
|
+
if attn_module.sliding_window is not None:
|
|
120
|
+
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
|
121
|
+
block_size=block_size,
|
|
122
|
+
num_kv_heads=common_utils.get_padded_num_heads(
|
|
123
|
+
attn_module.num_kv_heads,
|
|
124
|
+
self.runner.mesh.shape["model"]),
|
|
125
|
+
head_size=common_utils.get_padded_head_dim(
|
|
126
|
+
attn_module.head_size),
|
|
127
|
+
dtype=self.runner.kv_cache_dtype,
|
|
128
|
+
sliding_window=attn_module.sliding_window)
|
|
129
|
+
elif use_mla:
|
|
130
|
+
kv_cache_spec[f"layer.{i}"] = MLAAttentionSpec(
|
|
131
|
+
block_size=block_size,
|
|
132
|
+
num_kv_heads=attn_module.num_kv_heads,
|
|
133
|
+
head_size=attn_module.head_size,
|
|
134
|
+
dtype=self.runner.kv_cache_dtype,
|
|
135
|
+
cache_dtype_str=self.runner.vllm_config.
|
|
136
|
+
cache_config.cache_dtype)
|
|
137
|
+
else:
|
|
138
|
+
kv_cache_spec[layer_name] = FullAttentionSpec(
|
|
139
|
+
block_size=block_size,
|
|
140
|
+
num_kv_heads=common_utils.get_padded_num_heads(
|
|
141
|
+
attn_module.num_kv_heads,
|
|
142
|
+
self.runner.mesh.shape["model"]),
|
|
143
|
+
head_size=common_utils.get_padded_head_dim(
|
|
144
|
+
attn_module.head_size),
|
|
145
|
+
dtype=self.runner.kv_cache_dtype)
|
|
146
|
+
elif attn_module.attn_type in (AttentionType.ENCODER,
|
|
147
|
+
AttentionType.ENCODER_ONLY):
|
|
148
|
+
# encoder-only attention does not need KV cache.
|
|
149
|
+
continue
|
|
150
|
+
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
|
151
|
+
raise NotImplementedError
|
|
152
|
+
else:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"Unknown attention type: {attn_module.attn_type}")
|
|
155
|
+
return kv_cache_spec
|
|
156
|
+
|
|
157
|
+
def maybe_reinitialize_input_batch(self,
|
|
158
|
+
kv_cache_config: KVCacheConfig) -> None:
|
|
159
|
+
block_sizes = [
|
|
160
|
+
kv_cache_group.kv_cache_spec.block_size
|
|
161
|
+
for kv_cache_group in kv_cache_config.kv_cache_groups
|
|
162
|
+
]
|
|
163
|
+
if block_sizes != [self.runner.cache_config.block_size]:
|
|
164
|
+
assert self.runner.cache_config.cpu_offload_gb == 0, (
|
|
165
|
+
"Cannot re-initialize the input batch when CPU weight "
|
|
166
|
+
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
|
167
|
+
"for more details.")
|
|
168
|
+
new_input_batch = InputBatch(
|
|
169
|
+
max_num_reqs=self.runner.max_num_reqs,
|
|
170
|
+
max_model_len=self.runner.max_model_len,
|
|
171
|
+
max_num_batched_tokens=self.runner.max_num_tokens,
|
|
172
|
+
pin_memory=False,
|
|
173
|
+
vocab_size=self.runner.model_config.get_vocab_size(),
|
|
174
|
+
block_sizes=block_sizes,
|
|
175
|
+
)
|
|
176
|
+
self.runner.input_batch = new_input_batch
|
|
177
|
+
self.runner.persistent_batch_manager.input_batch = new_input_batch
|
|
178
|
+
|
|
179
|
+
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
|
180
|
+
self.maybe_reinitialize_input_batch(kv_cache_config)
|
|
181
|
+
|
|
182
|
+
# uniform page size.
|
|
183
|
+
representative_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
|
184
|
+
page_size_bytes = representative_spec.page_size_bytes
|
|
185
|
+
self.runner.layer_name_to_kvcache_index: Dict[str, int] = {}
|
|
186
|
+
kv_caches = self.runner.kv_caches
|
|
187
|
+
num_blocks_list = []
|
|
188
|
+
for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors):
|
|
189
|
+
assert kv_cache_tensor.size % page_size_bytes == 0
|
|
190
|
+
num_blocks = kv_cache_tensor.size // page_size_bytes
|
|
191
|
+
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
192
|
+
# num_blocks must be a multiple of dp_size
|
|
193
|
+
num_blocks = math.ceil(num_blocks / dp_size) * dp_size
|
|
194
|
+
# NOTE: we'll multiply the num_kv_heads by 2 in the function
|
|
195
|
+
kv_cache = create_kv_caches(
|
|
196
|
+
num_blocks=num_blocks,
|
|
197
|
+
block_size=representative_spec.block_size,
|
|
198
|
+
num_kv_heads=representative_spec.num_kv_heads,
|
|
199
|
+
head_size=representative_spec.head_size,
|
|
200
|
+
mesh=self.runner.mesh,
|
|
201
|
+
layer_names=[f'kv_cache_tensor.{i}'],
|
|
202
|
+
cache_dtype=t2j_dtype(representative_spec.dtype),
|
|
203
|
+
)[0]
|
|
204
|
+
kv_caches.append(kv_cache)
|
|
205
|
+
num_blocks_list.append(num_blocks)
|
|
206
|
+
for layer_name in kv_cache_tensor.shared_by:
|
|
207
|
+
self.runner.layer_name_to_kvcache_index[layer_name] = i
|
|
208
|
+
|
|
209
|
+
if self.shared_kv_cache_layers:
|
|
210
|
+
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
|
|
211
|
+
):
|
|
212
|
+
self.runner.layer_name_to_kvcache_index[
|
|
213
|
+
layer_name] = self.runner.layer_name_to_kvcache_index[
|
|
214
|
+
target_layer_name]
|
|
215
|
+
|
|
216
|
+
logger.info(
|
|
217
|
+
f"Init kv-cache | "
|
|
218
|
+
f"num_layers={len(kv_caches)} | "
|
|
219
|
+
f"shape=(num_blocks, {kv_caches[0].shape[1:]}) | "
|
|
220
|
+
f"num_blocks={num_blocks_list} | "
|
|
221
|
+
f"sharding={kv_caches[0].sharding} | "
|
|
222
|
+
f"dtype={kv_caches[0].dtype} | "
|
|
223
|
+
f"hbm={utils.hbm_usage_gb(self.runner.mesh.devices.flatten())}Gb")
|
|
224
|
+
|
|
225
|
+
@staticmethod
|
|
226
|
+
@functools.partial(jax.jit)
|
|
227
|
+
def _jitted_gather_kv_cache(kv_caches: List[jax.Array],
|
|
228
|
+
block_ids: jax.Array) -> List[jax.Array]:
|
|
229
|
+
"""
|
|
230
|
+
JIT-compiled function to gather KV cache slices for all layers at once.
|
|
231
|
+
This uses jax.tree.map to apply the operation across all layers.
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
def gather_and_reshape(layer_kv_cache):
|
|
235
|
+
return layer_kv_cache.at[block_ids].get().reshape(
|
|
236
|
+
-1, *layer_kv_cache.shape[2:])
|
|
237
|
+
|
|
238
|
+
return jax.tree.map(gather_and_reshape, kv_caches)
|
|
239
|
+
|
|
240
|
+
@staticmethod
|
|
241
|
+
@functools.partial(
|
|
242
|
+
jax.jit,
|
|
243
|
+
static_argnames=("len_block"),
|
|
244
|
+
)
|
|
245
|
+
def _jitted_gather_continuous_kv_cache(kv_caches: List[jax.Array],
|
|
246
|
+
start_block,
|
|
247
|
+
len_block) -> List[jax.Array]:
|
|
248
|
+
"""
|
|
249
|
+
JIT-compiled function to gather KV cache slices for all layers at once.
|
|
250
|
+
This uses jax.tree.map to apply the operation across all layers.
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
def gather_and_reshape(layer_kv_cache):
|
|
254
|
+
shape = layer_kv_cache.shape
|
|
255
|
+
return jax.lax.dynamic_slice_in_dim(layer_kv_cache,
|
|
256
|
+
start_block,
|
|
257
|
+
len_block,
|
|
258
|
+
axis=0).reshape(
|
|
259
|
+
-1, *shape[2:])
|
|
260
|
+
|
|
261
|
+
return jax.tree.map(gather_and_reshape, kv_caches)
|
|
262
|
+
|
|
263
|
+
@staticmethod
|
|
264
|
+
@functools.partial(
|
|
265
|
+
jax.jit,
|
|
266
|
+
static_argnames=("block_size"),
|
|
267
|
+
donate_argnames=(
|
|
268
|
+
"kv_caches",
|
|
269
|
+
"kv_cache_slices",
|
|
270
|
+
),
|
|
271
|
+
)
|
|
272
|
+
def _jitted_insert_kv_cache(
|
|
273
|
+
block_size,
|
|
274
|
+
kv_caches: List[jax.Array],
|
|
275
|
+
kv_cache_slices: List[jax.Array],
|
|
276
|
+
block_numbers: jax.Array,
|
|
277
|
+
) -> List[jax.Array]:
|
|
278
|
+
"""
|
|
279
|
+
JIT-compiled function to insert KV cache slices into the physical
|
|
280
|
+
cache for all layers at once. This fuses the pad, reshape, and scatter
|
|
281
|
+
operations into a single efficient kernel.
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
def _update_layer(cache, slices):
|
|
285
|
+
"""The function to apply to each layer's cache and slices."""
|
|
286
|
+
reshaped_slices = slices.reshape(-1, 1, block_size,
|
|
287
|
+
*slices.shape[1:])
|
|
288
|
+
for (i, block_idx) in enumerate(block_numbers):
|
|
289
|
+
cache = jax.lax.dynamic_update_slice_in_dim(cache,
|
|
290
|
+
reshaped_slices[i],
|
|
291
|
+
block_idx,
|
|
292
|
+
axis=0)
|
|
293
|
+
return cache
|
|
294
|
+
|
|
295
|
+
return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
|
|
296
|
+
|
|
297
|
+
@staticmethod
|
|
298
|
+
@functools.partial(
|
|
299
|
+
jax.jit,
|
|
300
|
+
static_argnames=("block_size"),
|
|
301
|
+
donate_argnames=(
|
|
302
|
+
"kv_caches",
|
|
303
|
+
"kv_cache_slices",
|
|
304
|
+
),
|
|
305
|
+
)
|
|
306
|
+
def _jitted_insert_continuous_kv_cache(
|
|
307
|
+
block_size,
|
|
308
|
+
kv_caches: List[jax.Array],
|
|
309
|
+
kv_cache_slices: List[jax.Array],
|
|
310
|
+
start_block,
|
|
311
|
+
) -> List[jax.Array]:
|
|
312
|
+
"""
|
|
313
|
+
JIT-compiled function to insert KV cache slices into continuous blocks.
|
|
314
|
+
Makes use of dynamic_update_slice_in_dim.
|
|
315
|
+
"""
|
|
316
|
+
|
|
317
|
+
def _update_layer(cache, slices):
|
|
318
|
+
"""The function to apply to each layer's cache and slices."""
|
|
319
|
+
reshaped_slices = slices.reshape(-1, block_size, *slices.shape[1:])
|
|
320
|
+
|
|
321
|
+
return jax.lax.dynamic_update_slice_in_dim(cache,
|
|
322
|
+
reshaped_slices,
|
|
323
|
+
start_block,
|
|
324
|
+
axis=0)
|
|
325
|
+
|
|
326
|
+
return jax.tree.map(_update_layer, kv_caches, kv_cache_slices)
|
|
327
|
+
|
|
328
|
+
def get_kv_cache_for_block_ids(
|
|
329
|
+
self,
|
|
330
|
+
block_ids: List[int],
|
|
331
|
+
) -> List[jax.Array]:
|
|
332
|
+
"""
|
|
333
|
+
Extracts the KV cache slices for a given list of block IDs.
|
|
334
|
+
This assumes all provided blocks are full.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
block_ids: A list of block IDs to extract KV cache for.
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
A list of JAX arrays, with each array representing the KV cache
|
|
341
|
+
slices for a layer, concatenated for all blocks.
|
|
342
|
+
"""
|
|
343
|
+
if block_ids == list(range(block_ids[0],
|
|
344
|
+
block_ids[0] + len(block_ids))):
|
|
345
|
+
with runner_utils.LatencyTracker(
|
|
346
|
+
"BatchedGatherKVSlices-for-blocks"):
|
|
347
|
+
batched_kv_cache_per_layer = self._jitted_gather_continuous_kv_cache(
|
|
348
|
+
self.runner.kv_caches, block_ids[0], len(block_ids))
|
|
349
|
+
|
|
350
|
+
else:
|
|
351
|
+
with runner_utils.LatencyTracker(
|
|
352
|
+
"BatchedGatherKVSlices-for-blocks"):
|
|
353
|
+
batched_kv_cache_per_layer = self._jitted_gather_kv_cache(
|
|
354
|
+
self.runner.kv_caches, jnp.array(block_ids))
|
|
355
|
+
return batched_kv_cache_per_layer
|
|
356
|
+
|
|
357
|
+
def transfer_kv_cache(self,
|
|
358
|
+
kv_cache_slices: List[jax.Array]) -> List[jax.Array]:
|
|
359
|
+
"""
|
|
360
|
+
Transfers KV cache slices to the runner's mesh.
|
|
361
|
+
|
|
362
|
+
This is used when a KV cache generated on one runner (e.g., a prefill
|
|
363
|
+
runner) needs to be used on another runner (e.g., a decode runner)
|
|
364
|
+
with a different device mesh. The transfer is asynchronous.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
kv_cache_slices: A list of JAX arrays, where each array contains
|
|
368
|
+
the KV cache slices for a specific layer. The shape of each
|
|
369
|
+
slice is expected to be (num_tokens, num_kv_heads * 2, head_size).
|
|
370
|
+
|
|
371
|
+
Returns:
|
|
372
|
+
A new list of JAX arrays representing the KV cache slices, sharded
|
|
373
|
+
across the runner's device mesh.
|
|
374
|
+
"""
|
|
375
|
+
# The KV cache slices have a shape of (num_tokens, num_kv_heads * 2, head_size).
|
|
376
|
+
# We shard along the num_kv_heads dimension (axis=1), which corresponds
|
|
377
|
+
# to the "model" axis of the mesh for tensor parallelism.
|
|
378
|
+
logger.debug(
|
|
379
|
+
f"Transferring kv cache shape {len(kv_cache_slices)} * {kv_cache_slices[0].shape} sharding {kv_cache_slices[0].sharding} size {kv_cache_slices[0].nbytes * len(kv_cache_slices)/1024/1024} Mbytes"
|
|
380
|
+
)
|
|
381
|
+
sharding = NamedSharding(self.runner.mesh,
|
|
382
|
+
PartitionSpec(None, "model"))
|
|
383
|
+
if envs.VLLM_TPU_USING_PATHWAYS:
|
|
384
|
+
from pathwaysutils.experimental import \
|
|
385
|
+
reshard as experimental_reshard
|
|
386
|
+
|
|
387
|
+
def get_sharding(x):
|
|
388
|
+
return sharding
|
|
389
|
+
|
|
390
|
+
sharding_spec_pytree = jax.tree.map(get_sharding, kv_cache_slices)
|
|
391
|
+
transferred_kv_cache = experimental_reshard.reshard(
|
|
392
|
+
kv_cache_slices,
|
|
393
|
+
sharding_spec_pytree,
|
|
394
|
+
donate=False,
|
|
395
|
+
)
|
|
396
|
+
else:
|
|
397
|
+
transferred_kv_cache = jax.device_put(kv_cache_slices, sharding)
|
|
398
|
+
|
|
399
|
+
jax.block_until_ready(transferred_kv_cache)
|
|
400
|
+
return transferred_kv_cache
|
|
401
|
+
|
|
402
|
+
def insert_request_with_kv_cache(
|
|
403
|
+
self,
|
|
404
|
+
request: "Request",
|
|
405
|
+
kv_cache_slices: List[jax.Array],
|
|
406
|
+
block_ids: List[List[int]],
|
|
407
|
+
):
|
|
408
|
+
"""
|
|
409
|
+
Inserts a request and its KV cache into the runner. This is used to
|
|
410
|
+
transfer a request from a prefill runner to a decode runner.
|
|
411
|
+
|
|
412
|
+
The provided KV cache slices are copied into the physical blocks
|
|
413
|
+
allocated for the request. The runner's internal state is then updated
|
|
414
|
+
to include the request.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
request: The vLLM request object, containing the state after prefill.
|
|
418
|
+
kv_cache_slices: The KV cache for the request, already transferred
|
|
419
|
+
to this runner's mesh. This is a list of JAX arrays, one per layer.
|
|
420
|
+
block_ids: The physical block numbers allocated for this request on
|
|
421
|
+
this runner. This is a list of lists, for each KV cache group.
|
|
422
|
+
"""
|
|
423
|
+
# Assume one KV cache group for now, which is consistent with current setup.
|
|
424
|
+
if len(block_ids) > 1:
|
|
425
|
+
raise NotImplementedError(
|
|
426
|
+
"Inserting KV cache for models with multiple KV cache groups "
|
|
427
|
+
"is not supported yet.")
|
|
428
|
+
block_numbers = block_ids[0]
|
|
429
|
+
if block_numbers == list(
|
|
430
|
+
range(block_numbers[0],
|
|
431
|
+
block_numbers[0] + len(block_numbers))):
|
|
432
|
+
# For continuous blocks we use slice instead of scatter.
|
|
433
|
+
start_block = block_numbers[0]
|
|
434
|
+
with runner_utils.LatencyTracker(
|
|
435
|
+
f"JittedInsertContinuousKVCache-b{len(block_numbers)}"):
|
|
436
|
+
logger.debug(f"inserting to continuous blocks {block_numbers}")
|
|
437
|
+
self.runner.kv_caches = KVCacheManager._jitted_insert_continuous_kv_cache(
|
|
438
|
+
self.runner.block_size,
|
|
439
|
+
self.runner.kv_caches,
|
|
440
|
+
kv_cache_slices,
|
|
441
|
+
start_block,
|
|
442
|
+
)
|
|
443
|
+
else:
|
|
444
|
+
with runner_utils.LatencyTracker(
|
|
445
|
+
f"JittedInsertKVCache-b{len(block_numbers)}"):
|
|
446
|
+
logger.debug(
|
|
447
|
+
f"inserting to non continuous blocks {block_numbers}")
|
|
448
|
+
self.runner.kv_caches = KVCacheManager._jitted_insert_kv_cache(
|
|
449
|
+
self.runner.block_size,
|
|
450
|
+
self.runner.kv_caches,
|
|
451
|
+
kv_cache_slices,
|
|
452
|
+
jnp.array(block_numbers),
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
logger.debug(
|
|
456
|
+
f"Updated kv cache entries cnt={len(self.runner.kv_caches)}")
|
|
457
|
+
|
|
458
|
+
# Update runner's internal state to track the new request.
|
|
459
|
+
req_id = request.request_id
|
|
460
|
+
if req_id in self.runner.requests:
|
|
461
|
+
logger.warning(
|
|
462
|
+
f"Request {req_id} already exists in the runner. Overwriting.")
|
|
463
|
+
|
|
464
|
+
# Create a CachedRequestState object to add to the input batch.
|
|
465
|
+
req_state = CachedRequestState(
|
|
466
|
+
req_id=request.request_id,
|
|
467
|
+
prompt_token_ids=request.prompt_token_ids,
|
|
468
|
+
output_token_ids=[request.all_token_ids[-1]],
|
|
469
|
+
sampling_params=request.sampling_params,
|
|
470
|
+
block_ids=tuple(block_ids),
|
|
471
|
+
num_computed_tokens=request.num_computed_tokens,
|
|
472
|
+
lora_request=request.lora_request,
|
|
473
|
+
mm_features=getattr(request, "mm_features", []),
|
|
474
|
+
pooling_params=getattr(request, "pooling_params", None),
|
|
475
|
+
generator=None,
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
self.runner.requests[req_id] = req_state
|
|
479
|
+
self.runner.input_batch.add_request(req_state)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from torchax.interop import jax_view
|
|
7
|
+
from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
|
|
8
|
+
from vllm.lora.request import LoRARequest
|
|
9
|
+
|
|
10
|
+
from tpu_inference.layers.vllm.sharding import update_lora
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class LoraUtils:
|
|
17
|
+
|
|
18
|
+
def __init__(self, runner: "TPUModelRunner"):
|
|
19
|
+
self.runner = runner
|
|
20
|
+
|
|
21
|
+
def set_active_loras(self, num_scheduled_tokens_per_req,
|
|
22
|
+
total_num_scheduled_tokens,
|
|
23
|
+
padded_total_num_scheduled_tokens):
|
|
24
|
+
# We need to respect padding when activating LoRA adapters
|
|
25
|
+
padded_num_scheduled_tokens_per_req = np.copy(
|
|
26
|
+
num_scheduled_tokens_per_req
|
|
27
|
+
) # Copying to avoid accidental state corruption bugs
|
|
28
|
+
padded_num_scheduled_tokens_per_req[-1] += \
|
|
29
|
+
padded_total_num_scheduled_tokens - total_num_scheduled_tokens
|
|
30
|
+
|
|
31
|
+
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
|
|
32
|
+
token_lora_mapping: tuple[int,
|
|
33
|
+
...] # of size np.sum(num_scheduled_tokens)
|
|
34
|
+
lora_requests: set[LoRARequest]
|
|
35
|
+
prompt_lora_mapping, token_lora_mapping, lora_requests = \
|
|
36
|
+
self.runner.input_batch.make_lora_inputs(padded_num_scheduled_tokens_per_req)
|
|
37
|
+
# One should not put lora_manager.set_active_loras under
|
|
38
|
+
# torchax.default_env() because set_active_loras also load lora from
|
|
39
|
+
# disk and torchax currently does not support that. Here we load the
|
|
40
|
+
# lora and set the lora weight to the linear layers.
|
|
41
|
+
self.runner._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
|
42
|
+
lora_requests)
|
|
43
|
+
|
|
44
|
+
params_and_buffers = update_lora(
|
|
45
|
+
self.runner.model.model, initial_params_buffers=self.runner.state)
|
|
46
|
+
self.runner.state = jax_view(params_and_buffers)
|
|
47
|
+
|
|
48
|
+
def extract_lora_metadata(self):
|
|
49
|
+
if self.runner.lora_config is None:
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
metadata = {}
|
|
53
|
+
punica_wrapper = None
|
|
54
|
+
for _, m in self.runner.model.model.named_modules():
|
|
55
|
+
if isinstance(m, BaseLinearLayerWithLoRA):
|
|
56
|
+
assert getattr(
|
|
57
|
+
m, 'punica_wrapper', None
|
|
58
|
+
) is not None, 'A lora wrapper should have contained a punica_wrapper'
|
|
59
|
+
punica_wrapper = m.punica_wrapper
|
|
60
|
+
break
|
|
61
|
+
assert punica_wrapper is not None, 'Should have been able to find a punica wrapper from the Lora wrapper.'
|
|
62
|
+
|
|
63
|
+
# vars does not show inherited methods or class attributes but this is
|
|
64
|
+
# fine because we only care about instance attributes.
|
|
65
|
+
for k in vars(punica_wrapper):
|
|
66
|
+
v = getattr(punica_wrapper, k, None)
|
|
67
|
+
if k == 'device': # Exclude string as it can't be traced by jax.jit
|
|
68
|
+
continue
|
|
69
|
+
metadata[k] = v
|
|
70
|
+
return jax_view(metadata)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def replace_lora_metadata(model, metadata: dict, lora_config) -> dict:
|
|
74
|
+
if lora_config is None or not metadata:
|
|
75
|
+
return {}
|
|
76
|
+
|
|
77
|
+
original_metadata = {}
|
|
78
|
+
punica_wrapper = None
|
|
79
|
+
for _, m in model.named_modules():
|
|
80
|
+
if isinstance(m, BaseLinearLayerWithLoRA):
|
|
81
|
+
assert getattr(
|
|
82
|
+
m, 'punica_wrapper', None
|
|
83
|
+
) is not None, 'A lora wrapper should have contained a punica_wrapper'
|
|
84
|
+
punica_wrapper = m.punica_wrapper
|
|
85
|
+
break
|
|
86
|
+
assert punica_wrapper is not None, 'Should have been able to find a punica wrapper from the Lora wrapper.'
|
|
87
|
+
|
|
88
|
+
for k in vars(punica_wrapper):
|
|
89
|
+
if k in metadata:
|
|
90
|
+
original_metadata[k] = getattr(punica_wrapper, k)
|
|
91
|
+
setattr(punica_wrapper, k, metadata[k])
|
|
92
|
+
return original_metadata
|