tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,27 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
2
15
|
import time
|
|
3
16
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
4
17
|
|
|
5
18
|
import jax
|
|
6
19
|
import jax.numpy as jnp
|
|
7
20
|
import numpy as np
|
|
8
|
-
import vllm.envs as
|
|
21
|
+
import vllm.envs as vllm_envs
|
|
9
22
|
from jax.sharding import NamedSharding, PartitionSpec
|
|
10
23
|
|
|
24
|
+
import tpu_inference.envs as envs
|
|
11
25
|
from tpu_inference.core.disagg_utils import is_disagg_enabled
|
|
12
26
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
13
27
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
@@ -15,6 +29,8 @@ from tpu_inference.layers.jax.sample.sampling import sample
|
|
|
15
29
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
16
30
|
TPUSupportedSamplingMetadata
|
|
17
31
|
from tpu_inference.logger import init_logger
|
|
32
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
33
|
+
JaxIntermediateTensors
|
|
18
34
|
from tpu_inference.utils import device_array
|
|
19
35
|
|
|
20
36
|
if TYPE_CHECKING:
|
|
@@ -30,10 +46,12 @@ class CompilationManager:
|
|
|
30
46
|
|
|
31
47
|
def __init__(self, runner: "TPUModelRunner"):
|
|
32
48
|
self.runner = runner
|
|
33
|
-
|
|
49
|
+
self._sampling_precompiled = False
|
|
50
|
+
self._gather_logprobs_precompiled = False
|
|
51
|
+
if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
|
|
34
52
|
logger.info("Enabling JAX compile cache.")
|
|
35
53
|
jax.config.update("jax_compilation_cache_dir",
|
|
36
|
-
|
|
54
|
+
vllm_envs.VLLM_XLA_CACHE_PATH)
|
|
37
55
|
|
|
38
56
|
def _create_dummy_tensor(self,
|
|
39
57
|
shape: Tuple[int, ...],
|
|
@@ -67,8 +85,7 @@ class CompilationManager:
|
|
|
67
85
|
logger.info("Compilation finished in %.2f [secs].", end - start)
|
|
68
86
|
|
|
69
87
|
def capture_model(self) -> None:
|
|
70
|
-
if
|
|
71
|
-
False) or self.runner.model_config.enforce_eager:
|
|
88
|
+
if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
|
|
72
89
|
return
|
|
73
90
|
logger.info("Precompile all the subgraphs with possible input shapes.")
|
|
74
91
|
|
|
@@ -81,11 +98,17 @@ class CompilationManager:
|
|
|
81
98
|
self._precompile_backbone_with_inputs_embeds()
|
|
82
99
|
if self.runner.scheduler_config.async_scheduling:
|
|
83
100
|
self._precompile_substitute_placeholder_token()
|
|
101
|
+
if not self.runner.is_last_rank:
|
|
102
|
+
return
|
|
84
103
|
self._precompile_select_from_array()
|
|
85
104
|
self._precompile_compute_logits()
|
|
105
|
+
# Skip sampling if already precompiled before KV cache allocation
|
|
106
|
+
if not self._sampling_precompiled:
|
|
107
|
+
self._precompile_sampling()
|
|
86
108
|
self._precompile_disagg_utils()
|
|
87
|
-
|
|
88
|
-
self.
|
|
109
|
+
# Skip gather_logprobs if already precompiled before KV cache allocation
|
|
110
|
+
if not self._gather_logprobs_precompiled:
|
|
111
|
+
self._precompile_gather_logprobs()
|
|
89
112
|
self._precompile_structured_decoding()
|
|
90
113
|
if self.runner.speculative_config:
|
|
91
114
|
self._precompile_speculative_decoding()
|
|
@@ -104,7 +127,7 @@ class CompilationManager:
|
|
|
104
127
|
|
|
105
128
|
self._run_compilation(
|
|
106
129
|
"input_embeddings_merger",
|
|
107
|
-
self.runner.
|
|
130
|
+
self.runner.embed_input_ids_fn,
|
|
108
131
|
self.runner.state,
|
|
109
132
|
dummy_input_ids,
|
|
110
133
|
dummy_multimodal_embeddings,
|
|
@@ -113,15 +136,22 @@ class CompilationManager:
|
|
|
113
136
|
|
|
114
137
|
self._run_compilation(
|
|
115
138
|
"input_embeddings_merger_text_only",
|
|
116
|
-
self.runner.
|
|
139
|
+
self.runner.embed_input_ids_fn,
|
|
117
140
|
self.runner.state,
|
|
118
141
|
dummy_input_ids,
|
|
119
142
|
None,
|
|
120
143
|
num_tokens=num_tokens,
|
|
121
144
|
)
|
|
122
145
|
|
|
123
|
-
def _precompile_backbone_helper(self,
|
|
124
|
-
|
|
146
|
+
def _precompile_backbone_helper(self,
|
|
147
|
+
name,
|
|
148
|
+
*,
|
|
149
|
+
input_ids,
|
|
150
|
+
positions,
|
|
151
|
+
inputs_embeds,
|
|
152
|
+
intermediate_tensors=None,
|
|
153
|
+
is_first_rank=True,
|
|
154
|
+
is_last_rank=True) -> None:
|
|
125
155
|
num_tokens = None
|
|
126
156
|
if input_ids is not None:
|
|
127
157
|
num_tokens = input_ids.shape[0]
|
|
@@ -181,10 +211,14 @@ class CompilationManager:
|
|
|
181
211
|
inputs_embeds,
|
|
182
212
|
layer_name_to_kvcache_index,
|
|
183
213
|
lora_metadata,
|
|
214
|
+
intermediate_tensors,
|
|
215
|
+
is_first_rank,
|
|
216
|
+
is_last_rank,
|
|
184
217
|
):
|
|
185
218
|
kv_caches, hidden_states, _ = self.runner.model_fn(
|
|
186
219
|
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
|
|
187
|
-
positions, layer_name_to_kvcache_index, lora_metadata
|
|
220
|
+
positions, layer_name_to_kvcache_index, lora_metadata,
|
|
221
|
+
intermediate_tensors, is_first_rank, is_last_rank)
|
|
188
222
|
self.runner.kv_caches = kv_caches
|
|
189
223
|
return hidden_states
|
|
190
224
|
|
|
@@ -207,6 +241,9 @@ class CompilationManager:
|
|
|
207
241
|
inputs_embeds,
|
|
208
242
|
tuple(self.runner.layer_name_to_kvcache_index.items()),
|
|
209
243
|
lora_metadata,
|
|
244
|
+
intermediate_tensors,
|
|
245
|
+
is_first_rank,
|
|
246
|
+
is_last_rank,
|
|
210
247
|
num_tokens=num_tokens,
|
|
211
248
|
)
|
|
212
249
|
|
|
@@ -257,6 +294,7 @@ class CompilationManager:
|
|
|
257
294
|
)
|
|
258
295
|
|
|
259
296
|
def _precompile_backbone_text_only(self) -> None:
|
|
297
|
+
hidden_size = self.runner.model_config.get_hidden_size()
|
|
260
298
|
for num_tokens in self.runner.num_tokens_paddings:
|
|
261
299
|
dp_sharding = NamedSharding(
|
|
262
300
|
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
|
|
@@ -266,10 +304,28 @@ class CompilationManager:
|
|
|
266
304
|
dp_sharding)
|
|
267
305
|
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
|
|
268
306
|
dp_sharding)
|
|
269
|
-
self.
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
307
|
+
is_first_rank = self.runner.is_first_rank
|
|
308
|
+
is_last_rank = self.runner.is_last_rank
|
|
309
|
+
if is_first_rank:
|
|
310
|
+
intermediate_tensors = None
|
|
311
|
+
else:
|
|
312
|
+
hidden_states = self._create_dummy_tensor(
|
|
313
|
+
(num_tokens, hidden_size), jnp.bfloat16)
|
|
314
|
+
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
315
|
+
jnp.bfloat16)
|
|
316
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
317
|
+
tensors={
|
|
318
|
+
"hidden_states": hidden_states,
|
|
319
|
+
"residual": residual
|
|
320
|
+
})
|
|
321
|
+
self._precompile_backbone_helper(
|
|
322
|
+
f"worker{self.runner.rank} backbone",
|
|
323
|
+
input_ids=input_ids,
|
|
324
|
+
positions=positions,
|
|
325
|
+
inputs_embeds=None,
|
|
326
|
+
intermediate_tensors=intermediate_tensors,
|
|
327
|
+
is_first_rank=is_first_rank,
|
|
328
|
+
is_last_rank=is_last_rank)
|
|
273
329
|
|
|
274
330
|
def _precompile_backbone_with_inputs_embeds(self) -> None:
|
|
275
331
|
hidden_size = self.runner.model_config.get_hidden_size()
|
|
@@ -283,10 +339,28 @@ class CompilationManager:
|
|
|
283
339
|
else:
|
|
284
340
|
positions = self._create_dummy_tensor((num_tokens, ),
|
|
285
341
|
jnp.int32)
|
|
286
|
-
self.
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
342
|
+
is_first_rank = self.runner.is_first_rank
|
|
343
|
+
is_last_rank = self.runner.is_last_rank
|
|
344
|
+
if not is_first_rank:
|
|
345
|
+
hidden_states = self._create_dummy_tensor(
|
|
346
|
+
(num_tokens, hidden_size), jnp.bfloat16)
|
|
347
|
+
residual = self._create_dummy_tensor((num_tokens, hidden_size),
|
|
348
|
+
jnp.bfloat16)
|
|
349
|
+
intermediate_tensors = JaxIntermediateTensors(
|
|
350
|
+
tensors={
|
|
351
|
+
"hidden_states": hidden_states,
|
|
352
|
+
"residual": residual
|
|
353
|
+
})
|
|
354
|
+
else:
|
|
355
|
+
intermediate_tensors = None
|
|
356
|
+
self._precompile_backbone_helper(
|
|
357
|
+
f"worker{self.runner.rank} backbone with embeds",
|
|
358
|
+
input_ids=None,
|
|
359
|
+
positions=positions,
|
|
360
|
+
inputs_embeds=inputs_embeds,
|
|
361
|
+
intermediate_tensors=intermediate_tensors,
|
|
362
|
+
is_first_rank=is_first_rank,
|
|
363
|
+
is_last_rank=is_last_rank)
|
|
290
364
|
|
|
291
365
|
def _precompile_select_from_array_helper(
|
|
292
366
|
self,
|
|
@@ -354,7 +428,7 @@ class CompilationManager:
|
|
|
354
428
|
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
|
|
355
429
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
356
430
|
self._precompile_select_from_array_helper(
|
|
357
|
-
name="select all logits",
|
|
431
|
+
name=f"worker{self.runner.rank} select all logits",
|
|
358
432
|
source_paddings=self.runner.num_tokens_paddings,
|
|
359
433
|
indices_paddings=index_paddings,
|
|
360
434
|
hidden_dim=hsize,
|
|
@@ -365,7 +439,8 @@ class CompilationManager:
|
|
|
365
439
|
if self.runner.speculative_config:
|
|
366
440
|
vocab_size = self.runner.model_config.get_vocab_size()
|
|
367
441
|
self._precompile_select_from_array_helper(
|
|
368
|
-
name=
|
|
442
|
+
name=
|
|
443
|
+
f"worker{self.runner.rank} select bonus tokens for spec decoding",
|
|
369
444
|
source_paddings=self.runner.num_logits_paddings,
|
|
370
445
|
indices_paddings=self.runner.num_reqs_paddings,
|
|
371
446
|
hidden_dim=vocab_size,
|
|
@@ -373,7 +448,8 @@ class CompilationManager:
|
|
|
373
448
|
PartitionSpec(None, "model")),
|
|
374
449
|
)
|
|
375
450
|
self._precompile_select_from_array_helper(
|
|
376
|
-
name=
|
|
451
|
+
name=
|
|
452
|
+
f"worker{self.runner.rank} select target tokens for spec decoding",
|
|
377
453
|
source_paddings=self.runner.num_logits_paddings,
|
|
378
454
|
indices_paddings=self.runner.num_logits_paddings,
|
|
379
455
|
hidden_dim=vocab_size,
|
|
@@ -396,7 +472,7 @@ class CompilationManager:
|
|
|
396
472
|
np.array([num_reqs], dtype=np.int32)):
|
|
397
473
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
398
474
|
self._run_compilation(
|
|
399
|
-
"compute_logits",
|
|
475
|
+
f"worker{self.runner.rank} compute_logits",
|
|
400
476
|
self.runner.compute_logits_fn,
|
|
401
477
|
self.runner.state,
|
|
402
478
|
hidden_states,
|
|
@@ -410,43 +486,48 @@ class CompilationManager:
|
|
|
410
486
|
for num_reqs in self.runner.num_reqs_paddings:
|
|
411
487
|
logits_sharding = NamedSharding(
|
|
412
488
|
self.runner.mesh,
|
|
413
|
-
PartitionSpec(ShardingAxisName.
|
|
489
|
+
PartitionSpec(ShardingAxisName.MLP_DATA,
|
|
490
|
+
ShardingAxisName.MLP_TENSOR))
|
|
414
491
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
415
492
|
sampling_metadata_sharding = NamedSharding(
|
|
416
493
|
self.runner.mesh, PartitionSpec(
|
|
417
|
-
ShardingAxisName.
|
|
494
|
+
ShardingAxisName.MLP_DATA)) if dp_size > 1 else None
|
|
418
495
|
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
|
|
419
496
|
logits_sharding)
|
|
420
497
|
for do_sampling in (True, False):
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
498
|
+
for logprobs in (True, False):
|
|
499
|
+
if do_sampling:
|
|
500
|
+
temperature = np.full((num_reqs, ),
|
|
501
|
+
0.7,
|
|
502
|
+
dtype=np.float32)
|
|
503
|
+
top_k = np.full((num_reqs, ), 20, dtype=np.int32)
|
|
504
|
+
top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
|
|
505
|
+
(temperature, top_k, top_p) = device_array(
|
|
506
|
+
self.runner.mesh, (temperature, top_k, top_p),
|
|
507
|
+
sharding=sampling_metadata_sharding)
|
|
508
|
+
else:
|
|
509
|
+
temperature = None
|
|
510
|
+
top_k = None
|
|
511
|
+
top_p = None
|
|
512
|
+
|
|
513
|
+
sampling_metadata = TPUSupportedSamplingMetadata(
|
|
514
|
+
temperature=temperature,
|
|
515
|
+
top_k=top_k,
|
|
516
|
+
top_p=top_p,
|
|
517
|
+
do_sampling=do_sampling,
|
|
518
|
+
logprobs=logprobs)
|
|
519
|
+
self._run_compilation(
|
|
520
|
+
f"worker{self.runner.rank} sample",
|
|
521
|
+
sample,
|
|
522
|
+
self.runner.rng_params_for_sampling,
|
|
523
|
+
self.runner.mesh,
|
|
524
|
+
logits,
|
|
525
|
+
sampling_metadata,
|
|
526
|
+
num_reqs=num_reqs,
|
|
527
|
+
do_sampling=do_sampling,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
self._sampling_precompiled = True
|
|
450
531
|
|
|
451
532
|
def _precompile_disagg_utils(self) -> None:
|
|
452
533
|
if not is_disagg_enabled():
|
|
@@ -476,10 +557,18 @@ class CompilationManager:
|
|
|
476
557
|
logger.info("Compiling gather_logprobs with different input shapes.")
|
|
477
558
|
hsize = self.runner.model_config.get_vocab_size()
|
|
478
559
|
for num_reqs in self.runner.num_reqs_paddings:
|
|
479
|
-
|
|
480
|
-
|
|
560
|
+
logits_sharding = NamedSharding(
|
|
561
|
+
self.runner.mesh,
|
|
562
|
+
PartitionSpec(ShardingAxisName.MLP_DATA,
|
|
563
|
+
ShardingAxisName.MLP_TENSOR))
|
|
564
|
+
token_ids_sharding = NamedSharding(
|
|
565
|
+
self.runner.mesh, PartitionSpec(ShardingAxisName.MLP_DATA, ))
|
|
566
|
+
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
|
|
567
|
+
logits_sharding)
|
|
568
|
+
token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32,
|
|
569
|
+
token_ids_sharding)
|
|
481
570
|
self._run_compilation(
|
|
482
|
-
"gather_logprobs",
|
|
571
|
+
f"worker{self.runner.rank} gather_logprobs",
|
|
483
572
|
self.runner._compute_and_gather_logprobs,
|
|
484
573
|
logits,
|
|
485
574
|
token_ids,
|
|
@@ -487,6 +576,8 @@ class CompilationManager:
|
|
|
487
576
|
num_reqs=num_reqs,
|
|
488
577
|
)
|
|
489
578
|
|
|
579
|
+
self._gather_logprobs_precompiled = True
|
|
580
|
+
|
|
490
581
|
def _precompile_speculative_decoding(self) -> None:
|
|
491
582
|
logger.info(
|
|
492
583
|
"Compiling speculative_decoding with different input shapes.")
|
|
@@ -531,7 +622,7 @@ class CompilationManager:
|
|
|
531
622
|
do_sampling=do_sampling)
|
|
532
623
|
|
|
533
624
|
self._run_compilation(
|
|
534
|
-
compilation_name,
|
|
625
|
+
f"worker{self.runner.rank} {compilation_name}",
|
|
535
626
|
self.runner.rejection_sampler,
|
|
536
627
|
draft_token_ids,
|
|
537
628
|
num_draft_tokens,
|
|
@@ -601,6 +692,7 @@ class CompilationManager:
|
|
|
601
692
|
self._run_compilation(
|
|
602
693
|
"eagle3_get_draft_token_ids",
|
|
603
694
|
self.runner.drafter._get_draft_token_ids,
|
|
695
|
+
self.runner.drafter.state,
|
|
604
696
|
hidden_states,
|
|
605
697
|
num_logits=num_logits,
|
|
606
698
|
)
|
|
@@ -645,9 +737,9 @@ class CompilationManager:
|
|
|
645
737
|
num_reqs,
|
|
646
738
|
):
|
|
647
739
|
target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
|
|
648
|
-
token_indices, query_start_loc,
|
|
649
|
-
aux_hidden_states, attention_metadata,
|
|
650
|
-
num_reqs)
|
|
740
|
+
self.runner.drafter.state, token_indices, query_start_loc,
|
|
741
|
+
seq_lens, input_ids, aux_hidden_states, attention_metadata,
|
|
742
|
+
next_token_ids, num_reqs)
|
|
651
743
|
return target_hidden_states, input_ids, last_token_indices
|
|
652
744
|
|
|
653
745
|
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
@@ -724,6 +816,7 @@ class CompilationManager:
|
|
|
724
816
|
self._run_compilation(
|
|
725
817
|
"eagle3_prepare_hidden_states_and_input_ids",
|
|
726
818
|
self.runner.drafter._prepare_hidden_states_and_input_ids,
|
|
819
|
+
self.runner.drafter.state,
|
|
727
820
|
aux_hidden_states,
|
|
728
821
|
query_start_loc,
|
|
729
822
|
target_token_ids,
|
|
@@ -758,6 +851,7 @@ class CompilationManager:
|
|
|
758
851
|
self._run_compilation(
|
|
759
852
|
"eagle3_select_inputs_for_loop_speculation",
|
|
760
853
|
self.runner.drafter._select_inputs_for_loop_speculation,
|
|
854
|
+
self.runner.drafter.state,
|
|
761
855
|
positions,
|
|
762
856
|
hidden_states,
|
|
763
857
|
hidden_states,
|
|
@@ -768,6 +862,7 @@ class CompilationManager:
|
|
|
768
862
|
self._run_compilation(
|
|
769
863
|
"eagle3_select_draft_token_ids",
|
|
770
864
|
self.runner.drafter._select_draft_token_ids,
|
|
865
|
+
self.runner.drafter.state,
|
|
771
866
|
hidden_states,
|
|
772
867
|
last_token_indices,
|
|
773
868
|
num_tokens=num_tokens,
|
tpu_inference/runner/kv_cache.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
from typing import Any, List
|
|
2
16
|
|
|
3
17
|
import jax
|
|
@@ -7,6 +21,7 @@ from jax._src import dtypes
|
|
|
7
21
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
22
|
from torchax.ops.mappings import t2j_dtype
|
|
9
23
|
|
|
24
|
+
import tpu_inference.kernels.mla.v1.kernel as mla
|
|
10
25
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
|
|
11
26
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
|
|
12
27
|
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
@@ -17,9 +32,13 @@ logger = init_logger(__name__)
|
|
|
17
32
|
DEFAULT_KV_CACHE_DTYPE = jnp.bfloat16
|
|
18
33
|
|
|
19
34
|
|
|
20
|
-
def get_kv_cache_shape_with_mesh(mesh: Mesh,
|
|
21
|
-
|
|
22
|
-
|
|
35
|
+
def get_kv_cache_shape_with_mesh(mesh: Mesh,
|
|
36
|
+
total_num_pages: int,
|
|
37
|
+
page_size: int,
|
|
38
|
+
actual_num_kv_heads: int,
|
|
39
|
+
actual_head_dim: int,
|
|
40
|
+
kv_dtype: any,
|
|
41
|
+
use_mla: bool = False):
|
|
23
42
|
"""Gets the KV cache shape based on the mesh configuration."""
|
|
24
43
|
|
|
25
44
|
model_cnt = mesh.shape["model"]
|
|
@@ -28,15 +47,21 @@ def get_kv_cache_shape_with_mesh(mesh: Mesh, total_num_pages: int,
|
|
|
28
47
|
# specific model, rather than being determined by the head_dim. If new
|
|
29
48
|
# models are introduced with a head_dim of 64, this will require additional
|
|
30
49
|
# model-specific adjustments.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
50
|
+
if use_mla:
|
|
51
|
+
get_kv_cache_shape_fn = mla.get_kv_cache_shape
|
|
52
|
+
shape = list(
|
|
53
|
+
get_kv_cache_shape_fn(total_num_pages, page_size, actual_head_dim,
|
|
54
|
+
kv_dtype))
|
|
55
|
+
else:
|
|
56
|
+
get_kv_cache_shape_fn = (
|
|
57
|
+
rpa_hd64.get_kv_cache_shape if actual_head_dim == 64 \
|
|
58
|
+
else rpa.get_kv_cache_shape
|
|
59
|
+
)
|
|
60
|
+
shape = list(
|
|
61
|
+
get_kv_cache_shape_fn(total_num_pages, page_size,
|
|
62
|
+
actual_num_kv_heads // model_cnt,
|
|
63
|
+
actual_head_dim, kv_dtype))
|
|
64
|
+
shape[2] *= model_cnt
|
|
40
65
|
return tuple(shape)
|
|
41
66
|
|
|
42
67
|
|
|
@@ -48,6 +73,7 @@ def create_kv_caches(
|
|
|
48
73
|
mesh: Mesh,
|
|
49
74
|
layer_names: List[str],
|
|
50
75
|
cache_dtype: jnp.dtype = DEFAULT_KV_CACHE_DTYPE,
|
|
76
|
+
use_mla: bool = False,
|
|
51
77
|
) -> List[jax.Array]:
|
|
52
78
|
"""
|
|
53
79
|
Creates a list of KV cache where each array mapps to single attention layer.
|
|
@@ -74,12 +100,16 @@ def create_kv_caches(
|
|
|
74
100
|
|
|
75
101
|
cache_shape = get_kv_cache_shape_with_mesh(mesh, num_blocks, block_size,
|
|
76
102
|
num_kv_heads, head_size,
|
|
77
|
-
cache_dtype)
|
|
103
|
+
cache_dtype, use_mla)
|
|
78
104
|
|
|
79
|
-
|
|
80
|
-
mesh,
|
|
81
|
-
|
|
82
|
-
|
|
105
|
+
if use_mla:
|
|
106
|
+
sharding = NamedSharding(mesh,
|
|
107
|
+
PartitionSpec(ShardingAxisName.MLP_TENSOR))
|
|
108
|
+
else:
|
|
109
|
+
sharding = NamedSharding(
|
|
110
|
+
mesh,
|
|
111
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA, None,
|
|
112
|
+
ShardingAxisName.ATTN_HEAD))
|
|
83
113
|
|
|
84
114
|
def _allocate() -> jax.Array:
|
|
85
115
|
return jnp.empty(
|
|
@@ -94,7 +124,8 @@ def create_kv_caches(
|
|
|
94
124
|
return kv_caches
|
|
95
125
|
|
|
96
126
|
|
|
97
|
-
def
|
|
127
|
+
def get_attention_page_size_bytes(mesh: Mesh,
|
|
128
|
+
kv_cache_specs: dict[str, Any]) -> int:
|
|
98
129
|
"""
|
|
99
130
|
Calculate KV cache page size of RPA kernel.
|
|
100
131
|
|
|
@@ -107,14 +138,16 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
|
|
|
107
138
|
"""
|
|
108
139
|
|
|
109
140
|
# Import it here to avoid circular import.
|
|
110
|
-
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
141
|
+
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
|
|
111
142
|
|
|
112
143
|
page_size_bytes_set = set()
|
|
113
144
|
for kv_cache_spec in kv_cache_specs.values():
|
|
114
145
|
assert isinstance(kv_cache_spec, AttentionSpec)
|
|
115
146
|
|
|
116
147
|
dtype = t2j_dtype(kv_cache_spec.dtype)
|
|
117
|
-
bits = dtypes.bit_width(dtype)
|
|
148
|
+
bits = (dtypes.bit_width(dtype) if hasattr(dtypes, "bit_width") else
|
|
149
|
+
dtypes.itemsize_bits(dtype))
|
|
150
|
+
use_mla = isinstance(kv_cache_spec, MLAAttentionSpec)
|
|
118
151
|
|
|
119
152
|
kv_cache_shape = get_kv_cache_shape_with_mesh(
|
|
120
153
|
mesh=mesh,
|
|
@@ -123,6 +156,7 @@ def get_rpa_page_size_bytes(mesh: Mesh, kv_cache_specs: dict[str, Any]) -> int:
|
|
|
123
156
|
actual_num_kv_heads=kv_cache_spec.num_kv_heads,
|
|
124
157
|
actual_head_dim=kv_cache_spec.head_size,
|
|
125
158
|
kv_dtype=dtype,
|
|
159
|
+
use_mla=use_mla,
|
|
126
160
|
)
|
|
127
161
|
page_size_bytes = (bits * np.prod(kv_cache_shape)) // 8
|
|
128
162
|
page_size_bytes_set.add(page_size_bytes)
|