tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,890 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import time
16
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import numpy as np
21
+ import vllm.envs as vllm_envs
22
+ from jax.sharding import NamedSharding, PartitionSpec
23
+
24
+ import tpu_inference.envs as envs
25
+ from tpu_inference.core.disagg_utils import is_disagg_enabled
26
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.layers.common.sharding import ShardingAxisName
28
+ from tpu_inference.layers.jax.sample.sampling import sample
29
+ from tpu_inference.layers.jax.sample.sampling_metadata import \
30
+ TPUSupportedSamplingMetadata
31
+ from tpu_inference.logger import init_logger
32
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
33
+ JaxIntermediateTensors
34
+ from tpu_inference.utils import device_array
35
+
36
+ if TYPE_CHECKING:
37
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
38
+
39
+ logger = init_logger(__name__)
40
+
41
+ # Constants for block bucketing in disaggregated utilities
42
+ BLOCK_BUCKETS = [1, 2, 4, 8, 16, 32, 64]
43
+
44
+
45
+ class CompilationManager:
46
+
47
+ def __init__(self, runner: "TPUModelRunner"):
48
+ self.runner = runner
49
+ self._sampling_precompiled = False
50
+ self._gather_logprobs_precompiled = False
51
+ if not vllm_envs.VLLM_DISABLE_COMPILE_CACHE:
52
+ logger.info("Enabling JAX compile cache.")
53
+ jax.config.update("jax_compilation_cache_dir",
54
+ vllm_envs.VLLM_XLA_CACHE_PATH)
55
+
56
+ def _create_dummy_tensor(self,
57
+ shape: Tuple[int, ...],
58
+ dtype: Any,
59
+ sharding: Optional[NamedSharding] = None) -> Any:
60
+ """Helper to create dummy tensors for precompilation."""
61
+ tensor = jnp.ones(shape, dtype=dtype)
62
+ if sharding:
63
+ return device_array(self.runner.mesh, tensor, sharding=sharding)
64
+ return device_array(self.runner.mesh, tensor)
65
+
66
+ def _should_skip_padding_combination(self, outer_val: int, inner_val: int,
67
+ only_equal: bool) -> bool:
68
+ """Helper to determine if we should skip this padding combination."""
69
+ if only_equal:
70
+ return inner_val != outer_val
71
+ return inner_val > outer_val
72
+
73
+ def _run_compilation(self, name: str, fn: Callable, *args,
74
+ **kwargs) -> None:
75
+ logger.info(f"Precompile {name} --> {kwargs}")
76
+ start = time.perf_counter()
77
+ result = fn(*args)
78
+ if result is not None:
79
+ if isinstance(result, tuple):
80
+ for r in result:
81
+ r.block_until_ready()
82
+ else:
83
+ result.block_until_ready()
84
+ end = time.perf_counter()
85
+ logger.info("Compilation finished in %.2f [secs].", end - start)
86
+
87
+ def capture_model(self) -> None:
88
+ if envs.SKIP_JAX_PRECOMPILE or self.runner.model_config.enforce_eager:
89
+ return
90
+ logger.info("Precompile all the subgraphs with possible input shapes.")
91
+
92
+ with self.runner.maybe_setup_dummy_loras(self.runner.lora_config):
93
+ self._precompile_backbone_text_only()
94
+ if self.runner.is_multimodal_model:
95
+ self.runner.precompile_vision_encoder_fn(
96
+ self._run_compilation, )
97
+ self._precompile_input_embeddings_merger()
98
+ self._precompile_backbone_with_inputs_embeds()
99
+ if self.runner.scheduler_config.async_scheduling:
100
+ self._precompile_substitute_placeholder_token()
101
+ if not self.runner.is_last_rank:
102
+ return
103
+ self._precompile_select_from_array()
104
+ self._precompile_compute_logits()
105
+ # Skip sampling if already precompiled before KV cache allocation
106
+ if not self._sampling_precompiled:
107
+ self._precompile_sampling()
108
+ self._precompile_disagg_utils()
109
+ # Skip gather_logprobs if already precompiled before KV cache allocation
110
+ if not self._gather_logprobs_precompiled:
111
+ self._precompile_gather_logprobs()
112
+ self._precompile_structured_decoding()
113
+ if self.runner.speculative_config:
114
+ self._precompile_speculative_decoding()
115
+
116
+ def _precompile_input_embeddings_merger(self) -> None:
117
+ for num_tokens in self.runner.num_tokens_paddings:
118
+ hidden_size = self.runner.vllm_config.model_config.get_hidden_size(
119
+ )
120
+ sharding = NamedSharding(self.runner.mesh, PartitionSpec())
121
+ dummy_multimodal_embeddings = self._create_dummy_tensor(
122
+ (num_tokens, hidden_size),
123
+ self.runner.vllm_config.model_config.dtype,
124
+ sharding=sharding)
125
+ dummy_input_ids = self._create_dummy_tensor((num_tokens, ),
126
+ jnp.int32)
127
+
128
+ self._run_compilation(
129
+ "input_embeddings_merger",
130
+ self.runner.get_input_embeddings_fn,
131
+ self.runner.state,
132
+ dummy_input_ids,
133
+ dummy_multimodal_embeddings,
134
+ num_tokens=num_tokens,
135
+ )
136
+
137
+ self._run_compilation(
138
+ "input_embeddings_merger_text_only",
139
+ self.runner.get_input_embeddings_fn,
140
+ self.runner.state,
141
+ dummy_input_ids,
142
+ None,
143
+ num_tokens=num_tokens,
144
+ )
145
+
146
+ def _precompile_backbone_helper(self,
147
+ name,
148
+ *,
149
+ input_ids,
150
+ positions,
151
+ inputs_embeds,
152
+ intermediate_tensors=None,
153
+ is_first_rank=True,
154
+ is_last_rank=True) -> None:
155
+ num_tokens = None
156
+ if input_ids is not None:
157
+ num_tokens = input_ids.shape[0]
158
+ elif inputs_embeds is not None:
159
+ num_tokens = inputs_embeds.shape[0]
160
+ assert num_tokens is not None
161
+
162
+ dp_size = self.runner.vllm_config.sharding_config.total_dp_size
163
+ dp_sharding = NamedSharding(
164
+ self.runner.mesh, PartitionSpec(
165
+ ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
166
+
167
+ # Keep existing pattern for complex array operations
168
+ seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
169
+ jnp.int32, dp_sharding)
170
+ query_start_loc = self._create_dummy_tensor(
171
+ (self.runner.max_num_reqs + dp_size, ), jnp.int32, dp_sharding)
172
+
173
+ # Keep existing pattern for specific value arrays
174
+ request_distribution = np.array([0, 0, 0] * dp_size, dtype=np.int32)
175
+ request_distribution = device_array(self.runner.mesh,
176
+ request_distribution,
177
+ sharding=dp_sharding)
178
+
179
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
180
+ uniform_attention_metadata: AttentionMetadata = None
181
+ for kv_cache_gid, kv_cache_group in enumerate(
182
+ self.runner.kv_cache_config.kv_cache_groups):
183
+ block_tables = self.runner.block_tables_cpu[
184
+ kv_cache_gid][:self.runner.max_num_reqs]
185
+ block_tables = block_tables.reshape(-1)
186
+ block_tables = device_array(self.runner.mesh,
187
+ block_tables,
188
+ sharding=dp_sharding)
189
+
190
+ attention_metadata_gid = AttentionMetadata(
191
+ input_positions=positions,
192
+ block_tables=block_tables,
193
+ seq_lens=seq_lens,
194
+ query_start_loc=query_start_loc,
195
+ request_distribution=request_distribution,
196
+ )
197
+ if not self.runner.use_hybrid_kvcache:
198
+ # all layers share the same attention metadata
199
+ uniform_attention_metadata = attention_metadata_gid
200
+ else:
201
+ for layer_name in kv_cache_group.layer_names:
202
+ attention_metadata_per_layer[
203
+ layer_name] = attention_metadata_gid
204
+
205
+ def model_fn_wrapper(
206
+ state,
207
+ kv_caches,
208
+ input_ids,
209
+ attention_metadata,
210
+ positions,
211
+ inputs_embeds,
212
+ layer_name_to_kvcache_index,
213
+ lora_metadata,
214
+ intermediate_tensors,
215
+ is_first_rank,
216
+ is_last_rank,
217
+ ):
218
+ kv_caches, hidden_states, _ = self.runner.model_fn(
219
+ state, kv_caches, input_ids, attention_metadata, inputs_embeds,
220
+ positions, layer_name_to_kvcache_index, lora_metadata,
221
+ intermediate_tensors, is_first_rank, is_last_rank)
222
+ self.runner.kv_caches = kv_caches
223
+ return hidden_states
224
+
225
+ with self.runner.maybe_select_dummy_loras(
226
+ self.runner.lora_config, np.array([num_tokens],
227
+ dtype=np.int32)):
228
+ lora_metadata = self.runner.lora_utils.extract_lora_metadata()
229
+ if self.runner.use_hybrid_kvcache:
230
+ attention_metadata = attention_metadata_per_layer
231
+ else:
232
+ attention_metadata = uniform_attention_metadata
233
+ self._run_compilation(
234
+ name,
235
+ model_fn_wrapper,
236
+ self.runner.state,
237
+ self.runner.kv_caches,
238
+ input_ids,
239
+ attention_metadata,
240
+ positions,
241
+ inputs_embeds,
242
+ tuple(self.runner.layer_name_to_kvcache_index.items()),
243
+ lora_metadata,
244
+ intermediate_tensors,
245
+ is_first_rank,
246
+ is_last_rank,
247
+ num_tokens=num_tokens,
248
+ )
249
+
250
+ def _precompile_substitute_placeholder_token(self) -> None:
251
+ """Precompiles the token substitution function for all expected input shapes.
252
+
253
+ It iterates through all potential padded token lengths
254
+ (`num_tokens_paddings`) and request batch sizes (`num_reqs_paddings`)
255
+ that the scheduler is expected to handle, ensuring a compiled version
256
+ is ready for each combination.
257
+ """
258
+
259
+ for num_tokens in self.runner.num_tokens_paddings:
260
+ dp_sharding = NamedSharding(
261
+ self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
262
+ ) if self.runner.vllm_config.sharding_config.total_dp_size > 1 else None
263
+
264
+ for num_reqs in self.runner.num_reqs_paddings:
265
+ padded_token_in_tpu_cur_input_indices = np.zeros(
266
+ (num_tokens, ), dtype=np.int32)
267
+ padded_token_in_tpu_pre_next_tokens_indices = np.zeros(
268
+ (num_tokens, ), dtype=jnp.int32)
269
+ (padded_token_in_tpu_cur_input_indices,
270
+ padded_token_in_tpu_pre_next_tokens_indices) = device_array(
271
+ self.runner.mesh,
272
+ (padded_token_in_tpu_cur_input_indices,
273
+ padded_token_in_tpu_pre_next_tokens_indices))
274
+
275
+ input_ids = self._create_dummy_tensor((num_tokens, ),
276
+ jnp.int32, dp_sharding)
277
+ # Need align to the sampling output
278
+ next_tokens = self._create_dummy_tensor(
279
+ (num_reqs, ),
280
+ jnp.int32,
281
+ sharding=dp_sharding,
282
+ )
283
+ placeholder_num = 1
284
+ self._run_compilation(
285
+ "_substitute_placeholder_token_fn",
286
+ self.runner._substitute_placeholder_token_fn,
287
+ input_ids,
288
+ padded_token_in_tpu_cur_input_indices,
289
+ padded_token_in_tpu_pre_next_tokens_indices,
290
+ next_tokens,
291
+ placeholder_num,
292
+ num_tokens=num_tokens,
293
+ num_reqs=num_reqs,
294
+ )
295
+
296
+ def _precompile_backbone_text_only(self) -> None:
297
+ hidden_size = self.runner.model_config.get_hidden_size()
298
+ for num_tokens in self.runner.num_tokens_paddings:
299
+ dp_sharding = NamedSharding(
300
+ self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
301
+ ) if self.runner.vllm_config.sharding_config.total_dp_size > 1 else None
302
+
303
+ input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32,
304
+ dp_sharding)
305
+ positions = self._create_dummy_tensor((num_tokens, ), jnp.int32,
306
+ dp_sharding)
307
+ is_first_rank = self.runner.is_first_rank
308
+ is_last_rank = self.runner.is_last_rank
309
+ if is_first_rank:
310
+ intermediate_tensors = None
311
+ else:
312
+ hidden_states = self._create_dummy_tensor(
313
+ (num_tokens, hidden_size), jnp.bfloat16)
314
+ residual = self._create_dummy_tensor((num_tokens, hidden_size),
315
+ jnp.bfloat16)
316
+ intermediate_tensors = JaxIntermediateTensors(
317
+ tensors={
318
+ "hidden_states": hidden_states,
319
+ "residual": residual
320
+ })
321
+ self._precompile_backbone_helper(
322
+ f"worker{self.runner.rank} backbone",
323
+ input_ids=input_ids,
324
+ positions=positions,
325
+ inputs_embeds=None,
326
+ intermediate_tensors=intermediate_tensors,
327
+ is_first_rank=is_first_rank,
328
+ is_last_rank=is_last_rank)
329
+
330
+ def _precompile_backbone_with_inputs_embeds(self) -> None:
331
+ hidden_size = self.runner.model_config.get_hidden_size()
332
+ dtype = self.runner.model_config.dtype
333
+ for num_tokens in self.runner.num_tokens_paddings:
334
+ inputs_embeds = self._create_dummy_tensor(
335
+ (num_tokens, hidden_size), dtype)
336
+ if self.runner.uses_mrope:
337
+ positions = self._create_dummy_tensor((3, num_tokens),
338
+ jnp.int32)
339
+ else:
340
+ positions = self._create_dummy_tensor((num_tokens, ),
341
+ jnp.int32)
342
+ is_first_rank = self.runner.is_first_rank
343
+ is_last_rank = self.runner.is_last_rank
344
+ if not is_first_rank:
345
+ hidden_states = self._create_dummy_tensor(
346
+ (num_tokens, hidden_size), jnp.bfloat16)
347
+ residual = self._create_dummy_tensor((num_tokens, hidden_size),
348
+ jnp.bfloat16)
349
+ intermediate_tensors = JaxIntermediateTensors(
350
+ tensors={
351
+ "hidden_states": hidden_states,
352
+ "residual": residual
353
+ })
354
+ else:
355
+ intermediate_tensors = None
356
+ self._precompile_backbone_helper(
357
+ f"worker{self.runner.rank} backbone with embeds",
358
+ input_ids=None,
359
+ positions=positions,
360
+ inputs_embeds=inputs_embeds,
361
+ intermediate_tensors=intermediate_tensors,
362
+ is_first_rank=is_first_rank,
363
+ is_last_rank=is_last_rank)
364
+
365
+ def _precompile_select_from_array_helper(
366
+ self,
367
+ name: str,
368
+ source_paddings: List[int],
369
+ indices_paddings: List[int],
370
+ hidden_dim: int,
371
+ input_sharding: Optional[NamedSharding] = None,
372
+ indices_sharding: Optional[NamedSharding] = None,
373
+ only_equal_paddings: bool = False,
374
+ check_should_skip_padding: bool = True,
375
+ ) -> None:
376
+ """Precompile select_from_array operations with various input shape combinations.
377
+
378
+ This helper method generates and precompiles the select_from_array function for different
379
+ combinations of array sizes and index counts. The operation being precompiled is
380
+ array[indices] where:
381
+ - array has shape (array_size, hidden_dim)
382
+ - indices has shape (indices_count,)
383
+ - result has shape (indices_count, hidden_dim)
384
+
385
+ This is essential for TPU compilation as JAX needs to precompile functions with all
386
+ possible input shapes that will be encountered during runtime.
387
+
388
+ Args:
389
+ name: Descriptive name for logging purposes (e.g., "select all logits")
390
+ source_paddings: List of possible sizes for the array being indexed (first dimension)
391
+ indices_paddings: List of possible counts of indices to select
392
+ hidden_dim: Second dimension size of the array (e.g., hidden_size or vocab_size)
393
+ sharding: Optional sharding specification for distributed computation
394
+ only_equal_paddings: If True, only compile when array size equals indices count
395
+ check_should_skip_padding: If True, check whether to skip certain padding combinations to reduce compilation time
396
+ """
397
+ logger.info(f"Compiling select_from_array for {name}.")
398
+ for array_size in source_paddings:
399
+ for indices_count in indices_paddings:
400
+ if check_should_skip_padding and self._should_skip_padding_combination(
401
+ array_size, indices_count, only_equal_paddings):
402
+ continue
403
+
404
+ input_tensor = self._create_dummy_tensor(
405
+ (array_size, hidden_dim), jnp.bfloat16, input_sharding)
406
+ indices_to_select = self._create_dummy_tensor(
407
+ (indices_count, ), jnp.int32, indices_sharding)
408
+
409
+ self._run_compilation(
410
+ f"select_from_array [{name}]",
411
+ self.runner._select_from_array_fn, input_tensor,
412
+ indices_to_select, **{
413
+ "array_size": array_size,
414
+ "index_size": indices_count
415
+ })
416
+
417
+ def _precompile_select_from_array(self) -> None:
418
+ logger.info("Compiling select_from_array with different input shapes.")
419
+ hsize = self.runner.model_config.get_hidden_size()
420
+
421
+ if self.runner.speculative_config:
422
+ index_paddings = self.runner.num_logits_paddings
423
+ else:
424
+ index_paddings = self.runner.num_reqs_paddings
425
+ dp_sharding = NamedSharding(self.runner.mesh,
426
+ PartitionSpec(ShardingAxisName.ATTN_DATA))
427
+ hidden_states_sharding = NamedSharding(
428
+ self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
429
+ dp_size = self.runner.vllm_config.sharding_config.total_dp_size
430
+ self._precompile_select_from_array_helper(
431
+ name=f"worker{self.runner.rank} select all logits",
432
+ source_paddings=self.runner.num_tokens_paddings,
433
+ indices_paddings=index_paddings,
434
+ hidden_dim=hsize,
435
+ input_sharding=hidden_states_sharding,
436
+ indices_sharding=dp_sharding if dp_size > 1 else None,
437
+ )
438
+
439
+ if self.runner.speculative_config:
440
+ vocab_size = self.runner.model_config.get_vocab_size()
441
+ self._precompile_select_from_array_helper(
442
+ name=
443
+ f"worker{self.runner.rank} select bonus tokens for spec decoding",
444
+ source_paddings=self.runner.num_logits_paddings,
445
+ indices_paddings=self.runner.num_reqs_paddings,
446
+ hidden_dim=vocab_size,
447
+ input_sharding=NamedSharding(self.runner.mesh,
448
+ PartitionSpec(None, "model")),
449
+ )
450
+ self._precompile_select_from_array_helper(
451
+ name=
452
+ f"worker{self.runner.rank} select target tokens for spec decoding",
453
+ source_paddings=self.runner.num_logits_paddings,
454
+ indices_paddings=self.runner.num_logits_paddings,
455
+ hidden_dim=vocab_size,
456
+ input_sharding=NamedSharding(self.runner.mesh,
457
+ PartitionSpec(None, "model")),
458
+ only_equal_paddings=True,
459
+ )
460
+
461
+ def _precompile_compute_logits(self) -> None:
462
+ logger.info("Compiling compute_logits with different input shapes.")
463
+ hsize = self.runner.model_config.get_hidden_size()
464
+ leading_shape = self.runner.num_reqs_paddings if not self.runner.speculative_config else self.runner.num_logits_paddings
465
+ dp_sharding = NamedSharding(self.runner.mesh,
466
+ PartitionSpec(ShardingAxisName.ATTN_DATA))
467
+ for num_reqs in leading_shape:
468
+ hidden_states = self._create_dummy_tensor(
469
+ (num_reqs, hsize), jnp.bfloat16, dp_sharding)
470
+ with self.runner.maybe_select_dummy_loras(
471
+ self.runner.lora_config,
472
+ np.array([num_reqs], dtype=np.int32)):
473
+ lora_metadata = self.runner.lora_utils.extract_lora_metadata()
474
+ self._run_compilation(
475
+ f"worker{self.runner.rank} compute_logits",
476
+ self.runner.compute_logits_fn,
477
+ self.runner.state,
478
+ hidden_states,
479
+ lora_metadata,
480
+ num_reqs=num_reqs,
481
+ )
482
+
483
+ def _precompile_sampling(self) -> None:
484
+ logger.info("Compiling sampling with different input shapes.")
485
+ hsize = self.runner.model_config.get_vocab_size()
486
+ for num_reqs in self.runner.num_reqs_paddings:
487
+ logits_sharding = NamedSharding(
488
+ self.runner.mesh,
489
+ PartitionSpec(ShardingAxisName.MLP_DATA,
490
+ ShardingAxisName.MLP_TENSOR))
491
+ dp_size = self.runner.vllm_config.sharding_config.total_dp_size
492
+ sampling_metadata_sharding = NamedSharding(
493
+ self.runner.mesh, PartitionSpec(
494
+ ShardingAxisName.MLP_DATA)) if dp_size > 1 else None
495
+ logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
496
+ logits_sharding)
497
+ for do_sampling in (True, False):
498
+ if do_sampling:
499
+ temperature = np.full((num_reqs, ), 0.7, dtype=np.float32)
500
+ top_k = np.full((num_reqs, ), 20, dtype=np.int32)
501
+ top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
502
+ (temperature, top_k,
503
+ top_p) = device_array(self.runner.mesh,
504
+ (temperature, top_k, top_p),
505
+ sharding=sampling_metadata_sharding)
506
+ else:
507
+ temperature = None
508
+ top_k = None
509
+ top_p = None
510
+
511
+ sampling_metadata = TPUSupportedSamplingMetadata(
512
+ temperature=temperature,
513
+ top_k=top_k,
514
+ top_p=top_p,
515
+ do_sampling=do_sampling,
516
+ )
517
+ self._run_compilation(
518
+ f"worker{self.runner.rank} sample",
519
+ sample,
520
+ self.runner.rng_params_for_sampling,
521
+ self.runner.mesh,
522
+ logits,
523
+ sampling_metadata,
524
+ num_reqs=num_reqs,
525
+ do_sampling=do_sampling,
526
+ )
527
+
528
+ self._sampling_precompiled = True
529
+
530
+ def _precompile_disagg_utils(self) -> None:
531
+ if not is_disagg_enabled():
532
+ return
533
+ logger.info(
534
+ "Compiling disaggregated util with different input shapes.")
535
+ block_size = self.runner.block_size
536
+ for num_blocks in range(1, self.runner.max_num_blocks_per_req // 2):
537
+ logger.info(
538
+ f"Precompile slice and insert for num_blocks {num_blocks}")
539
+ block_numbers = list(range(1, num_blocks + 1))
540
+ kv_cache_slices = self.runner.kv_cache_manager.get_kv_cache_for_block_ids(
541
+ block_numbers)
542
+ # Prevent the slices from getting freed by insert before finishing this operation
543
+ for layer_cache in kv_cache_slices:
544
+ layer_cache.block_until_ready()
545
+ self.runner.kv_caches = self.runner.kv_cache_manager._jitted_insert_continuous_kv_cache(
546
+ block_size,
547
+ self.runner.kv_caches,
548
+ kv_cache_slices,
549
+ block_numbers[0],
550
+ )
551
+ for layer_cache in self.runner.kv_caches:
552
+ layer_cache.block_until_ready()
553
+
554
+ def _precompile_gather_logprobs(self) -> None:
555
+ logger.info("Compiling gather_logprobs with different input shapes.")
556
+ hsize = self.runner.model_config.get_vocab_size()
557
+ for num_reqs in self.runner.num_reqs_paddings:
558
+ logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
559
+ token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
560
+ self._run_compilation(
561
+ f"worker{self.runner.rank} gather_logprobs",
562
+ self.runner._compute_and_gather_logprobs,
563
+ logits,
564
+ token_ids,
565
+ self.runner.model_config.max_logprobs,
566
+ num_reqs=num_reqs,
567
+ )
568
+
569
+ self._gather_logprobs_precompiled = True
570
+
571
+ def _precompile_speculative_decoding(self) -> None:
572
+ logger.info(
573
+ "Compiling speculative_decoding with different input shapes.")
574
+ self._precompile_rejection_sampler()
575
+ if self.runner.speculative_config.method == "eagle3":
576
+ self._precompile_eagle3_helpers()
577
+
578
+ def _precompile_rejection_sampler(self) -> None:
579
+ logger.info("Compiling rejection_sampler with different input shapes.")
580
+ vocab_size = self.runner.model_config.get_vocab_size()
581
+ for num_logits in self.runner.num_logits_paddings:
582
+ for num_reqs in self.runner.num_reqs_paddings:
583
+ sharding = NamedSharding(self.runner.mesh,
584
+ PartitionSpec(None, "model"))
585
+ target_probs = self._create_dummy_tensor(
586
+ (num_logits, vocab_size), jnp.bfloat16, sharding)
587
+ draft_token_ids = self._create_dummy_tensor((num_logits, ),
588
+ jnp.int32)
589
+ num_draft_tokens = self._create_dummy_tensor((num_reqs, ),
590
+ jnp.int32)
591
+ bonus_token_ids = self._create_dummy_tensor((num_reqs, ),
592
+ jnp.int32)
593
+
594
+ for do_sampling in (False, True):
595
+ draft_probs = None
596
+ if do_sampling:
597
+ compilation_name = "random_rejection_sampler"
598
+ temperature = self._create_dummy_tensor((num_reqs, ),
599
+ np.float32)
600
+ top_k = self._create_dummy_tensor((num_reqs, ),
601
+ np.int32)
602
+ top_p = self._create_dummy_tensor((num_reqs, ),
603
+ np.float32)
604
+ sampling_metadata = TPUSupportedSamplingMetadata(
605
+ temperature=temperature,
606
+ top_k=top_k,
607
+ top_p=top_p,
608
+ do_sampling=do_sampling)
609
+ else:
610
+ compilation_name = "greedy_rejection_sampler"
611
+ sampling_metadata = TPUSupportedSamplingMetadata(
612
+ do_sampling=do_sampling)
613
+
614
+ self._run_compilation(
615
+ f"worker{self.runner.rank} {compilation_name}",
616
+ self.runner.rejection_sampler,
617
+ draft_token_ids,
618
+ num_draft_tokens,
619
+ draft_probs,
620
+ target_probs,
621
+ bonus_token_ids,
622
+ sampling_metadata,
623
+ self.runner.rng_params_for_sampling,
624
+ num_logits=num_logits,
625
+ num_reqs=num_reqs,
626
+ do_sampling=do_sampling,
627
+ )
628
+
629
+ def _precompile_eagle3_helpers(self) -> None:
630
+ logger.info(
631
+ "Compiling eagle3 jitted helpers with different input shapes.")
632
+ target_hidden_size = self.runner.model_config.get_hidden_size()
633
+ draft_hidden_size = self.runner.speculative_config.draft_model_config.get_hidden_size(
634
+ )
635
+ dtype = self.runner.model_config.dtype
636
+
637
+ num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
638
+ draft_kv_cache_group_id = num_kv_cache_groups - 1
639
+ block_tables = self.runner.input_batch.block_table[
640
+ draft_kv_cache_group_id].get_cpu_tensor().reshape(-1)
641
+ block_tables = jax.device_put(
642
+ block_tables, NamedSharding(self.runner.mesh,
643
+ PartitionSpec(None, )))
644
+
645
+ selected_positions = self._create_dummy_tensor(
646
+ (self.runner.max_num_reqs, ), jnp.int32)
647
+ seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
648
+ jnp.int32)
649
+ query_start_loc = self._create_dummy_tensor(
650
+ (self.runner.max_num_reqs + 1, ), jnp.int32)
651
+ self._run_compilation(
652
+ "_update_inputs_for_loop_speculation for the first loop",
653
+ self.runner.drafter._update_inputs_for_loop_speculation,
654
+ selected_positions, seq_lens, block_tables)
655
+ self._run_compilation(
656
+ "_update_inputs_for_loop_speculation for the subsequent loops",
657
+ self.runner.drafter._update_inputs_for_loop_speculation,
658
+ selected_positions, seq_lens, block_tables)
659
+
660
+ request_distribution = np.array([0, 0, 0], dtype=np.int32)
661
+ request_distribution = device_array(self.runner.mesh,
662
+ request_distribution)
663
+
664
+ for num_reqs_padding in self.runner.num_reqs_paddings:
665
+ for i in range(1, self.runner.drafter.num_speculative_tokens + 1):
666
+ draft_token_ids_list = [
667
+ self._create_dummy_tensor(
668
+ (num_reqs_padding, ), jnp.int32,
669
+ NamedSharding(self.runner.mesh, PartitionSpec()))
670
+ for _ in range(i)
671
+ ]
672
+ self._run_compilation(
673
+ "eagle3_stack_draft_token_ids",
674
+ self.runner.drafter._stack_draft_token_ids,
675
+ draft_token_ids_list,
676
+ num_reqs=num_reqs_padding,
677
+ draft_token_ids_list_length=len(draft_token_ids_list))
678
+
679
+ for num_logits in self.runner.num_logits_paddings:
680
+ hidden_states = self._create_dummy_tensor(
681
+ (num_logits, draft_hidden_size), jnp.bfloat16)
682
+ self._run_compilation(
683
+ "eagle3_get_draft_token_ids",
684
+ self.runner.drafter._get_draft_token_ids,
685
+ self.runner.drafter.state,
686
+ hidden_states,
687
+ num_logits=num_logits,
688
+ )
689
+
690
+ input_ids_loop = self._create_dummy_tensor(
691
+ (self.runner.max_num_reqs, ), jnp.int32,
692
+ NamedSharding(self.runner.mesh, PartitionSpec()))
693
+ draft_hidden_state_loop = self._create_dummy_tensor(
694
+ (self.runner.max_num_reqs, draft_hidden_size), dtype,
695
+ NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
696
+ next_token_ids = self._create_dummy_tensor(
697
+ (self.runner.max_num_reqs, ), jnp.int32)
698
+ last_token_indices = self._create_dummy_tensor(
699
+ (self.runner.max_num_reqs, ), jnp.int32)
700
+ for num_tokens in self.runner.num_tokens_paddings:
701
+ aux_hidden_states = [
702
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
703
+ dtype),
704
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
705
+ dtype),
706
+ self._create_dummy_tensor((num_tokens, target_hidden_size),
707
+ dtype),
708
+ ]
709
+
710
+ positions = self._create_dummy_tensor((num_tokens, ), jnp.int32)
711
+ attention_metadata = AttentionMetadata(
712
+ input_positions=positions,
713
+ block_tables=block_tables,
714
+ seq_lens=seq_lens,
715
+ query_start_loc=query_start_loc,
716
+ request_distribution=request_distribution,
717
+ )
718
+
719
+ def filter_token_and_prepare_initial_inputs_wrapper(
720
+ token_indices,
721
+ query_start_loc,
722
+ seq_lens,
723
+ input_ids,
724
+ aux_hidden_states,
725
+ attention_metadata,
726
+ next_token_ids,
727
+ num_reqs,
728
+ ):
729
+ target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
730
+ self.runner.drafter.state, token_indices, query_start_loc,
731
+ seq_lens, input_ids, aux_hidden_states, attention_metadata,
732
+ next_token_ids, num_reqs)
733
+ return target_hidden_states, input_ids, last_token_indices
734
+
735
+ input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
736
+ aux_hidden_states = [
737
+ self._create_dummy_tensor(
738
+ (num_tokens, target_hidden_size), jnp.bfloat16,
739
+ NamedSharding(self.runner.mesh, PartitionSpec(None,
740
+ None))),
741
+ self._create_dummy_tensor(
742
+ (num_tokens, target_hidden_size), jnp.bfloat16,
743
+ NamedSharding(self.runner.mesh, PartitionSpec(None,
744
+ None))),
745
+ self._create_dummy_tensor(
746
+ (num_tokens, target_hidden_size), jnp.bfloat16,
747
+ NamedSharding(self.runner.mesh, PartitionSpec(None,
748
+ None))),
749
+ ]
750
+ # TODO(ranlihao): This will increase the precompilation latency. Find proper range for token_indices.
751
+ for padded_total_num_tokens in [
752
+ num_tokens,
753
+ min(num_tokens * 2, self.runner.num_tokens_paddings[-1])
754
+ ]:
755
+ token_indices = self._create_dummy_tensor(
756
+ (padded_total_num_tokens, ), jnp.int32)
757
+ self._run_compilation(
758
+ "eagle3_filter_token_and_prepare_initial_inputs",
759
+ filter_token_and_prepare_initial_inputs_wrapper,
760
+ token_indices,
761
+ query_start_loc,
762
+ seq_lens,
763
+ input_ids,
764
+ aux_hidden_states,
765
+ attention_metadata,
766
+ next_token_ids,
767
+ device_array(
768
+ self.runner.mesh,
769
+ np.asarray([self.runner.input_batch.num_reqs],
770
+ dtype=jnp.int32)),
771
+ num_tokens=num_tokens,
772
+ )
773
+
774
+ def draft_model_fn_wrapper(
775
+ state,
776
+ kv_caches,
777
+ input_ids,
778
+ draft_hidden_states,
779
+ attention_metadata,
780
+ ):
781
+ kv_caches, hidden_states, _ = self.runner.drafter.model_fn(
782
+ state, kv_caches, input_ids, draft_hidden_states,
783
+ attention_metadata)
784
+ self.runner.kv_caches = kv_caches
785
+ return hidden_states
786
+
787
+ draft_hidden_states = self._create_dummy_tensor(
788
+ (num_tokens, draft_hidden_size), dtype,
789
+ NamedSharding(self.runner.mesh, PartitionSpec(None, "model")))
790
+ input_ids = self._create_dummy_tensor(
791
+ (num_tokens, ), jnp.int32,
792
+ NamedSharding(self.runner.mesh, PartitionSpec()))
793
+ self._run_compilation(
794
+ "eagle3_draft_model_fn",
795
+ draft_model_fn_wrapper,
796
+ self.runner.drafter.state,
797
+ self.runner.kv_caches,
798
+ input_ids,
799
+ draft_hidden_states,
800
+ attention_metadata,
801
+ num_tokens=num_tokens,
802
+ )
803
+ target_token_ids = self._create_dummy_tensor((num_tokens, ),
804
+ jnp.int32)
805
+
806
+ self._run_compilation(
807
+ "eagle3_prepare_hidden_states_and_input_ids",
808
+ self.runner.drafter._prepare_hidden_states_and_input_ids,
809
+ self.runner.drafter.state,
810
+ aux_hidden_states,
811
+ query_start_loc,
812
+ target_token_ids,
813
+ next_token_ids,
814
+ device_array(
815
+ self.runner.mesh,
816
+ np.asarray([self.runner.input_batch.num_reqs],
817
+ dtype=jnp.int32)),
818
+ num_tokens=num_tokens,
819
+ )
820
+
821
+ attention_metadata.query_start_loc = jax.device_put(
822
+ attention_metadata.query_start_loc,
823
+ NamedSharding(self.runner.mesh, PartitionSpec()))
824
+ attention_metadata.input_positions = self._create_dummy_tensor(
825
+ (self.runner.max_num_reqs, ), jnp.int32)
826
+ self._run_compilation(
827
+ "draft_model_fn in a loop",
828
+ draft_model_fn_wrapper,
829
+ self.runner.drafter.state,
830
+ self.runner.kv_caches,
831
+ input_ids_loop,
832
+ draft_hidden_state_loop,
833
+ attention_metadata,
834
+ num_tokens=num_tokens,
835
+ )
836
+
837
+ hidden_states = self._create_dummy_tensor(
838
+ (num_tokens, draft_hidden_size), jnp.bfloat16,
839
+ NamedSharding(self.runner.mesh, PartitionSpec(None, None)))
840
+
841
+ self._run_compilation(
842
+ "eagle3_select_inputs_for_loop_speculation",
843
+ self.runner.drafter._select_inputs_for_loop_speculation,
844
+ self.runner.drafter.state,
845
+ positions,
846
+ hidden_states,
847
+ hidden_states,
848
+ last_token_indices,
849
+ num_tokens=num_tokens,
850
+ )
851
+
852
+ self._run_compilation(
853
+ "eagle3_select_draft_token_ids",
854
+ self.runner.drafter._select_draft_token_ids,
855
+ self.runner.drafter.state,
856
+ hidden_states,
857
+ last_token_indices,
858
+ num_tokens=num_tokens,
859
+ )
860
+
861
+ def _precompile_structured_decoding(self) -> None:
862
+ logger.info(
863
+ "Compiling structured_decoding with different input shapes.")
864
+ if self.runner.vllm_config.sharding_config.total_dp_size > 1:
865
+ logger.warning(
866
+ "Structured decoding precompilation skipped since structured decoding is not supported with DP."
867
+ )
868
+ return
869
+ for num_reqs in self.runner.num_reqs_paddings:
870
+ dummy_logits = self._create_dummy_tensor(
871
+ (num_reqs, self.runner.vocab_size), jnp.bfloat16)
872
+ dummy_require_struct_decoding = self.runner.require_structured_out_cpu[:
873
+ num_reqs]
874
+ dummy_grammar_bitmask = self.runner.grammar_bitmask_cpu[:num_reqs]
875
+
876
+ (dummy_logits, dummy_require_struct_decoding,
877
+ dummy_grammar_bitmask, arange) = device_array(
878
+ self.runner.mesh,
879
+ (dummy_logits, dummy_require_struct_decoding,
880
+ dummy_grammar_bitmask, self.runner.structured_decode_arange))
881
+
882
+ self._run_compilation(
883
+ "structured_decode",
884
+ self.runner.structured_decoding_manager.structured_decode_fn,
885
+ dummy_require_struct_decoding,
886
+ dummy_grammar_bitmask,
887
+ dummy_logits,
888
+ arange,
889
+ num_reqs=num_reqs,
890
+ )