tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,27 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
2
15
|
import time
|
|
3
16
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
4
17
|
|
|
5
18
|
import jax
|
|
6
19
|
import jax.numpy as jnp
|
|
7
20
|
import numpy as np
|
|
8
|
-
import vllm.envs as
|
|
21
|
+
import vllm.envs as vllm_envs
|
|
9
22
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
10
23
|
|
|
24
|
+
import tpu_inference.envs as envs
|
|
11
25
|
from tpu_inference.core.disagg_utils import is_disagg_enabled
|
|
12
26
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
13
27
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
@@ -15,6 +29,8 @@ from tpu_inference.layers.jax.sample.sampling import sample
|
|
|
15
29
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
16
30
|
TPUSupportedSamplingMetadata
|
|
17
31
|
from tpu_inference.logger import init_logger
|
|
32
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
33
|
+
JaxIntermediateTensors
|
|
18
34
|
from tpu_inference.utils import device_array
|
|
19
35
|
|
|
20
36
|
if TYPE_CHECKING:
|
|
@@ -30,10 +46,12 @@ class CompilationManager:
|
|
|
30
46
|
|
|
31
47
|
def __init__(self, runner: "TPUModelRunner"):
|
|
32
48
|
self.runner = runner
|
|
33
|
-
|
|
49
|
+
self._sampling_precompiled = False
|
|
50
|
+
self._gather_logprobs_precompiled = False
|
|
51
|
+
if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
|
|
34
52
|
logger.info("Enabling JAX compile cache.")
|
|
35
53
|
jax.config.update("jax_compilation_cache_dir",
|
|
36
|
-
|
|
54
|
+
vllm_envs.VLLM_XLA_CACHE_PATH)
|
|
37
55
|
|
|
38
56
|
def _create_dummy_tensor(self,
|
|
39
57
|
shape: Tuple[int, ...],
|
|
@@ -67,8 +85,7 @@ class CompilationManager:
|
|
|
67
85
|
logger.info("Compilation finished in %.2f [secs].", end - start)
|
|
68
86
|
|
|
69
87
|
def capture_model(self) -> None:
|
|
70
|
-
if
|
|
71
|
-
False) or self.runner.model_config.enforce_eager:
|
|
88
|
+
if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
|
|
72
89
|
return
|
|
73
90
|
logger.info("Precompile all the subgraphs with possible input shapes.")
|
|
74
91
|
|
|
@@ -81,11 +98,17 @@ class CompilationManager:
|
|
|
81
98
|
self._precompile_backbone_with_inputs_embeds()
|
|
82
99
|
if self.runner.scheduler_config.async_scheduling:
|
|
83
100
|
self._precompile_substitute_placeholder_token()
|
|
101
|
+
if not self.runner.is_last_rank:
|
|
102
|
+
return
|
|
84
103
|
self._precompile_select_from_array()
|
|
85
104
|
self._precompile_compute_logits()
|
|
105
|
+
# Skip sampling if already precompiled before KV cache allocation
|
|
106
|
+
if not self._sampling_precompiled:
|
|
107
|
+
self._precompile_sampling()
|
|
86
108
|
self._precompile_disagg_utils()
|
|
87
|
-
|
|
88
|
-
self.
|
|
109
|
+
# Skip gather_logprobs if already precompiled before KV cache allocation
|
|
110
|
+
if not self._gather_logprobs_precompiled:
|
|
111
|
+
self._precompile_gather_logprobs()
|
|
89
112
|
self._precompile_structured_decoding()
|
|
90
113
|
if self.runner.speculative_config:
|
|
91
114
|
self._precompile_speculative_decoding()
|
|
@@ -104,7 +127,7 @@ class CompilationManager:
|
|
|
104
127
|
|
|
105
128
|
self._run_compilation(
|
|
106
129
|
"input_embeddings_merger",
|
|
107
|
-
self.runner.
|
|
130
|
+
self.runner.embed_input_ids_fn,
|
|
108
131
|
self.runner.state,
|
|
109
132
|
dummy_input_ids,
|
|
110
133
|
dummy_multimodal_embeddings,
|
|
@@ -113,15 +136,22 @@ class CompilationManager:
|
|
|
113
136
|
|
|
114
137
|
self._run_compilation(
|
|
115
138
|
"input_embeddings_merger_text_only",
|
|
116
|
-
self.runner.
|
|
139
|
+
self.runner.embed_input_ids_fn,
|
|
117
140
|
self.runner.state,
|
|
118
141
|
dummy_input_ids,
|
|
119
142
|
None,
|
|
120
143
|
num_tokens=num_tokens,
|
|
121
144
|
)
|
|
122
145
|
|
|
123
|
-
def _precompile_backbone_helper(self,
|
|
124
|
-
|
|
146
|
+
def _precompile_backbone_helper(self,
|
|
147
|
+
name,
|
|
148
|
+
*,
|
|
149
|
+
input_ids,
|
|
150
|
+
positions,
|
|
151
|
+
inputs_embeds,
|
|
152
|
+
intermediate_tensors=None,
|
|
153
|
+
is_first_rank=True,
|
|
154
|
+
is_last_rank=True) -> None:
|
|
125
155
|
num_tokens = None
|
|
126
156
|
if input_ids is not None:
|
|
127
157
|
num_tokens = input_ids.shape[0]
|
|
@@ -181,10 +211,14 @@ class CompilationManager:
|
|
|
181
211
|
inputs_embeds,
|
|
182
212
|
layer_name_to_kvcache_index,
|
|
183
213
|
lora_metadata,
|
|
214
|
+
intermediate_tensors,
|
|
215
|
+
is_first_rank,
|
|
216
|
+
is_last_rank,
|
|
184
217
|
):
|
|
185
218
|
kv_caches, hidden_states, _ = self.runner.model_fn(
|
|
186
219
|
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
|
|
187
|
-
positions, layer_name_to_kvcache_index, lora_metadata
|
|
220
|
+
positions, layer_name_to_kvcache_index, lora_metadata,
|
|
221
|
+
intermediate_tensors, is_first_rank, is_last_rank)
|
|
188
222
|
self.runner.kv_caches = kv_caches
|
|
189
223
|
return hidden_states
|
|
190
224
|
|
|
@@ -207,6 +241,9 @@ class CompilationManager:
|
|
|
207
241
|
inputs_embeds,
|
|
208
242
|
tuple(self.runner.layer_name_to_kvcache_index.items()),
|
|
209
243
|
lora_metadata,
|
|
244
|
+
intermediate_tensors,
|
|
245
|
+
is_first_rank,
|
|
246
|
+
is_last_rank,
|
|
210
247
|
num_tokens=num_tokens,
|
|
211
248
|
)
|
|
212
249
|
|
|
@@ -257,6 +294,7 @@ class CompilationManager:
|
|
|
257
294
|
)
|
|
258
295
|
|
|
259
296
|
def _precompile_backbone_text_only(self) -> None:
|
|
297
|
+
hidden_size = self.runner.model_config.get_hidden_size()
|
|
260
298
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
261
299
|
dp_sharding = NamedSharding(
|
|
262
300
|
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
|
|
@@ -266,10 +304,28 @@ class CompilationManager:
|
|
|
266
304
|
dp_sharding)
|
|
267
305
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
|
|
268
306
|
dp_sharding)
|
|
269
|
-
self.
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
307
|
+
is_first_rank = self.runner.is_first_rank
|
|
308
|
+
is_last_rank = self.runner.is_last_rank
|
|
309
|
+
if is_first_rank:
|
|
310
|
+
intermediate_tensors = None
|
|
311
|
+
else:
|
|
312
|
+
hidden_states = self._create_dummy_tensor(
|
|
313
|
+
(num_tokens, hidden_size), jnp.bfloat16)
|
|
314
|
+
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
315
|
+
jnp.bfloat16)
|
|
316
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
317
|
+
tensors={
|
|
318
|
+
"hidden_states": hidden_states,
|
|
319
|
+
"residual": residual
|
|
320
|
+
})
|
|
321
|
+
self._precompile_backbone_helper(
|
|
322
|
+
f"worker{self.runner.rank} backbone",
|
|
323
|
+
input_ids=input_ids,
|
|
324
|
+
positions=positions,
|
|
325
|
+
inputs_embeds=None,
|
|
326
|
+
intermediate_tensors=intermediate_tensors,
|
|
327
|
+
is_first_rank=is_first_rank,
|
|
328
|
+
is_last_rank=is_last_rank)
|
|
273
329
|
|
|
274
330
|
def _precompile_backbone_with_inputs_embeds(self) -> None:
|
|
275
331
|
hidden_size = self.runner.model_config.get_hidden_size()
|
|
@@ -283,10 +339,28 @@ class CompilationManager:
|
|
|
283
339
|
else:
|
|
284
340
|
positions = self._create_dummy_tensor((num_tokens, ),
|
|
285
341
|
jnp.int32)
|
|
286
|
-
self.
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
342
|
+
is_first_rank = self.runner.is_first_rank
|
|
343
|
+
is_last_rank = self.runner.is_last_rank
|
|
344
|
+
if not is_first_rank:
|
|
345
|
+
hidden_states = self._create_dummy_tensor(
|
|
346
|
+
(num_tokens, hidden_size), jnp.bfloat16)
|
|
347
|
+
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
348
|
+
jnp.bfloat16)
|
|
349
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
350
|
+
tensors={
|
|
351
|
+
"hidden_states": hidden_states,
|
|
352
|
+
"residual": residual
|
|
353
|
+
})
|
|
354
|
+
else:
|
|
355
|
+
intermediate_tensors = None
|
|
356
|
+
self._precompile_backbone_helper(
|
|
357
|
+
f"worker{self.runner.rank} backbone with embeds",
|
|
358
|
+
input_ids=None,
|
|
359
|
+
positions=positions,
|
|
360
|
+
inputs_embeds=inputs_embeds,
|
|
361
|
+
intermediate_tensors=intermediate_tensors,
|
|
362
|
+
is_first_rank=is_first_rank,
|
|
363
|
+
is_last_rank=is_last_rank)
|
|
290
364
|
|
|
291
365
|
def _precompile_select_from_array_helper(
|
|
292
366
|
self,
|
|
@@ -354,7 +428,7 @@ class CompilationManager:
|
|
|
354
428
|
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
|
|
355
429
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
356
430
|
self._precompile_select_from_array_helper(
|
|
357
|
-
name="select all logits",
|
|
431
|
+
name=f"worker{self.runner.rank} select all logits",
|
|
358
432
|
source_paddings=self.runner.num_tokens_paddings,
|
|
359
433
|
indices_paddings=index_paddings,
|
|
360
434
|
hidden_dim=hsize,
|
|
@@ -365,7 +439,8 @@ class CompilationManager:
|
|
|
365
439
|
if self.runner.speculative_config:
|
|
366
440
|
vocab_size = self.runner.model_config.get_vocab_size()
|
|
367
441
|
self._precompile_select_from_array_helper(
|
|
368
|
-
name=
|
|
442
|
+
name=
|
|
443
|
+
f"worker{self.runner.rank} select bonus tokens for spec decoding",
|
|
369
444
|
source_paddings=self.runner.num_logits_paddings,
|
|
370
445
|
indices_paddings=self.runner.num_reqs_paddings,
|
|
371
446
|
hidden_dim=vocab_size,
|
|
@@ -373,7 +448,8 @@ class CompilationManager:
|
|
|
373
448
|
PartitionSpec(None, "model")),
|
|
374
449
|
)
|
|
375
450
|
self._precompile_select_from_array_helper(
|
|
376
|
-
name=
|
|
451
|
+
name=
|
|
452
|
+
f"worker{self.runner.rank} select target tokens for spec decoding",
|
|
377
453
|
source_paddings=self.runner.num_logits_paddings,
|
|
378
454
|
indices_paddings=self.runner.num_logits_paddings,
|
|
379
455
|
hidden_dim=vocab_size,
|
|
@@ -396,7 +472,7 @@ class CompilationManager:
|
|
|
396
472
|
np.array([num_reqs], dtype=np.int32)):
|
|
397
473
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
398
474
|
self._run_compilation(
|
|
399
|
-
"compute_logits",
|
|
475
|
+
f"worker{self.runner.rank} compute_logits",
|
|
400
476
|
self.runner.compute_logits_fn,
|
|
401
477
|
self.runner.state,
|
|
402
478
|
hidden_states,
|
|
@@ -410,43 +486,48 @@ class CompilationManager:
|
|
|
410
486
|
for num_reqs in self.runner.num_reqs_paddings:
|
|
411
487
|
logits_sharding = NamedSharding(
|
|
412
488
|
self.runner.mesh,
|
|
413
|
-
PartitionSpec(ShardingAxisName.
|
|
489
|
+
PartitionSpec(ShardingAxisName.MLP_DATA,
|
|
490
|
+
ShardingAxisName.MLP_TENSOR))
|
|
414
491
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
415
492
|
sampling_metadata_sharding = NamedSharding(
|
|
416
493
|
self.runner.mesh, PartitionSpec(
|
|
417
|
-
ShardingAxisName.
|
|
494
|
+
ShardingAxisName.MLP_DATA)) if dp_size > 1 else None
|
|
418
495
|
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
|
|
419
496
|
logits_sharding)
|
|
420
497
|
for do_sampling in (True, False):
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
498
|
+
for logprobs in (True, False):
|
|
499
|
+
if do_sampling:
|
|
500
|
+
temperature = np.full((num_reqs, ),
|
|
501
|
+
0.7,
|
|
502
|
+
dtype=np.float32)
|
|
503
|
+
top_k = np.full((num_reqs, ), 20, dtype=np.int32)
|
|
504
|
+
top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
|
|
505
|
+
(temperature, top_k, top_p) = device_array(
|
|
506
|
+
self.runner.mesh, (temperature, top_k, top_p),
|
|
507
|
+
sharding=sampling_metadata_sharding)
|
|
508
|
+
else:
|
|
509
|
+
temperature = None
|
|
510
|
+
top_k = None
|
|
511
|
+
top_p = None
|
|
512
|
+
|
|
513
|
+
sampling_metadata = TPUSupportedSamplingMetadata(
|
|
514
|
+
temperature=temperature,
|
|
515
|
+
top_k=top_k,
|
|
516
|
+
top_p=top_p,
|
|
517
|
+
do_sampling=do_sampling,
|
|
518
|
+
logprobs=logprobs)
|
|
519
|
+
self._run_compilation(
|
|
520
|
+
f"worker{self.runner.rank} sample",
|
|
521
|
+
sample,
|
|
522
|
+
self.runner.rng_params_for_sampling,
|
|
523
|
+
self.runner.mesh,
|
|
524
|
+
logits,
|
|
525
|
+
sampling_metadata,
|
|
526
|
+
num_reqs=num_reqs,
|
|
527
|
+
do_sampling=do_sampling,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
self._sampling_precompiled = True
|
|
450
531
|
|
|
451
532
|
def _precompile_disagg_utils(self) -> None:
|
|
452
533
|
if not is_disagg_enabled():
|
|
@@ -476,10 +557,18 @@ class CompilationManager:
|
|
|
476
557
|
logger.info("Compiling gather_logprobs with different input shapes.")
|
|
477
558
|
hsize = self.runner.model_config.get_vocab_size()
|
|
478
559
|
for num_reqs in self.runner.num_reqs_paddings:
|
|
479
|
-
|
|
480
|
-
|
|
560
|
+
logits_sharding = NamedSharding(
|
|
561
|
+
self.runner.mesh,
|
|
562
|
+
PartitionSpec(ShardingAxisName.MLP_DATA,
|
|
563
|
+
ShardingAxisName.MLP_TENSOR))
|
|
564
|
+
token_ids_sharding = NamedSharding(
|
|
565
|
+
self.runner.mesh, PartitionSpec(ShardingAxisName.MLP_DATA, ))
|
|
566
|
+
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
|
|
567
|
+
logits_sharding)
|
|
568
|
+
token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32,
|
|
569
|
+
token_ids_sharding)
|
|
481
570
|
self._run_compilation(
|
|
482
|
-
"gather_logprobs",
|
|
571
|
+
f"worker{self.runner.rank} gather_logprobs",
|
|
483
572
|
self.runner._compute_and_gather_logprobs,
|
|
484
573
|
logits,
|
|
485
574
|
token_ids,
|
|
@@ -487,6 +576,8 @@ class CompilationManager:
|
|
|
487
576
|
num_reqs=num_reqs,
|
|
488
577
|
)
|
|
489
578
|
|
|
579
|
+
self._gather_logprobs_precompiled = True
|
|
580
|
+
|
|
490
581
|
def _precompile_speculative_decoding(self) -> None:
|
|
491
582
|
logger.info(
|
|
492
583
|
"Compiling speculative_decoding with different input shapes.")
|
|
@@ -531,7 +622,7 @@ class CompilationManager:
|
|
|
531
622
|
do_sampling=do_sampling)
|
|
532
623
|
|
|
533
624
|
self._run_compilation(
|
|
534
|
-
compilation_name,
|
|
625
|
+
f"worker{self.runner.rank} {compilation_name}",
|
|
535
626
|
self.runner.rejection_sampler,
|
|
536
627
|
draft_token_ids,
|
|
537
628
|
num_draft_tokens,
|
|
@@ -548,7 +639,9 @@ class CompilationManager:
|
|
|
548
639
|
def _precompile_eagle3_helpers(self) -> None:
|
|
549
640
|
logger.info(
|
|
550
641
|
"Compiling eagle3 jitted helpers with different input shapes.")
|
|
551
|
-
|
|
642
|
+
target_hidden_size = self.runner.model_config.get_hidden_size()
|
|
643
|
+
draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
|
|
644
|
+
)
|
|
552
645
|
dtype = self.runner.model_config.dtype
|
|
553
646
|
|
|
554
647
|
num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
|
|
@@ -595,10 +688,11 @@ class CompilationManager:
|
|
|
595
688
|
|
|
596
689
|
for num_logits in self.runner.num_logits_paddings:
|
|
597
690
|
hidden_states = self._create_dummy_tensor(
|
|
598
|
-
(num_logits,
|
|
691
|
+
(num_logits, draft_hidden_size), jnp.bfloat16)
|
|
599
692
|
self._run_compilation(
|
|
600
693
|
"eagle3_get_draft_token_ids",
|
|
601
694
|
self.runner.drafter._get_draft_token_ids,
|
|
695
|
+
self.runner.drafter.state,
|
|
602
696
|
hidden_states,
|
|
603
697
|
num_logits=num_logits,
|
|
604
698
|
)
|
|
@@ -606,8 +700,8 @@ class CompilationManager:
|
|
|
606
700
|
input_ids_loop = self._create_dummy_tensor(
|
|
607
701
|
(self.runner.max_num_reqs, ), jnp.int32,
|
|
608
702
|
NamedSharding(self.runner.mesh, PartitionSpec()))
|
|
609
|
-
|
|
610
|
-
(self.runner.max_num_reqs,
|
|
703
|
+
draft_hidden_state_loop = self._create_dummy_tensor(
|
|
704
|
+
(self.runner.max_num_reqs, draft_hidden_size), dtype,
|
|
611
705
|
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
612
706
|
next_token_ids = self._create_dummy_tensor(
|
|
613
707
|
(self.runner.max_num_reqs, ), jnp.int32)
|
|
@@ -615,9 +709,12 @@ class CompilationManager:
|
|
|
615
709
|
(self.runner.max_num_reqs, ), jnp.int32)
|
|
616
710
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
617
711
|
aux_hidden_states = [
|
|
618
|
-
self._create_dummy_tensor((num_tokens,
|
|
619
|
-
|
|
620
|
-
self._create_dummy_tensor((num_tokens,
|
|
712
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
713
|
+
dtype),
|
|
714
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
715
|
+
dtype),
|
|
716
|
+
self._create_dummy_tensor((num_tokens, target_hidden_size),
|
|
717
|
+
dtype),
|
|
621
718
|
]
|
|
622
719
|
|
|
623
720
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
@@ -640,23 +737,23 @@ class CompilationManager:
|
|
|
640
737
|
num_reqs,
|
|
641
738
|
):
|
|
642
739
|
target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
|
|
643
|
-
token_indices, query_start_loc,
|
|
644
|
-
aux_hidden_states, attention_metadata,
|
|
645
|
-
num_reqs)
|
|
740
|
+
self.runner.drafter.state, token_indices, query_start_loc,
|
|
741
|
+
seq_lens, input_ids, aux_hidden_states, attention_metadata,
|
|
742
|
+
next_token_ids, num_reqs)
|
|
646
743
|
return target_hidden_states, input_ids, last_token_indices
|
|
647
744
|
|
|
648
745
|
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
649
746
|
aux_hidden_states = [
|
|
650
747
|
self._create_dummy_tensor(
|
|
651
|
-
(num_tokens,
|
|
748
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
652
749
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
653
750
|
None))),
|
|
654
751
|
self._create_dummy_tensor(
|
|
655
|
-
(num_tokens,
|
|
752
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
656
753
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
657
754
|
None))),
|
|
658
755
|
self._create_dummy_tensor(
|
|
659
|
-
(num_tokens,
|
|
756
|
+
(num_tokens, target_hidden_size), jnp.bfloat16,
|
|
660
757
|
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
661
758
|
None))),
|
|
662
759
|
]
|
|
@@ -688,17 +785,17 @@ class CompilationManager:
|
|
|
688
785
|
state,
|
|
689
786
|
kv_caches,
|
|
690
787
|
input_ids,
|
|
691
|
-
|
|
788
|
+
draft_hidden_states,
|
|
692
789
|
attention_metadata,
|
|
693
790
|
):
|
|
694
791
|
kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
|
|
695
|
-
state, kv_caches, input_ids,
|
|
792
|
+
state, kv_caches, input_ids, draft_hidden_states,
|
|
696
793
|
attention_metadata)
|
|
697
794
|
self.runner.kv_caches = kv_caches
|
|
698
795
|
return hidden_states
|
|
699
796
|
|
|
700
|
-
|
|
701
|
-
(num_tokens,
|
|
797
|
+
draft_hidden_states = self._create_dummy_tensor(
|
|
798
|
+
(num_tokens, draft_hidden_size), dtype,
|
|
702
799
|
NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
|
|
703
800
|
input_ids = self._create_dummy_tensor(
|
|
704
801
|
(num_tokens, ), jnp.int32,
|
|
@@ -709,7 +806,7 @@ class CompilationManager:
|
|
|
709
806
|
self.runner.drafter.state,
|
|
710
807
|
self.runner.kv_caches,
|
|
711
808
|
input_ids,
|
|
712
|
-
|
|
809
|
+
draft_hidden_states,
|
|
713
810
|
attention_metadata,
|
|
714
811
|
num_tokens=num_tokens,
|
|
715
812
|
)
|
|
@@ -719,6 +816,7 @@ class CompilationManager:
|
|
|
719
816
|
self._run_compilation(
|
|
720
817
|
"eagle3_prepare_hidden_states_and_input_ids",
|
|
721
818
|
self.runner.drafter._prepare_hidden_states_and_input_ids,
|
|
819
|
+
self.runner.drafter.state,
|
|
722
820
|
aux_hidden_states,
|
|
723
821
|
query_start_loc,
|
|
724
822
|
target_token_ids,
|
|
@@ -741,18 +839,19 @@ class CompilationManager:
|
|
|
741
839
|
self.runner.drafter.state,
|
|
742
840
|
self.runner.kv_caches,
|
|
743
841
|
input_ids_loop,
|
|
744
|
-
|
|
842
|
+
draft_hidden_state_loop,
|
|
745
843
|
attention_metadata,
|
|
746
844
|
num_tokens=num_tokens,
|
|
747
845
|
)
|
|
748
846
|
|
|
749
847
|
hidden_states = self._create_dummy_tensor(
|
|
750
|
-
(num_tokens,
|
|
848
|
+
(num_tokens, draft_hidden_size), jnp.bfloat16,
|
|
751
849
|
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
752
850
|
|
|
753
851
|
self._run_compilation(
|
|
754
852
|
"eagle3_select_inputs_for_loop_speculation",
|
|
755
853
|
self.runner.drafter._select_inputs_for_loop_speculation,
|
|
854
|
+
self.runner.drafter.state,
|
|
756
855
|
positions,
|
|
757
856
|
hidden_states,
|
|
758
857
|
hidden_states,
|
|
@@ -763,6 +862,7 @@ class CompilationManager:
|
|
|
763
862
|
self._run_compilation(
|
|
764
863
|
"eagle3_select_draft_token_ids",
|
|
765
864
|
self.runner.drafter._select_draft_token_ids,
|
|
865
|
+
self.runner.drafter.state,
|
|
766
866
|
hidden_states,
|
|
767
867
|
last_token_indices,
|
|
768
868
|
num_tokens=num_tokens,
|
tpu_inference/runner/kv_cache.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Any, List
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -7,6 +21,7 @@ from jax._src import dtypes
|
|
|
7
21
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
22
|
from torchax.ops.mappings import t2j_dtype
|
|
9
23
|
|
|
24
|
+
import tpu_inference.kernels.mla.v1.kernel as mla
|
|
10
25
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
|
|
11
26
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
|
|
12
27
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
@@ -17,9 +32,13 @@ logger = init_logger(__name__)
|
|
|
17
32
|
DEFAULT_KV_CACHE_DTYPE = jnp.bfloat16
|
|
18
33
|
|
|
19
34
|
|
|
20
|
-
def get_kv_cache_shape_with_mesh(mesh: Mesh,
|
|
21
|
-
|
|
22
|
-
|
|
35
|
+
def get_kv_cache_shape_with_mesh(mesh: Mesh,
|
|
36
|
+
total_num_pages: int,
|
|
37
|
+
page_size: int,
|
|
38
|
+
actual_num_kv_heads: int,
|
|
39
|
+
actual_head_dim: int,
|
|
40
|
+
kv_dtype: any,
|
|
41
|
+
use_mla: bool = False):
|
|
23
42
|
"""Gets the KV cache shape based on the mesh configuration."""
|
|
24
43
|
|
|
25
44
|
model_cnt = mesh.shape["model"]
|
|
@@ -28,15 +47,21 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
|
|
|
28
47
|
# specific model, rather than being determined by the head_dim. If new
|
|
29
48
|
# models are introduced with a head_dim of 64, this will require additional
|
|
30
49
|
# model-specific adjustments.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
50
|
+
if use_mla:
|
|
51
|
+
get_kv_cache_shape_fn = mla.get_kv_cache_shape
|
|
52
|
+
shape = list(
|
|
53
|
+
get_kv_cache_shape_fn(total_num_pages, page_size, actual_head_dim,
|
|
54
|
+
kv_dtype))
|
|
55
|
+
else:
|
|
56
|
+
get_kv_cache_shape_fn = (
|
|
57
|
+
rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
|
|
58
|
+
else rpa.get_kv_cache_shape
|
|
59
|
+
)
|
|
60
|
+
shape = list(
|
|
61
|
+
get_kv_cache_shape_fn(total_num_pages, page_size,
|
|
62
|
+
actual_num_kv_heads // model_cnt,
|
|
63
|
+
actual_head_dim, kv_dtype))
|
|
64
|
+
shape[2] *= model_cnt
|
|
40
65
|
return tuple(shape)
|
|
41
66
|
|
|
42
67
|
|
|
@@ -48,6 +73,7 @@ def create_kv_caches(
|
|
|
48
73
|
mesh: Mesh,
|
|
49
74
|
layer_names: List[str],
|
|
50
75
|
cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE,
|
|
76
|
+
use_mla: bool = False,
|
|
51
77
|
) -> List[jax.Array]:
|
|
52
78
|
"""
|
|
53
79
|
Creates a list of KV cache where each array mapps to single attention layer.
|
|
@@ -74,12 +100,16 @@ def create_kv_caches(
|
|
|
74
100
|
|
|
75
101
|
cache_shape = get_kv_cache_shape_with_mesh(mesh, num_blocks, block_size,
|
|
76
102
|
num_kv_heads, head_size,
|
|
77
|
-
cache_dtype)
|
|
103
|
+
cache_dtype, use_mla)
|
|
78
104
|
|
|
79
|
-
|
|
80
|
-
mesh,
|
|
81
|
-
|
|
82
|
-
|
|
105
|
+
if use_mla:
|
|
106
|
+
sharding = NamedSharding(mesh,
|
|
107
|
+
PartitionSpec(ShardingAxisName.MLP_TENSOR))
|
|
108
|
+
else:
|
|
109
|
+
sharding = NamedSharding(
|
|
110
|
+
mesh,
|
|
111
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA, None,
|
|
112
|
+
ShardingAxisName.ATTN_HEAD))
|
|
83
113
|
|
|
84
114
|
def _allocate() -> jax.Array:
|
|
85
115
|
return jnp.empty(
|
|
@@ -94,7 +124,8 @@ def create_kv_caches(
|
|
|
94
124
|
return kv_caches
|
|
95
125
|
|
|
96
126
|
|
|
97
|
-
def
|
|
127
|
+
def get_attention_page_size_bytes(mesh: Mesh,
|
|
128
|
+
kv_cache_specs: dict[str, Any]) -> int:
|
|
98
129
|
"""
|
|
99
130
|
Calculate KV cache page size of RPA kernel.
|
|
100
131
|
|
|
@@ -107,14 +138,16 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
|
|
|
107
138
|
"""
|
|
108
139
|
|
|
109
140
|
# Import it here to avoid circular import.
|
|
110
|
-
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
141
|
+
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
|
111
142
|
|
|
112
143
|
page_size_bytes_set = set()
|
|
113
144
|
for kv_cache_spec in kv_cache_specs.values():
|
|
114
145
|
assert isinstance(kv_cache_spec, AttentionSpec)
|
|
115
146
|
|
|
116
147
|
dtype = t2j_dtype(kv_cache_spec.dtype)
|
|
117
|
-
bits = dtypes.bit_width(dtype)
|
|
148
|
+
bits = (dtypes.bit_width(dtype) if hasattr(dtypes, "bit_width") else
|
|
149
|
+
dtypes.itemsize_bits(dtype))
|
|
150
|
+
use_mla = isinstance(kv_cache_spec, MLAAttentionSpec)
|
|
118
151
|
|
|
119
152
|
kv_cache_shape = get_kv_cache_shape_with_mesh(
|
|
120
153
|
mesh=mesh,
|
|
@@ -123,6 +156,7 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
|
|
|
123
156
|
actual_num_kv_heads=kv_cache_spec.num_kv_heads,
|
|
124
157
|
actual_head_dim=kv_cache_spec.head_size,
|
|
125
158
|
kv_dtype=dtype,
|
|
159
|
+
use_mla=use_mla,
|
|
126
160
|
)
|
|
127
161
|
page_size_bytes = (bits * np.prod(kv_cache_shape)) // 8
|
|
128
162
|
page_size_bytes_set.add(page_size_bytes)
|