tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,780 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
import numpy as np
|
|
8
|
+
import vllm.envs as envs
|
|
9
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
10
|
+
|
|
11
|
+
from tpu_inference.core.disagg_utils import is_disagg_enabled
|
|
12
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
13
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
14
|
+
from tpu_inference.layers.jax.sample.sampling import sample
|
|
15
|
+
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
16
|
+
TPUSupportedSamplingMetadata
|
|
17
|
+
from tpu_inference.logger import init_logger
|
|
18
|
+
from tpu_inference.utils import device_array
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
22
|
+
|
|
23
|
+
logger = init_logger(__name__)
|
|
24
|
+
|
|
25
|
+
# Constants for block bucketing in disaggregated utilities
|
|
26
|
+
BLOCK_BUCKETS = [1, 2, 4, 8, 16, 32, 64]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class CompilationManager:
|
|
30
|
+
|
|
31
|
+
def __init__(self, runner: "TPUModelRunner"):
|
|
32
|
+
self.runner = runner
|
|
33
|
+
if not envs.VLLM_DISABLE_COMPILE_CACHE:
|
|
34
|
+
logger.info("Enabling JAX compile cache.")
|
|
35
|
+
jax.config.update("jax_compilation_cache_dir",
|
|
36
|
+
envs.VLLM_XLA_CACHE_PATH)
|
|
37
|
+
|
|
38
|
+
def _create_dummy_tensor(self,
|
|
39
|
+
shape: Tuple[int, ...],
|
|
40
|
+
dtype: Any,
|
|
41
|
+
sharding: Optional[NamedSharding] = None) -> Any:
|
|
42
|
+
"""Helper to create dummy tensors for precompilation."""
|
|
43
|
+
tensor = jnp.ones(shape, dtype=dtype)
|
|
44
|
+
if sharding:
|
|
45
|
+
return device_array(self.runner.mesh, tensor, sharding=sharding)
|
|
46
|
+
return device_array(self.runner.mesh, tensor)
|
|
47
|
+
|
|
48
|
+
def _should_skip_padding_combination(self, outer_val: int, inner_val: int,
|
|
49
|
+
only_equal: bool) -> bool:
|
|
50
|
+
"""Helper to determine if we should skip this padding combination."""
|
|
51
|
+
if only_equal:
|
|
52
|
+
return inner_val != outer_val
|
|
53
|
+
return inner_val > outer_val
|
|
54
|
+
|
|
55
|
+
def _run_compilation(self, name: str, fn: Callable, *args,
|
|
56
|
+
**kwargs) -> None:
|
|
57
|
+
logger.info(f"Precompile {name} --> {kwargs}")
|
|
58
|
+
start = time.perf_counter()
|
|
59
|
+
result = fn(*args)
|
|
60
|
+
if result is not None:
|
|
61
|
+
if isinstance(result, tuple):
|
|
62
|
+
for r in result:
|
|
63
|
+
r.block_until_ready()
|
|
64
|
+
else:
|
|
65
|
+
result.block_until_ready()
|
|
66
|
+
end = time.perf_counter()
|
|
67
|
+
logger.info("Compilation finished in %.2f [secs].", end - start)
|
|
68
|
+
|
|
69
|
+
def capture_model(self) -> None:
|
|
70
|
+
if os.getenv("SKIP_JAX_PRECOMPILE",
|
|
71
|
+
False) or self.runner.model_config.enforce_eager:
|
|
72
|
+
return
|
|
73
|
+
logger.info("Precompile all the subgraphs with possible input shapes.")
|
|
74
|
+
|
|
75
|
+
with self.runner.maybe_setup_dummy_loras(self.runner.lora_config):
|
|
76
|
+
self._precompile_backbone_text_only()
|
|
77
|
+
if self.runner.is_multimodal_model:
|
|
78
|
+
self.runner.precompile_vision_encoder_fn(
|
|
79
|
+
self._run_compilation, )
|
|
80
|
+
self._precompile_input_embeddings_merger()
|
|
81
|
+
self._precompile_backbone_with_inputs_embeds()
|
|
82
|
+
if self.runner.scheduler_config.async_scheduling:
|
|
83
|
+
self._precompile_substitute_placeholder_token()
|
|
84
|
+
self._precompile_select_from_array()
|
|
85
|
+
self._precompile_compute_logits()
|
|
86
|
+
self._precompile_disagg_utils()
|
|
87
|
+
self._precompile_sampling()
|
|
88
|
+
self._precompile_gather_logprobs()
|
|
89
|
+
self._precompile_structured_decoding()
|
|
90
|
+
if self.runner.speculative_config:
|
|
91
|
+
self._precompile_speculative_decoding()
|
|
92
|
+
|
|
93
|
+
def _precompile_input_embeddings_merger(self) -> None:
|
|
94
|
+
for num_tokens in self.runner.num_tokens_paddings:
|
|
95
|
+
hidden_size = self.runner.vllm_config.model_config.get_hidden_size(
|
|
96
|
+
)
|
|
97
|
+
sharding = NamedSharding(self.runner.mesh, PartitionSpec())
|
|
98
|
+
dummy_multimodal_embeddings = self._create_dummy_tensor(
|
|
99
|
+
(num_tokens, hidden_size),
|
|
100
|
+
self.runner.vllm_config.model_config.dtype,
|
|
101
|
+
sharding=sharding)
|
|
102
|
+
dummy_input_ids = self._create_dummy_tensor((num_tokens, ),
|
|
103
|
+
jnp.int32)
|
|
104
|
+
|
|
105
|
+
self._run_compilation(
|
|
106
|
+
"input_embeddings_merger",
|
|
107
|
+
self.runner.get_input_embeddings_fn,
|
|
108
|
+
self.runner.state,
|
|
109
|
+
dummy_input_ids,
|
|
110
|
+
dummy_multimodal_embeddings,
|
|
111
|
+
num_tokens=num_tokens,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
self._run_compilation(
|
|
115
|
+
"input_embeddings_merger_text_only",
|
|
116
|
+
self.runner.get_input_embeddings_fn,
|
|
117
|
+
self.runner.state,
|
|
118
|
+
dummy_input_ids,
|
|
119
|
+
None,
|
|
120
|
+
num_tokens=num_tokens,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def _precompile_backbone_helper(self, name, *, input_ids, positions,
|
|
124
|
+
inputs_embeds) -> None:
|
|
125
|
+
num_tokens = None
|
|
126
|
+
if input_ids is not None:
|
|
127
|
+
num_tokens = input_ids.shape[0]
|
|
128
|
+
elif inputs_embeds is not None:
|
|
129
|
+
num_tokens = inputs_embeds.shape[0]
|
|
130
|
+
assert num_tokens is not None
|
|
131
|
+
|
|
132
|
+
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
133
|
+
dp_sharding = NamedSharding(
|
|
134
|
+
self.runner.mesh, PartitionSpec(
|
|
135
|
+
ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
|
|
136
|
+
|
|
137
|
+
# Keep existing pattern for complex array operations
|
|
138
|
+
block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
|
|
139
|
+
block_tables = block_tables.reshape(-1)
|
|
140
|
+
block_tables = device_array(self.runner.mesh,
|
|
141
|
+
block_tables,
|
|
142
|
+
sharding=dp_sharding)
|
|
143
|
+
|
|
144
|
+
seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
|
|
145
|
+
jnp.int32, dp_sharding)
|
|
146
|
+
query_start_loc = self._create_dummy_tensor(
|
|
147
|
+
(self.runner.max_num_reqs + dp_size, ), jnp.int32, dp_sharding)
|
|
148
|
+
|
|
149
|
+
# Keep existing pattern for specific value arrays
|
|
150
|
+
request_distribution = np.array([0, 0, 0] * dp_size, dtype=np.int32)
|
|
151
|
+
request_distribution = device_array(self.runner.mesh,
|
|
152
|
+
request_distribution,
|
|
153
|
+
sharding=dp_sharding)
|
|
154
|
+
|
|
155
|
+
attention_metadata = AttentionMetadata(
|
|
156
|
+
input_positions=positions,
|
|
157
|
+
block_tables=block_tables,
|
|
158
|
+
seq_lens=seq_lens,
|
|
159
|
+
query_start_loc=query_start_loc,
|
|
160
|
+
request_distribution=request_distribution,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def model_fn_wrapper(
|
|
164
|
+
state,
|
|
165
|
+
kv_caches,
|
|
166
|
+
input_ids,
|
|
167
|
+
attention_metadata,
|
|
168
|
+
inputs_embeds,
|
|
169
|
+
layer_name_to_kvcache_index,
|
|
170
|
+
lora_metadata,
|
|
171
|
+
):
|
|
172
|
+
kv_caches, hidden_states, _ = self.runner.model_fn(
|
|
173
|
+
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
|
|
174
|
+
layer_name_to_kvcache_index, lora_metadata)
|
|
175
|
+
self.runner.kv_caches = kv_caches
|
|
176
|
+
return hidden_states
|
|
177
|
+
|
|
178
|
+
with self.runner.maybe_select_dummy_loras(
|
|
179
|
+
self.runner.lora_config, np.array([num_tokens],
|
|
180
|
+
dtype=np.int32)):
|
|
181
|
+
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
182
|
+
self._run_compilation(
|
|
183
|
+
name,
|
|
184
|
+
model_fn_wrapper,
|
|
185
|
+
self.runner.state,
|
|
186
|
+
self.runner.kv_caches,
|
|
187
|
+
input_ids,
|
|
188
|
+
attention_metadata,
|
|
189
|
+
inputs_embeds,
|
|
190
|
+
tuple(self.runner.layer_name_to_kvcache_index.items()),
|
|
191
|
+
lora_metadata,
|
|
192
|
+
num_tokens=num_tokens,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def _precompile_substitute_placeholder_token(self) -> None:
|
|
196
|
+
"""Precompiles the token substitution function for all expected input shapes.
|
|
197
|
+
|
|
198
|
+
It iterates through all potential padded token lengths
|
|
199
|
+
(`num_tokens_paddings`) and request batch sizes (`num_reqs_paddings`)
|
|
200
|
+
that the scheduler is expected to handle, ensuring a compiled version
|
|
201
|
+
is ready for each combination.
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
for num_tokens in self.runner.num_tokens_paddings:
|
|
205
|
+
dp_sharding = NamedSharding(
|
|
206
|
+
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
|
|
207
|
+
) if self.runner.vllm_config.sharding_config.total_dp_size > 1 else None
|
|
208
|
+
|
|
209
|
+
for num_reqs in self.runner.num_reqs_paddings:
|
|
210
|
+
padded_token_in_tpu_cur_input_indices = np.zeros(
|
|
211
|
+
(num_tokens, ), dtype=np.int32)
|
|
212
|
+
padded_token_in_tpu_pre_next_tokens_indices = np.zeros(
|
|
213
|
+
(num_tokens, ), dtype=jnp.int32)
|
|
214
|
+
(padded_token_in_tpu_cur_input_indices,
|
|
215
|
+
padded_token_in_tpu_pre_next_tokens_indices) = device_array(
|
|
216
|
+
self.runner.mesh,
|
|
217
|
+
(padded_token_in_tpu_cur_input_indices,
|
|
218
|
+
padded_token_in_tpu_pre_next_tokens_indices))
|
|
219
|
+
|
|
220
|
+
input_ids = self._create_dummy_tensor((num_tokens, ),
|
|
221
|
+
jnp.int32, dp_sharding)
|
|
222
|
+
# Need align to the sampling output
|
|
223
|
+
next_tokens = self._create_dummy_tensor(
|
|
224
|
+
(num_reqs, ),
|
|
225
|
+
jnp.int32,
|
|
226
|
+
sharding=dp_sharding,
|
|
227
|
+
)
|
|
228
|
+
placeholder_num = 1
|
|
229
|
+
self._run_compilation(
|
|
230
|
+
"_substitute_placeholder_token_fn",
|
|
231
|
+
self.runner._substitute_placeholder_token_fn,
|
|
232
|
+
input_ids,
|
|
233
|
+
padded_token_in_tpu_cur_input_indices,
|
|
234
|
+
padded_token_in_tpu_pre_next_tokens_indices,
|
|
235
|
+
next_tokens,
|
|
236
|
+
placeholder_num,
|
|
237
|
+
num_tokens=num_tokens,
|
|
238
|
+
num_reqs=num_reqs,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
def _precompile_backbone_text_only(self) -> None:
|
|
242
|
+
for num_tokens in self.runner.num_tokens_paddings:
|
|
243
|
+
dp_sharding = NamedSharding(
|
|
244
|
+
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
|
|
245
|
+
) if self.runner.vllm_config.sharding_config.total_dp_size > 1 else None
|
|
246
|
+
|
|
247
|
+
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32,
|
|
248
|
+
dp_sharding)
|
|
249
|
+
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
|
|
250
|
+
dp_sharding)
|
|
251
|
+
self._precompile_backbone_helper("backbone",
|
|
252
|
+
input_ids=input_ids,
|
|
253
|
+
positions=positions,
|
|
254
|
+
inputs_embeds=None)
|
|
255
|
+
|
|
256
|
+
def _precompile_backbone_with_inputs_embeds(self) -> None:
|
|
257
|
+
hidden_size = self.runner.model_config.get_hidden_size()
|
|
258
|
+
dtype = self.runner.model_config.dtype
|
|
259
|
+
for num_tokens in self.runner.num_tokens_paddings:
|
|
260
|
+
inputs_embeds = self._create_dummy_tensor(
|
|
261
|
+
(num_tokens, hidden_size), dtype)
|
|
262
|
+
if self.runner.uses_mrope:
|
|
263
|
+
positions = self._create_dummy_tensor((3, num_tokens),
|
|
264
|
+
jnp.int32)
|
|
265
|
+
else:
|
|
266
|
+
positions = self._create_dummy_tensor((num_tokens, ),
|
|
267
|
+
jnp.int32)
|
|
268
|
+
self._precompile_backbone_helper("backbone with embeds",
|
|
269
|
+
input_ids=None,
|
|
270
|
+
positions=positions,
|
|
271
|
+
inputs_embeds=inputs_embeds)
|
|
272
|
+
|
|
273
|
+
def _precompile_select_from_array_helper(
|
|
274
|
+
self,
|
|
275
|
+
name: str,
|
|
276
|
+
source_paddings: List[int],
|
|
277
|
+
indices_paddings: List[int],
|
|
278
|
+
hidden_dim: int,
|
|
279
|
+
input_sharding: Optional[NamedSharding] = None,
|
|
280
|
+
indices_sharding: Optional[NamedSharding] = None,
|
|
281
|
+
only_equal_paddings: bool = False,
|
|
282
|
+
check_should_skip_padding: bool = True,
|
|
283
|
+
) -> None:
|
|
284
|
+
"""Precompile select_from_array operations with various input shape combinations.
|
|
285
|
+
|
|
286
|
+
This helper method generates and precompiles the select_from_array function for different
|
|
287
|
+
combinations of array sizes and index counts. The operation being precompiled is
|
|
288
|
+
array[indices] where:
|
|
289
|
+
- array has shape (array_size, hidden_dim)
|
|
290
|
+
- indices has shape (indices_count,)
|
|
291
|
+
- result has shape (indices_count, hidden_dim)
|
|
292
|
+
|
|
293
|
+
This is essential for TPU compilation as JAX needs to precompile functions with all
|
|
294
|
+
possible input shapes that will be encountered during runtime.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
name: Descriptive name for logging purposes (e.g., "select all logits")
|
|
298
|
+
source_paddings: List of possible sizes for the array being indexed (first dimension)
|
|
299
|
+
indices_paddings: List of possible counts of indices to select
|
|
300
|
+
hidden_dim: Second dimension size of the array (e.g., hidden_size or vocab_size)
|
|
301
|
+
sharding: Optional sharding specification for distributed computation
|
|
302
|
+
only_equal_paddings: If True, only compile when array size equals indices count
|
|
303
|
+
check_should_skip_padding: If True, check whether to skip certain padding combinations to reduce compilation time
|
|
304
|
+
"""
|
|
305
|
+
logger.info(f"Compiling select_from_array for {name}.")
|
|
306
|
+
for array_size in source_paddings:
|
|
307
|
+
for indices_count in indices_paddings:
|
|
308
|
+
if check_should_skip_padding and self._should_skip_padding_combination(
|
|
309
|
+
array_size, indices_count, only_equal_paddings):
|
|
310
|
+
continue
|
|
311
|
+
|
|
312
|
+
input_tensor = self._create_dummy_tensor(
|
|
313
|
+
(array_size, hidden_dim), jnp.bfloat16, input_sharding)
|
|
314
|
+
indices_to_select = self._create_dummy_tensor(
|
|
315
|
+
(indices_count, ), jnp.int32, indices_sharding)
|
|
316
|
+
|
|
317
|
+
self._run_compilation(
|
|
318
|
+
f"select_from_array [{name}]",
|
|
319
|
+
self.runner._select_from_array_fn, input_tensor,
|
|
320
|
+
indices_to_select, **{
|
|
321
|
+
"array_size": array_size,
|
|
322
|
+
"index_size": indices_count
|
|
323
|
+
})
|
|
324
|
+
|
|
325
|
+
def _precompile_select_from_array(self) -> None:
|
|
326
|
+
logger.info("Compiling select_from_array with different input shapes.")
|
|
327
|
+
hsize = self.runner.model_config.get_hidden_size()
|
|
328
|
+
|
|
329
|
+
if self.runner.speculative_config:
|
|
330
|
+
index_paddings = self.runner.num_logits_paddings
|
|
331
|
+
else:
|
|
332
|
+
index_paddings = self.runner.num_reqs_paddings
|
|
333
|
+
dp_sharding = NamedSharding(self.runner.mesh,
|
|
334
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
335
|
+
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
336
|
+
self._precompile_select_from_array_helper(
|
|
337
|
+
name="select all logits",
|
|
338
|
+
source_paddings=self.runner.num_tokens_paddings,
|
|
339
|
+
indices_paddings=index_paddings,
|
|
340
|
+
hidden_dim=hsize,
|
|
341
|
+
input_sharding=dp_sharding,
|
|
342
|
+
indices_sharding=dp_sharding if dp_size > 1 else None,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
if self.runner.speculative_config:
|
|
346
|
+
vocab_size = self.runner.model_config.get_vocab_size()
|
|
347
|
+
self._precompile_select_from_array_helper(
|
|
348
|
+
name="select bonus tokens for spec decoding",
|
|
349
|
+
source_paddings=self.runner.num_logits_paddings,
|
|
350
|
+
indices_paddings=self.runner.num_reqs_paddings,
|
|
351
|
+
hidden_dim=vocab_size,
|
|
352
|
+
input_sharding=NamedSharding(self.runner.mesh,
|
|
353
|
+
PartitionSpec(None, "model")),
|
|
354
|
+
)
|
|
355
|
+
self._precompile_select_from_array_helper(
|
|
356
|
+
name="select target tokens for spec decoding",
|
|
357
|
+
source_paddings=self.runner.num_logits_paddings,
|
|
358
|
+
indices_paddings=self.runner.num_logits_paddings,
|
|
359
|
+
hidden_dim=vocab_size,
|
|
360
|
+
input_sharding=NamedSharding(self.runner.mesh,
|
|
361
|
+
PartitionSpec(None, "model")),
|
|
362
|
+
only_equal_paddings=True,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
def _precompile_compute_logits(self) -> None:
|
|
366
|
+
logger.info("Compiling compute_logits with different input shapes.")
|
|
367
|
+
hsize = self.runner.model_config.get_hidden_size()
|
|
368
|
+
leading_shape = self.runner.num_reqs_paddings if not self.runner.speculative_config else self.runner.num_logits_paddings
|
|
369
|
+
dp_sharding = NamedSharding(self.runner.mesh,
|
|
370
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
371
|
+
for num_reqs in leading_shape:
|
|
372
|
+
hidden_states = self._create_dummy_tensor(
|
|
373
|
+
(num_reqs, hsize), jnp.bfloat16, dp_sharding)
|
|
374
|
+
with self.runner.maybe_select_dummy_loras(
|
|
375
|
+
self.runner.lora_config,
|
|
376
|
+
np.array([num_reqs], dtype=np.int32)):
|
|
377
|
+
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
378
|
+
self._run_compilation(
|
|
379
|
+
"compute_logits",
|
|
380
|
+
self.runner.compute_logits_fn,
|
|
381
|
+
self.runner.state,
|
|
382
|
+
hidden_states,
|
|
383
|
+
lora_metadata,
|
|
384
|
+
num_reqs=num_reqs,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
def _precompile_sampling(self) -> None:
|
|
388
|
+
logger.info("Compiling sampling with different input shapes.")
|
|
389
|
+
hsize = self.runner.model_config.get_vocab_size()
|
|
390
|
+
for num_reqs in self.runner.num_reqs_paddings:
|
|
391
|
+
logits_sharding = NamedSharding(
|
|
392
|
+
self.runner.mesh,
|
|
393
|
+
PartitionSpec(ShardingAxisName.ATTN_DATA, "model"))
|
|
394
|
+
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
395
|
+
sampling_metadata_sharding = NamedSharding(
|
|
396
|
+
self.runner.mesh, PartitionSpec(
|
|
397
|
+
ShardingAxisName.ATTN_DATA)) if dp_size > 1 else None
|
|
398
|
+
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
|
|
399
|
+
logits_sharding)
|
|
400
|
+
for do_sampling in (True, False):
|
|
401
|
+
if do_sampling:
|
|
402
|
+
temperature = np.full((num_reqs, ), 0.7, dtype=np.float32)
|
|
403
|
+
top_k = np.full((num_reqs, ), 20, dtype=np.int32)
|
|
404
|
+
top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
|
|
405
|
+
(temperature, top_k,
|
|
406
|
+
top_p) = device_array(self.runner.mesh,
|
|
407
|
+
(temperature, top_k, top_p),
|
|
408
|
+
sharding=sampling_metadata_sharding)
|
|
409
|
+
else:
|
|
410
|
+
temperature = None
|
|
411
|
+
top_k = None
|
|
412
|
+
top_p = None
|
|
413
|
+
|
|
414
|
+
sampling_metadata = TPUSupportedSamplingMetadata(
|
|
415
|
+
temperature=temperature,
|
|
416
|
+
top_k=top_k,
|
|
417
|
+
top_p=top_p,
|
|
418
|
+
do_sampling=do_sampling,
|
|
419
|
+
)
|
|
420
|
+
self._run_compilation(
|
|
421
|
+
"sample",
|
|
422
|
+
sample,
|
|
423
|
+
self.runner.rng_params_for_sampling,
|
|
424
|
+
self.runner.mesh,
|
|
425
|
+
logits,
|
|
426
|
+
sampling_metadata,
|
|
427
|
+
num_reqs=num_reqs,
|
|
428
|
+
do_sampling=do_sampling,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
def _precompile_disagg_utils(self) -> None:
|
|
432
|
+
if not is_disagg_enabled():
|
|
433
|
+
return
|
|
434
|
+
logger.info(
|
|
435
|
+
"Compiling disaggregated util with different input shapes.")
|
|
436
|
+
block_size = self.runner.block_size
|
|
437
|
+
for num_blocks in range(1, self.runner.max_num_blocks_per_req // 2):
|
|
438
|
+
logger.info(
|
|
439
|
+
f"Precompile slice and insert for num_blocks {num_blocks}")
|
|
440
|
+
block_numbers = list(range(1, num_blocks + 1))
|
|
441
|
+
kv_cache_slices = self.runner.kv_cache_manager.get_kv_cache_for_block_ids(
|
|
442
|
+
block_numbers)
|
|
443
|
+
# Prevent the slices from getting freed by insert before finishing this operation
|
|
444
|
+
for layer_cache in kv_cache_slices:
|
|
445
|
+
layer_cache.block_until_ready()
|
|
446
|
+
self.runner.kv_caches = self.runner.kv_cache_manager._jitted_insert_continuous_kv_cache(
|
|
447
|
+
block_size,
|
|
448
|
+
self.runner.kv_caches,
|
|
449
|
+
kv_cache_slices,
|
|
450
|
+
block_numbers[0],
|
|
451
|
+
)
|
|
452
|
+
for layer_cache in self.runner.kv_caches:
|
|
453
|
+
layer_cache.block_until_ready()
|
|
454
|
+
|
|
455
|
+
def _precompile_gather_logprobs(self) -> None:
|
|
456
|
+
logger.info("Compiling gather_logprobs with different input shapes.")
|
|
457
|
+
hsize = self.runner.model_config.get_vocab_size()
|
|
458
|
+
for num_reqs in self.runner.num_reqs_paddings:
|
|
459
|
+
logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
|
|
460
|
+
token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
|
|
461
|
+
self._run_compilation(
|
|
462
|
+
"gather_logprobs",
|
|
463
|
+
self.runner._compute_and_gather_logprobs,
|
|
464
|
+
logits,
|
|
465
|
+
token_ids,
|
|
466
|
+
self.runner.model_config.max_logprobs,
|
|
467
|
+
num_reqs=num_reqs,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
def _precompile_speculative_decoding(self) -> None:
|
|
471
|
+
logger.info(
|
|
472
|
+
"Compiling speculative_decoding with different input shapes.")
|
|
473
|
+
self._precompile_rejection_sampler()
|
|
474
|
+
if self.runner.speculative_config.method == "eagle3":
|
|
475
|
+
self._precompile_eagle3_helpers()
|
|
476
|
+
|
|
477
|
+
def _precompile_rejection_sampler(self) -> None:
|
|
478
|
+
logger.info("Compiling rejection_sampler with different input shapes.")
|
|
479
|
+
vocab_size = self.runner.model_config.get_vocab_size()
|
|
480
|
+
for num_logits in self.runner.num_logits_paddings:
|
|
481
|
+
for num_reqs in self.runner.num_reqs_paddings:
|
|
482
|
+
sharding = NamedSharding(self.runner.mesh,
|
|
483
|
+
PartitionSpec(None, "model"))
|
|
484
|
+
target_probs = self._create_dummy_tensor(
|
|
485
|
+
(num_logits, vocab_size), jnp.bfloat16, sharding)
|
|
486
|
+
draft_token_ids = self._create_dummy_tensor((num_logits, ),
|
|
487
|
+
jnp.int32)
|
|
488
|
+
num_draft_tokens = self._create_dummy_tensor((num_reqs, ),
|
|
489
|
+
jnp.int32)
|
|
490
|
+
bonus_token_ids = self._create_dummy_tensor((num_reqs, ),
|
|
491
|
+
jnp.int32)
|
|
492
|
+
|
|
493
|
+
for do_sampling in (False, True):
|
|
494
|
+
draft_probs = None
|
|
495
|
+
if do_sampling:
|
|
496
|
+
compilation_name = "random_rejection_sampler"
|
|
497
|
+
temperature = self._create_dummy_tensor((num_reqs, ),
|
|
498
|
+
np.float32)
|
|
499
|
+
top_k = self._create_dummy_tensor((num_reqs, ),
|
|
500
|
+
np.int32)
|
|
501
|
+
top_p = self._create_dummy_tensor((num_reqs, ),
|
|
502
|
+
np.float32)
|
|
503
|
+
sampling_metadata = TPUSupportedSamplingMetadata(
|
|
504
|
+
temperature=temperature,
|
|
505
|
+
top_k=top_k,
|
|
506
|
+
top_p=top_p,
|
|
507
|
+
do_sampling=do_sampling)
|
|
508
|
+
else:
|
|
509
|
+
compilation_name = "greedy_rejection_sampler"
|
|
510
|
+
sampling_metadata = TPUSupportedSamplingMetadata(
|
|
511
|
+
do_sampling=do_sampling)
|
|
512
|
+
|
|
513
|
+
self._run_compilation(
|
|
514
|
+
compilation_name,
|
|
515
|
+
self.runner.rejection_sampler,
|
|
516
|
+
draft_token_ids,
|
|
517
|
+
num_draft_tokens,
|
|
518
|
+
draft_probs,
|
|
519
|
+
target_probs,
|
|
520
|
+
bonus_token_ids,
|
|
521
|
+
sampling_metadata,
|
|
522
|
+
self.runner.rng_params_for_sampling,
|
|
523
|
+
num_logits=num_logits,
|
|
524
|
+
num_reqs=num_reqs,
|
|
525
|
+
do_sampling=do_sampling,
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
def _precompile_eagle3_helpers(self) -> None:
|
|
529
|
+
logger.info(
|
|
530
|
+
"Compiling eagle3 jitted helpers with different input shapes.")
|
|
531
|
+
hidden_size = self.runner.model_config.get_hidden_size()
|
|
532
|
+
dtype = self.runner.model_config.dtype
|
|
533
|
+
|
|
534
|
+
num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
|
|
535
|
+
draft_kv_cache_group_id = num_kv_cache_groups - 1
|
|
536
|
+
block_tables = self.runner.input_batch.block_table[
|
|
537
|
+
draft_kv_cache_group_id].get_cpu_tensor().reshape(-1)
|
|
538
|
+
block_tables = jax.device_put(
|
|
539
|
+
block_tables, NamedSharding(self.runner.mesh,
|
|
540
|
+
PartitionSpec(None, )))
|
|
541
|
+
|
|
542
|
+
selected_positions = self._create_dummy_tensor(
|
|
543
|
+
(self.runner.max_num_reqs, ), jnp.int32)
|
|
544
|
+
seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
|
|
545
|
+
jnp.int32)
|
|
546
|
+
query_start_loc = self._create_dummy_tensor(
|
|
547
|
+
(self.runner.max_num_reqs + 1, ), jnp.int32)
|
|
548
|
+
self._run_compilation(
|
|
549
|
+
"_update_inputs_for_loop_speculation for the first loop",
|
|
550
|
+
self.runner.drafter._update_inputs_for_loop_speculation,
|
|
551
|
+
selected_positions, seq_lens, block_tables)
|
|
552
|
+
self._run_compilation(
|
|
553
|
+
"_update_inputs_for_loop_speculation for the subsequent loops",
|
|
554
|
+
self.runner.drafter._update_inputs_for_loop_speculation,
|
|
555
|
+
selected_positions, seq_lens, block_tables)
|
|
556
|
+
|
|
557
|
+
request_distribution = np.array([0, 0, 0], dtype=np.int32)
|
|
558
|
+
request_distribution = device_array(self.runner.mesh,
|
|
559
|
+
request_distribution)
|
|
560
|
+
|
|
561
|
+
for num_reqs_padding in self.runner.num_reqs_paddings:
|
|
562
|
+
for i in range(1, self.runner.drafter.num_speculative_tokens + 1):
|
|
563
|
+
draft_token_ids_list = [
|
|
564
|
+
self._create_dummy_tensor(
|
|
565
|
+
(num_reqs_padding, ), jnp.int32,
|
|
566
|
+
NamedSharding(self.runner.mesh, PartitionSpec()))
|
|
567
|
+
for _ in range(i)
|
|
568
|
+
]
|
|
569
|
+
self._run_compilation(
|
|
570
|
+
"eagle3_stack_draft_token_ids",
|
|
571
|
+
self.runner.drafter._stack_draft_token_ids,
|
|
572
|
+
draft_token_ids_list,
|
|
573
|
+
num_reqs=num_reqs_padding,
|
|
574
|
+
draft_token_ids_list_length=len(draft_token_ids_list))
|
|
575
|
+
|
|
576
|
+
for num_logits in self.runner.num_logits_paddings:
|
|
577
|
+
hidden_states = self._create_dummy_tensor(
|
|
578
|
+
(num_logits, hidden_size), jnp.bfloat16)
|
|
579
|
+
self._run_compilation(
|
|
580
|
+
"eagle3_get_draft_token_ids",
|
|
581
|
+
self.runner.drafter._get_draft_token_ids,
|
|
582
|
+
hidden_states,
|
|
583
|
+
num_logits=num_logits,
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
input_ids_loop = self._create_dummy_tensor(
|
|
587
|
+
(self.runner.max_num_reqs, ), jnp.int32,
|
|
588
|
+
NamedSharding(self.runner.mesh, PartitionSpec()))
|
|
589
|
+
target_hidden_state_loop = self._create_dummy_tensor(
|
|
590
|
+
(self.runner.max_num_reqs, hidden_size), dtype,
|
|
591
|
+
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
592
|
+
next_token_ids = self._create_dummy_tensor(
|
|
593
|
+
(self.runner.max_num_reqs, ), jnp.int32)
|
|
594
|
+
last_token_indices = self._create_dummy_tensor(
|
|
595
|
+
(self.runner.max_num_reqs, ), jnp.int32)
|
|
596
|
+
for num_tokens in self.runner.num_tokens_paddings:
|
|
597
|
+
aux_hidden_states = [
|
|
598
|
+
self._create_dummy_tensor((num_tokens, hidden_size), dtype),
|
|
599
|
+
self._create_dummy_tensor((num_tokens, hidden_size), dtype),
|
|
600
|
+
self._create_dummy_tensor((num_tokens, hidden_size), dtype),
|
|
601
|
+
]
|
|
602
|
+
|
|
603
|
+
positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
604
|
+
attention_metadata = AttentionMetadata(
|
|
605
|
+
input_positions=positions,
|
|
606
|
+
block_tables=block_tables,
|
|
607
|
+
seq_lens=seq_lens,
|
|
608
|
+
query_start_loc=query_start_loc,
|
|
609
|
+
request_distribution=request_distribution,
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
def filter_token_and_prepare_initial_inputs_wrapper(
|
|
613
|
+
token_indices,
|
|
614
|
+
query_start_loc,
|
|
615
|
+
seq_lens,
|
|
616
|
+
input_ids,
|
|
617
|
+
aux_hidden_states,
|
|
618
|
+
attention_metadata,
|
|
619
|
+
next_token_ids,
|
|
620
|
+
num_reqs,
|
|
621
|
+
):
|
|
622
|
+
target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
|
|
623
|
+
token_indices, query_start_loc, seq_lens, input_ids,
|
|
624
|
+
aux_hidden_states, attention_metadata, next_token_ids,
|
|
625
|
+
num_reqs)
|
|
626
|
+
return target_hidden_states, input_ids, last_token_indices
|
|
627
|
+
|
|
628
|
+
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
|
|
629
|
+
aux_hidden_states = [
|
|
630
|
+
self._create_dummy_tensor(
|
|
631
|
+
(num_tokens, hidden_size), jnp.bfloat16,
|
|
632
|
+
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
633
|
+
None))),
|
|
634
|
+
self._create_dummy_tensor(
|
|
635
|
+
(num_tokens, hidden_size), jnp.bfloat16,
|
|
636
|
+
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
637
|
+
None))),
|
|
638
|
+
self._create_dummy_tensor(
|
|
639
|
+
(num_tokens, hidden_size), jnp.bfloat16,
|
|
640
|
+
NamedSharding(self.runner.mesh, PartitionSpec(None,
|
|
641
|
+
None))),
|
|
642
|
+
]
|
|
643
|
+
# TODO(ranlihao): This will increase the precompilation latency. Find proper range for token_indices.
|
|
644
|
+
for padded_total_num_tokens in [
|
|
645
|
+
num_tokens,
|
|
646
|
+
min(num_tokens * 2, self.runner.num_tokens_paddings[-1])
|
|
647
|
+
]:
|
|
648
|
+
token_indices = self._create_dummy_tensor(
|
|
649
|
+
(padded_total_num_tokens, ), jnp.int32)
|
|
650
|
+
self._run_compilation(
|
|
651
|
+
"eagle3_filter_token_and_prepare_initial_inputs",
|
|
652
|
+
filter_token_and_prepare_initial_inputs_wrapper,
|
|
653
|
+
token_indices,
|
|
654
|
+
query_start_loc,
|
|
655
|
+
seq_lens,
|
|
656
|
+
input_ids,
|
|
657
|
+
aux_hidden_states,
|
|
658
|
+
attention_metadata,
|
|
659
|
+
next_token_ids,
|
|
660
|
+
device_array(
|
|
661
|
+
self.runner.mesh,
|
|
662
|
+
np.asarray([self.runner.input_batch.num_reqs],
|
|
663
|
+
dtype=jnp.int32)),
|
|
664
|
+
num_tokens=num_tokens,
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
def draft_model_fn_wrapper(
|
|
668
|
+
state,
|
|
669
|
+
kv_caches,
|
|
670
|
+
input_ids,
|
|
671
|
+
target_hidden_states,
|
|
672
|
+
attention_metadata,
|
|
673
|
+
):
|
|
674
|
+
kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
|
|
675
|
+
state, kv_caches, input_ids, target_hidden_states,
|
|
676
|
+
attention_metadata)
|
|
677
|
+
self.runner.kv_caches = kv_caches
|
|
678
|
+
return hidden_states
|
|
679
|
+
|
|
680
|
+
target_hidden_states = self._create_dummy_tensor(
|
|
681
|
+
(num_tokens, hidden_size), dtype,
|
|
682
|
+
NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
|
|
683
|
+
input_ids = self._create_dummy_tensor(
|
|
684
|
+
(num_tokens, ), jnp.int32,
|
|
685
|
+
NamedSharding(self.runner.mesh, PartitionSpec()))
|
|
686
|
+
self._run_compilation(
|
|
687
|
+
"eagle3_draft_model_fn",
|
|
688
|
+
draft_model_fn_wrapper,
|
|
689
|
+
self.runner.drafter.state,
|
|
690
|
+
self.runner.kv_caches,
|
|
691
|
+
input_ids,
|
|
692
|
+
target_hidden_states,
|
|
693
|
+
attention_metadata,
|
|
694
|
+
num_tokens=num_tokens,
|
|
695
|
+
)
|
|
696
|
+
target_token_ids = self._create_dummy_tensor((num_tokens, ),
|
|
697
|
+
jnp.int32)
|
|
698
|
+
|
|
699
|
+
self._run_compilation(
|
|
700
|
+
"eagle3_prepare_hidden_states_and_input_ids",
|
|
701
|
+
self.runner.drafter._prepare_hidden_states_and_input_ids,
|
|
702
|
+
aux_hidden_states,
|
|
703
|
+
query_start_loc,
|
|
704
|
+
target_token_ids,
|
|
705
|
+
next_token_ids,
|
|
706
|
+
device_array(
|
|
707
|
+
self.runner.mesh,
|
|
708
|
+
np.asarray([self.runner.input_batch.num_reqs],
|
|
709
|
+
dtype=jnp.int32)),
|
|
710
|
+
num_tokens=num_tokens,
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
attention_metadata.query_start_loc = jax.device_put(
|
|
714
|
+
attention_metadata.query_start_loc,
|
|
715
|
+
NamedSharding(self.runner.mesh, PartitionSpec()))
|
|
716
|
+
attention_metadata.input_positions = self._create_dummy_tensor(
|
|
717
|
+
(self.runner.max_num_reqs, ), jnp.int32)
|
|
718
|
+
self._run_compilation(
|
|
719
|
+
"draft_model_fn in a loop",
|
|
720
|
+
draft_model_fn_wrapper,
|
|
721
|
+
self.runner.drafter.state,
|
|
722
|
+
self.runner.kv_caches,
|
|
723
|
+
input_ids_loop,
|
|
724
|
+
target_hidden_state_loop,
|
|
725
|
+
attention_metadata,
|
|
726
|
+
num_tokens=num_tokens,
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
hidden_states = self._create_dummy_tensor(
|
|
730
|
+
(num_tokens, hidden_size), jnp.bfloat16,
|
|
731
|
+
NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
|
|
732
|
+
|
|
733
|
+
self._run_compilation(
|
|
734
|
+
"eagle3_select_inputs_for_loop_speculation",
|
|
735
|
+
self.runner.drafter._select_inputs_for_loop_speculation,
|
|
736
|
+
positions,
|
|
737
|
+
hidden_states,
|
|
738
|
+
hidden_states,
|
|
739
|
+
last_token_indices,
|
|
740
|
+
num_tokens=num_tokens,
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
self._run_compilation(
|
|
744
|
+
"eagle3_select_draft_token_ids",
|
|
745
|
+
self.runner.drafter._select_draft_token_ids,
|
|
746
|
+
hidden_states,
|
|
747
|
+
last_token_indices,
|
|
748
|
+
num_tokens=num_tokens,
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
def _precompile_structured_decoding(self) -> None:
|
|
752
|
+
logger.info(
|
|
753
|
+
"Compiling structured_decoding with different input shapes.")
|
|
754
|
+
if self.runner.vllm_config.sharding_config.total_dp_size > 1:
|
|
755
|
+
logger.warning(
|
|
756
|
+
"Structured decoding precompilation skipped since structured decoding is not supported with DP."
|
|
757
|
+
)
|
|
758
|
+
return
|
|
759
|
+
for num_reqs in self.runner.num_reqs_paddings:
|
|
760
|
+
dummy_logits = self._create_dummy_tensor(
|
|
761
|
+
(num_reqs, self.runner.vocab_size), jnp.bfloat16)
|
|
762
|
+
dummy_require_struct_decoding = self.runner.require_structured_out_cpu[:
|
|
763
|
+
num_reqs]
|
|
764
|
+
dummy_grammar_bitmask = self.runner.grammar_bitmask_cpu[:num_reqs]
|
|
765
|
+
|
|
766
|
+
(dummy_logits, dummy_require_struct_decoding,
|
|
767
|
+
dummy_grammar_bitmask, arange) = device_array(
|
|
768
|
+
self.runner.mesh,
|
|
769
|
+
(dummy_logits, dummy_require_struct_decoding,
|
|
770
|
+
dummy_grammar_bitmask, self.runner.structured_decode_arange))
|
|
771
|
+
|
|
772
|
+
self._run_compilation(
|
|
773
|
+
"structured_decode",
|
|
774
|
+
self.runner.structured_decoding_manager.structured_decode_fn,
|
|
775
|
+
dummy_require_struct_decoding,
|
|
776
|
+
dummy_grammar_bitmask,
|
|
777
|
+
dummy_logits,
|
|
778
|
+
arange,
|
|
779
|
+
num_reqs=num_reqs,
|
|
780
|
+
)
|