tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
# Flax and JAX sharding imports
|
|
5
|
+
import jax
|
|
6
|
+
from flax import nnx
|
|
7
|
+
|
|
8
|
+
from tpu_inference.layers.jax.attention.attention import (AttentionMetadata,
|
|
9
|
+
KVCache)
|
|
10
|
+
from tpu_inference.layers.jax.layers import DenseFFW
|
|
11
|
+
from tpu_inference.layers.jax.moe.moe import MoE
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(kw_only=True)
|
|
15
|
+
class TransformerBlock(nnx.Module):
|
|
16
|
+
"""
|
|
17
|
+
A heavy weight module which serves as the stateful live blocks in serving
|
|
18
|
+
|
|
19
|
+
custom_module can be either a dense module (i.e., DenseFFW) or MoE.
|
|
20
|
+
"""
|
|
21
|
+
pre_attention_norm: nnx.Module
|
|
22
|
+
pre_mlp_norm: nnx.Module
|
|
23
|
+
custom_module: Optional[nnx.Module] = None
|
|
24
|
+
attn: nnx.Module
|
|
25
|
+
use_attention_rope: bool = True
|
|
26
|
+
quant: Any | None = None
|
|
27
|
+
|
|
28
|
+
def __call__(
|
|
29
|
+
self, x_TD: jax.Array, is_prefill: bool, kv_cache: KVCache,
|
|
30
|
+
attention_metadata: AttentionMetadata
|
|
31
|
+
) -> Tuple[KVCache, jax.Array]:
|
|
32
|
+
# Attn Block
|
|
33
|
+
attn_residual_TD = x_TD
|
|
34
|
+
x_TD = self.pre_attention_norm(x_TD)
|
|
35
|
+
new_cache, attn_output_TD = self.attn(x_TD, is_prefill, kv_cache,
|
|
36
|
+
attention_metadata,
|
|
37
|
+
self.use_attention_rope)
|
|
38
|
+
attn_output_TD += attn_residual_TD
|
|
39
|
+
|
|
40
|
+
# FFW Block
|
|
41
|
+
ffw_residual_TD = attn_output_TD
|
|
42
|
+
normed_ffw_input_TD = self.pre_mlp_norm(attn_output_TD)
|
|
43
|
+
logits_TD = self.custom_module(normed_ffw_input_TD)
|
|
44
|
+
logits_TD += ffw_residual_TD
|
|
45
|
+
return new_cache, logits_TD
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass(kw_only=True)
|
|
49
|
+
class SharedExpertsTransformerBlock(TransformerBlock):
|
|
50
|
+
"""Create a modified TransformerBlock that sums MoE layer output with shared expert output.
|
|
51
|
+
|
|
52
|
+
Users can provide the FFW layer in two ways:
|
|
53
|
+
1. Pass the module (either `MoE` or `DenseFFW`) to the `custom_module`
|
|
54
|
+
attribute.
|
|
55
|
+
2. Specify the `moe_ffw` or `dense_ffw` attributes
|
|
56
|
+
(e.g., for passing quantized modules).
|
|
57
|
+
|
|
58
|
+
Attributes:
|
|
59
|
+
moe_ffw: Optional MoE layer.
|
|
60
|
+
dense_ffw: Optional DFF layer.
|
|
61
|
+
shared_experts: Optional shared experts module, used if MoE is enabled.
|
|
62
|
+
|
|
63
|
+
If an `MoE` layer is used (from either path), its output is summed
|
|
64
|
+
with the `shared_experts` module.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
moe_ffw: Optional[MoE] = None
|
|
68
|
+
dense_ffw: Optional[DenseFFW] = None
|
|
69
|
+
shared_experts: Optional[DenseFFW] = None
|
|
70
|
+
|
|
71
|
+
def __call__(self, x_TD, is_prefill, kv_cache, attention_metadata):
|
|
72
|
+
# Attn Block
|
|
73
|
+
attn_residual_TD = x_TD
|
|
74
|
+
x_TD = self.pre_attention_norm(x_TD)
|
|
75
|
+
new_cache, attn_output_TD = self.attn(x_TD, is_prefill, kv_cache,
|
|
76
|
+
attention_metadata,
|
|
77
|
+
self.use_attention_rope)
|
|
78
|
+
attn_output_TD += attn_residual_TD
|
|
79
|
+
|
|
80
|
+
# FFW Block
|
|
81
|
+
ffw_residual_TD = attn_output_TD
|
|
82
|
+
normed_ffw_input_TD = self.pre_mlp_norm(attn_output_TD)
|
|
83
|
+
|
|
84
|
+
if isinstance(self.custom_module, MoE):
|
|
85
|
+
moe_layer = self.custom_module
|
|
86
|
+
else:
|
|
87
|
+
moe_layer = self.moe_ffw
|
|
88
|
+
|
|
89
|
+
if isinstance(self.custom_module, DenseFFW):
|
|
90
|
+
dense_layer = self.custom_module
|
|
91
|
+
else:
|
|
92
|
+
dense_layer = self.dense_ffw
|
|
93
|
+
|
|
94
|
+
if moe_layer is not None:
|
|
95
|
+
logits_TD = moe_layer(normed_ffw_input_TD)
|
|
96
|
+
# Add the shared expert outputs to the MoE outputs.
|
|
97
|
+
shared_expert_output_TD = self.shared_experts(normed_ffw_input_TD)
|
|
98
|
+
logits_TD += shared_expert_output_TD
|
|
99
|
+
elif dense_layer is not None:
|
|
100
|
+
logits_TD = dense_layer(normed_ffw_input_TD)
|
|
101
|
+
else:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
"Neither custom_module, moe_ffw nor dense_ffw attribute is set for this SharedExpertsTransformerBlock!"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
logits_TD += ffw_residual_TD
|
|
107
|
+
return new_cache, logits_TD
|
|
File without changes
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
3
|
+
import functools
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import torch
|
|
9
|
+
from jax.sharding import Mesh
|
|
10
|
+
from torchax.interop import jax_view, torch_view
|
|
11
|
+
from torchax.ops.mappings import t2j
|
|
12
|
+
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|
13
|
+
AttentionLayer, AttentionType)
|
|
14
|
+
|
|
15
|
+
from tpu_inference import utils
|
|
16
|
+
from tpu_inference.layers.common.attention_interface import attention
|
|
17
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
18
|
+
from tpu_inference.logger import init_logger
|
|
19
|
+
from tpu_inference.models.vllm.vllm_model_wrapper_context import \
|
|
20
|
+
get_vllm_model_wrapper_context
|
|
21
|
+
|
|
22
|
+
logger = init_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PallasAttentionBackend(AttentionBackend):
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def get_name() -> str:
|
|
29
|
+
return "PALLAS"
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
|
33
|
+
return PallasAttentionBackendImpl
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PallasAttentionBackendImpl(AttentionImpl):
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
num_heads: int,
|
|
41
|
+
head_size: int,
|
|
42
|
+
scale: float,
|
|
43
|
+
num_kv_heads: int,
|
|
44
|
+
alibi_slopes: list[float] | None,
|
|
45
|
+
sliding_window: int | None,
|
|
46
|
+
kv_cache_dtype: str,
|
|
47
|
+
logits_soft_cap: float | None = None,
|
|
48
|
+
attn_type: AttentionType = AttentionType.DECODER,
|
|
49
|
+
kv_sharing_target_layer_name: str | None = None,
|
|
50
|
+
sinks: torch.Tensor | None = None,
|
|
51
|
+
) -> None:
|
|
52
|
+
self.num_heads = num_heads
|
|
53
|
+
self.head_size = head_size
|
|
54
|
+
self.scale = float(scale)
|
|
55
|
+
self.num_kv_heads = num_kv_heads
|
|
56
|
+
self.sliding_window = sliding_window
|
|
57
|
+
self.logits_soft_cap = logits_soft_cap
|
|
58
|
+
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
|
59
|
+
|
|
60
|
+
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
61
|
+
if alibi_slopes is not None:
|
|
62
|
+
raise NotImplementedError("Alibi slopes is not supported.")
|
|
63
|
+
self.kv_cache_quantized_dtype = None
|
|
64
|
+
if kv_cache_dtype != "auto":
|
|
65
|
+
self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
|
|
66
|
+
kv_cache_dtype)
|
|
67
|
+
|
|
68
|
+
if attn_type != AttentionType.DECODER:
|
|
69
|
+
raise NotImplementedError("Encoder self-attention and "
|
|
70
|
+
"encoder/decoder cross-attention "
|
|
71
|
+
"are not implemented for "
|
|
72
|
+
"PallasAttentionBackendImpl")
|
|
73
|
+
|
|
74
|
+
self.sinks = sinks
|
|
75
|
+
if self.sinks is not None:
|
|
76
|
+
assert self.sinks.shape[0] == num_heads, (
|
|
77
|
+
"Sinks must have the same number of heads as the number of "
|
|
78
|
+
"heads in the layer")
|
|
79
|
+
|
|
80
|
+
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
|
81
|
+
#TODO (kyuyeunk): Shard the sinks along num_heads dim
|
|
82
|
+
if self.sinks is not None:
|
|
83
|
+
sinks = t2j(self.sinks, use_dlpack=False)
|
|
84
|
+
sinks = torch_view(sinks.astype(jnp.float32))
|
|
85
|
+
self.sinks = torch.nn.Parameter(sinks, requires_grad=False)
|
|
86
|
+
|
|
87
|
+
def forward(
|
|
88
|
+
self,
|
|
89
|
+
layer: AttentionLayer,
|
|
90
|
+
query: torch.Tensor,
|
|
91
|
+
key: torch.Tensor,
|
|
92
|
+
value: torch.Tensor,
|
|
93
|
+
kv_cache: torch.Tensor,
|
|
94
|
+
attn_metadata: AttentionMetadata,
|
|
95
|
+
output: Optional[torch.Tensor] = None,
|
|
96
|
+
output_scale: Optional[torch.Tensor] = None,
|
|
97
|
+
) -> torch.Tensor:
|
|
98
|
+
if output_scale is not None:
|
|
99
|
+
raise NotImplementedError(
|
|
100
|
+
"fused output quantization is not yet supported for "
|
|
101
|
+
"PallasAttentionBackendImpl")
|
|
102
|
+
|
|
103
|
+
if kv_cache.numel():
|
|
104
|
+
raise RuntimeError(
|
|
105
|
+
"KV cache from vLLM Attention layer should be empty but has "
|
|
106
|
+
"the size of %s.", kv_cache.numel())
|
|
107
|
+
|
|
108
|
+
del kv_cache # Use kv_cache from vllm wrapper context values instead.
|
|
109
|
+
|
|
110
|
+
vllm_model_wrapper_context = get_vllm_model_wrapper_context()
|
|
111
|
+
kv_cache_index = vllm_model_wrapper_context.layer_name_to_kvcache_index[
|
|
112
|
+
layer.layer_name]
|
|
113
|
+
kv_cache = vllm_model_wrapper_context.kv_caches[kv_cache_index]
|
|
114
|
+
|
|
115
|
+
mesh = vllm_model_wrapper_context.mesh
|
|
116
|
+
|
|
117
|
+
query, key, value = jax_view(query), jax_view(key), jax_view(value)
|
|
118
|
+
q_scale = k_scale = v_scale = None
|
|
119
|
+
if self.kv_cache_quantized_dtype:
|
|
120
|
+
key, value = utils.quantize_kv(key, value,
|
|
121
|
+
self.kv_cache_quantized_dtype,
|
|
122
|
+
layer._k_scale_float,
|
|
123
|
+
layer._v_scale_float)
|
|
124
|
+
# TODO(kyuyeunk): Enable w8a8 when VREG spill issue is resolved.
|
|
125
|
+
# q_scale = layer._q_scale_float
|
|
126
|
+
k_scale = layer._k_scale_float
|
|
127
|
+
v_scale = layer._v_scale_float
|
|
128
|
+
|
|
129
|
+
sinks = jax_view(self.sinks)
|
|
130
|
+
|
|
131
|
+
new_kv_cache, outputs = _jax_attn_func(
|
|
132
|
+
kv_cache,
|
|
133
|
+
query,
|
|
134
|
+
key,
|
|
135
|
+
value,
|
|
136
|
+
sinks,
|
|
137
|
+
attn_metadata,
|
|
138
|
+
mesh,
|
|
139
|
+
self.scale,
|
|
140
|
+
self.head_size,
|
|
141
|
+
self.num_heads,
|
|
142
|
+
self.num_kv_heads,
|
|
143
|
+
q_scale,
|
|
144
|
+
k_scale,
|
|
145
|
+
v_scale,
|
|
146
|
+
self.sliding_window,
|
|
147
|
+
)
|
|
148
|
+
vllm_model_wrapper_context.kv_caches[kv_cache_index] = new_kv_cache
|
|
149
|
+
|
|
150
|
+
return torch_view(outputs)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@functools.partial(
|
|
154
|
+
jax.jit,
|
|
155
|
+
static_argnames=(
|
|
156
|
+
"mesh",
|
|
157
|
+
"scale",
|
|
158
|
+
"head_size",
|
|
159
|
+
"num_heads",
|
|
160
|
+
"num_kv_heads",
|
|
161
|
+
"q_scale",
|
|
162
|
+
"k_scale",
|
|
163
|
+
"v_scale",
|
|
164
|
+
"sliding_window",
|
|
165
|
+
),
|
|
166
|
+
donate_argnames=("kv_cache"),
|
|
167
|
+
)
|
|
168
|
+
def _jax_attn_func(
|
|
169
|
+
kv_cache: jax.Array,
|
|
170
|
+
q: jax.Array,
|
|
171
|
+
k: jax.Array,
|
|
172
|
+
v: jax.Array,
|
|
173
|
+
sinks: jax.Array | None,
|
|
174
|
+
attention_metadata: AttentionMetadata,
|
|
175
|
+
mesh: Mesh,
|
|
176
|
+
scale: float,
|
|
177
|
+
head_size: int,
|
|
178
|
+
num_heads: int,
|
|
179
|
+
num_kv_heads: int,
|
|
180
|
+
q_scale: float | None = None,
|
|
181
|
+
k_scale: float | None = None,
|
|
182
|
+
v_scale: float | None = None,
|
|
183
|
+
sliding_window: int | None = None,
|
|
184
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
185
|
+
del scale # Unused for now, as the attention function applies a default scale.
|
|
186
|
+
|
|
187
|
+
# Get shapes from vllm
|
|
188
|
+
q_len, q_compute_dim = q.shape
|
|
189
|
+
k_len, k_compute_dim = k.shape
|
|
190
|
+
assert k.shape == v.shape
|
|
191
|
+
assert q_compute_dim == head_size * num_heads
|
|
192
|
+
assert k_compute_dim == head_size * num_kv_heads
|
|
193
|
+
|
|
194
|
+
# Convert the shapes from vLLM's convetion to what the attention function expects
|
|
195
|
+
# bs, num_heads, q_len, head_size
|
|
196
|
+
q = q.reshape(q_len, num_heads, head_size)
|
|
197
|
+
# bs, num_kv_heads, k_len, head_size
|
|
198
|
+
k = k.reshape(k_len, num_kv_heads, head_size)
|
|
199
|
+
v = v.reshape(k_len, num_kv_heads, head_size)
|
|
200
|
+
|
|
201
|
+
new_kv_cache, outputs = attention(
|
|
202
|
+
kv_cache,
|
|
203
|
+
q,
|
|
204
|
+
k,
|
|
205
|
+
v,
|
|
206
|
+
attention_metadata,
|
|
207
|
+
mesh,
|
|
208
|
+
q_scale=q_scale,
|
|
209
|
+
k_scale=k_scale,
|
|
210
|
+
v_scale=v_scale,
|
|
211
|
+
sinks=sinks,
|
|
212
|
+
attention_chunk_size=sliding_window,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Convert the shape back to vLLM's convention
|
|
216
|
+
assert outputs.shape[0] == q_len
|
|
217
|
+
assert outputs.shape[1] == num_heads
|
|
218
|
+
assert outputs.shape[2] == head_size
|
|
219
|
+
outputs = outputs.reshape(q_len, q_compute_dim)
|
|
220
|
+
|
|
221
|
+
return new_kv_cache, outputs
|