tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1768 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import functools
17
+ import logging
18
+ import random
19
+ from contextlib import nullcontext
20
+ from dataclasses import dataclass
21
+ from typing import Any, Callable, Dict, List, Optional, Tuple, cast
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ import jaxtyping
26
+ import numpy as np
27
+ import vllm.envs as vllm_envs
28
+ from flax import nnx
29
+ from jax.experimental import mesh_utils
30
+ from jax.sharding import NamedSharding, PartitionSpec
31
+ from vllm.config import VllmConfig
32
+ from vllm.distributed import get_pp_group
33
+ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
34
+ has_kv_transfer_group)
35
+ from vllm.forward_context import set_forward_context
36
+ from vllm.tasks import SupportedTask
37
+ from vllm.utils.math_utils import cdiv
38
+ from vllm.v1.core.sched.output import GrammarOutput
39
+ from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
40
+ from vllm.v1.kv_cache_interface import KVCacheConfig
41
+ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
42
+ DraftTokenIds, KVConnectorOutput, LogprobsLists,
43
+ ModelRunnerOutput)
44
+ from vllm.v1.request import Request
45
+ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
46
+ from vllm.v1.worker.kv_connector_model_runner_mixin import \
47
+ KVConnectorModelRunnerMixin
48
+ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
49
+
50
+ import tpu_inference.envs as envs
51
+ from tpu_inference import utils as common_utils
52
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
53
+ from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
54
+ MESH_AXIS_NAMES_2D,
55
+ ShardingAxisName,
56
+ ShardingConfigManager)
57
+ from tpu_inference.layers.jax.sample.rejection_sampler import RejectionSampler
58
+ from tpu_inference.layers.jax.sample.sampling import (compute_logprobs,
59
+ gather_logprobs, sample)
60
+ from tpu_inference.layers.jax.sample.sampling_metadata import \
61
+ TPUSupportedSamplingMetadata
62
+ from tpu_inference.logger import init_logger
63
+ from tpu_inference.models.common.model_loader import get_model
64
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
65
+ JaxIntermediateTensors
66
+ from tpu_inference.models.jax.utils.weight_utils import (
67
+ shard_put, transfer_state_with_mappings)
68
+ from tpu_inference.runner import utils as runner_utils
69
+ from tpu_inference.runner.compilation_manager import CompilationManager
70
+ from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
71
+ from tpu_inference.runner.kv_cache_manager import KVCacheManager
72
+ from tpu_inference.runner.lora_utils import LoraUtils
73
+ from tpu_inference.runner.multimodal_manager import MultiModalManager
74
+ from tpu_inference.runner.persistent_batch_manager import \
75
+ PersistentBatchManager
76
+ from tpu_inference.runner.speculative_decoding_manager import (
77
+ SpecDecodeMetadata, SpeculativeDecodingManager)
78
+ from tpu_inference.runner.structured_decoding_manager import \
79
+ StructuredDecodingManager
80
+ from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
81
+ from tpu_inference.utils import (device_array, make_optimized_mesh,
82
+ time_function, to_jax_dtype, to_torch_dtype)
83
+
84
+ logger = init_logger(__name__)
85
+
86
+ logging.getLogger("torchax.tensor").setLevel(logging.ERROR)
87
+
88
+ INVALID_TOKEN_ID = -1
89
+ # Smallest output size
90
+ MIN_NUM_SEQS = 8
91
+
92
+ DUMMY_METADATA = AttentionMetadata(
93
+ input_positions=[],
94
+ block_tables=[],
95
+ request_distribution=[0, 0, 0],
96
+ )
97
+
98
+
99
+ class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
100
+ """Holds asynchronous model output specifically from a TPU runner.
101
+
102
+ This class acts as a wrapper around the standard ModelRunnerOutput. Its
103
+ primary purpose is to hold references to data still on the TPU device
104
+ (like the `next_tokens` JAX array) without blocking the main thread.
105
+
106
+ The `get_output()` method is called to resolve these async results,
107
+ triggering the JAX device-to-host (CPU) data transfer and populating
108
+ the final `ModelRunnerOutput` object.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ model_runner_output: ModelRunnerOutput,
114
+ next_tokens: jax.Array,
115
+ num_reqs: int,
116
+ discard_sampled_tokens_req_indices: list[int],
117
+ logits_indices_selector: Optional[List[int]] = None,
118
+ ):
119
+ self._model_runner_output = model_runner_output
120
+ self._next_tokens = next_tokens
121
+ self._num_reqs = num_reqs
122
+ self._discard_sampled_tokens_req_indices = discard_sampled_tokens_req_indices
123
+ self.logits_indices_selector: list[int] = logits_indices_selector
124
+
125
+ def get_output(self) -> ModelRunnerOutput:
126
+ next_tokens_cpu = np.asarray(jax.device_get(self._next_tokens))
127
+ if self.logits_indices_selector is not None:
128
+ next_tokens_cpu = next_tokens_cpu[self.logits_indices_selector]
129
+ selected_token_ids = np.expand_dims(next_tokens_cpu[:self._num_reqs],
130
+ 1)
131
+ valid_sampled_token_ids = selected_token_ids.tolist()
132
+ for i in self._discard_sampled_tokens_req_indices:
133
+ valid_sampled_token_ids[i].clear()
134
+ self._model_runner_output.sampled_token_ids = valid_sampled_token_ids
135
+ return self._model_runner_output
136
+
137
+
138
+ @dataclass
139
+ class AsyncPreResults:
140
+ req_ids: list[str]
141
+ next_tokens: jax.Array
142
+ request_seq_lens: list[tuple[int, CachedRequestState, int]]
143
+ discard_sampled_tokens_req_indices: list[int]
144
+ placeholder_req_id_to_index: dict[str, int]
145
+ logits_indices_selector: Optional[List[int]] = None
146
+
147
+
148
+ @dataclass
149
+ class ExecuteModelState:
150
+ """Ephemeral cached state transferred between execute_model() and
151
+ sample_tokens(), after execute_model() returns None."""
152
+
153
+ scheduler_output: "VllmSchedulerOutput"
154
+ attn_metadata: AttentionMetadata
155
+ input_ids: Optional[jax.Array]
156
+ hidden_states: jax.Array
157
+ logits: jax.Array
158
+ aux_hidden_states: Optional[jax.Array]
159
+ spec_decode_metadata: Optional[SpecDecodeMetadata]
160
+ kv_connector_output: Optional[KVConnectorOutput]
161
+ logits_indices_selector: Optional[List[int]] = None
162
+ padded_num_reqs: Optional[int] = None
163
+
164
+
165
+ @functools.partial(jax.jit, donate_argnums=(0, 1, 2))
166
+ def _substitute_placeholder_token(
167
+ input_ids: jax.Array, token_in_tpu_cur_input_indices: jax.Array,
168
+ token_in_tpu_pre_next_tokens_indices: jax.Array,
169
+ next_tokens: jax.Array, placeholder_num: int):
170
+ """Substitute placeholder tokens from TPU for async scheduler
171
+
172
+ Padding for parallelisation of the substitute_placeholder_token_fn
173
+ [1, 3] => [1, 3, 0, 2, 4, 5, 6, 7, 8]
174
+ The reason for such a special padding instead of padding with -1 is:
175
+ An edge case when the end index needs to be updated and padding is required.
176
+ If we pad the array with -1, the _substitute_placeholder_token_fn will repeatedly update the end element with the original value
177
+ Although such a scenario is unlikely to happen in vLLM, it is best to eliminate any potential risks.
178
+
179
+ Args:
180
+ input_ids: possible input_ids size
181
+ token_in_tpu_cur_input_indices: replace holder idx in input_ids. Length the same to input_ids.
182
+ token_in_tpu_pre_next_tokens_indices: value idx in next_tokens. Length the same to input_ids.
183
+ next_tokens: next tokens on the TPU from previous step.
184
+ placeholder_num: number of placeholders. placeholder_num <= len(token_in_tpu_cur_input_indices)
185
+ Return:
186
+ input_ids after replace placeholder tokens
187
+ """
188
+ assert input_ids.shape == token_in_tpu_cur_input_indices.shape == token_in_tpu_pre_next_tokens_indices.shape, \
189
+ f"Shape mismatch: input_ids and index arrays must have identical shapes due to precompilation assumptions. " \
190
+ f"Got: {input_ids.shape=}, {token_in_tpu_cur_input_indices.shape=}, {token_in_tpu_pre_next_tokens_indices.shape=}"
191
+
192
+ # updates the input_ids for all placeholders.
193
+ mask = jnp.arange(input_ids.shape[0]) < placeholder_num
194
+ new_token_values = next_tokens[token_in_tpu_pre_next_tokens_indices]
195
+ original_values = input_ids[token_in_tpu_cur_input_indices]
196
+ update_values = jnp.where(mask, new_token_values, original_values)
197
+ return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
198
+
199
+
200
+ def _jax_logprobs_to_lists(logprobs_tensors,
201
+ logits_indices_selector=None,
202
+ cu_num_generated_tokens=None):
203
+ """Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
204
+ log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
205
+ logprobs_list = logprobs_tensors.logprobs.tolist()
206
+ selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
207
+
208
+ if logits_indices_selector is not None:
209
+ log_token_ids_list = [
210
+ log_token_ids_list[i] for i in logits_indices_selector
211
+ ]
212
+ logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
213
+ selected_token_ranks_list = [
214
+ selected_token_ranks_list[i] for i in logits_indices_selector
215
+ ]
216
+
217
+ return LogprobsLists(
218
+ logprob_token_ids=np.asarray(log_token_ids_list),
219
+ logprobs=np.asarray(logprobs_list),
220
+ sampled_token_ranks=np.asarray(selected_token_ranks_list),
221
+ cu_num_generated_tokens=cu_num_generated_tokens,
222
+ )
223
+
224
+
225
+ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
226
+
227
+ def __init__(
228
+ self,
229
+ vllm_config: VllmConfig,
230
+ devices: List[Any],
231
+ rank: int = 0,
232
+ is_first_rank: bool = True,
233
+ is_last_rank: bool = True,
234
+ ):
235
+ self.vllm_config = vllm_config
236
+ self.model_config = vllm_config.model_config
237
+ # TODO(jevinjiang): override block size based on RPA v3.
238
+ self.cache_config = vllm_config.cache_config
239
+ self.lora_config = vllm_config.lora_config
240
+ self.load_config = vllm_config.load_config
241
+ self.parallel_config = vllm_config.parallel_config
242
+ self.scheduler_config = vllm_config.scheduler_config
243
+ self.speculative_config = vllm_config.speculative_config
244
+ self.observability_config = vllm_config.observability_config
245
+ self.device_config = vllm_config.device_config
246
+
247
+ self.devices = devices
248
+ self.dtype = self.model_config.dtype
249
+ self.maybe_forbid_compile = runner_utils.ForbidCompile(
250
+ ) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
251
+ self.dp_size = self.vllm_config.sharding_config.total_dp_size
252
+ self.rank = rank
253
+ self.is_first_rank = is_first_rank
254
+ self.is_last_rank = is_last_rank
255
+
256
+ self._init_random()
257
+ self._init_mesh()
258
+ self._init_phased_profiling()
259
+ self._init_mm()
260
+ self._init_inputs()
261
+ self._init_speculative_decoding()
262
+
263
+ # Delegate functions to specific manager classes.
264
+ self.compilation_manager = CompilationManager(self)
265
+ if self.is_last_rank:
266
+ self.speculative_decoding_manager = SpeculativeDecodingManager(
267
+ self)
268
+ self.structured_decoding_manager = StructuredDecodingManager(self)
269
+ self.kv_cache_manager = KVCacheManager(self)
270
+ self.mm_manager = MultiModalManager(self)
271
+ self.persistent_batch_manager = PersistentBatchManager(
272
+ self.requests, self.input_batch, self.encoder_cache,
273
+ self.uses_mrope, self.model_config, self.is_last_rank)
274
+ self.lora_utils = LoraUtils(self)
275
+
276
+ cache_dtype = self.cache_config.cache_dtype
277
+ if cache_dtype == "auto":
278
+ cache_dtype = self.dtype
279
+ self.kv_cache_dtype = to_torch_dtype(cache_dtype)
280
+
281
+ self._pre_async_results: AsyncPreResults | None = None
282
+ self._substitute_placeholder_token_fn = _substitute_placeholder_token
283
+ self.execute_model_state: ExecuteModelState | None = None
284
+
285
+ def _init_random(self):
286
+ if self.model_config.seed is None:
287
+ self.model_config.seed = 0
288
+ random.seed(self.model_config.seed)
289
+ np.random.seed(self.model_config.seed)
290
+ self.rng_key = jax.random.key(self.model_config.seed)
291
+
292
+ def _init_mesh(self) -> None:
293
+ if envs.NEW_MODEL_DESIGN:
294
+ self.mesh = self._create_new_model_mesh()
295
+ else:
296
+ # NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
297
+ # to create a 2D mesh for now. We should make the new_model_mesh as the default
298
+ # in the future.
299
+ self.mesh = self._create_2d_mesh()
300
+
301
+ logger.info(f"Init mesh | mesh={self.mesh}")
302
+
303
+ def _create_new_model_mesh(self) -> jax.sharding.Mesh:
304
+ num_slices = envs.NUM_SLICES
305
+
306
+ logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
307
+ f"num_slices={num_slices}")
308
+
309
+ if num_slices == 1:
310
+ devices_array = self._create_single_slice_mesh()
311
+ else:
312
+ devices_array = self._create_multi_slice_mesh(num_slices)
313
+
314
+ return jax.sharding.Mesh(devices_array, MESH_AXIS_NAMES)
315
+
316
+ def _create_single_slice_mesh(self) -> jax.Array:
317
+ sharding_strategy: ShardingConfigManager = self.vllm_config.sharding_config
318
+ mesh_shape = (
319
+ sharding_strategy.model_dp_size,
320
+ sharding_strategy.attn_dp_size,
321
+ sharding_strategy.expert_size,
322
+ sharding_strategy.tp_size,
323
+ )
324
+
325
+ return mesh_utils.create_device_mesh(
326
+ mesh_shape,
327
+ self.devices,
328
+ allow_split_physical_axes=True,
329
+ )
330
+
331
+ def _create_multi_slice_mesh(self, num_slices: int) -> jax.Array:
332
+ sharding_strategy: ShardingConfigManager = self.vllm_config.sharding_config
333
+ dp_inner = sharding_strategy.model_dp_size // num_slices
334
+
335
+ # Splits data parallelism across multiple slices.
336
+ ici_mesh_shape = (
337
+ dp_inner,
338
+ sharding_strategy.attn_dp_size,
339
+ sharding_strategy.expert_size,
340
+ sharding_strategy.tp_size,
341
+ )
342
+ dcn_mesh_shape = (num_slices, 1, 1, 1)
343
+
344
+ return mesh_utils.create_hybrid_device_mesh(
345
+ mesh_shape=ici_mesh_shape,
346
+ dcn_mesh_shape=dcn_mesh_shape,
347
+ devices=self.devices,
348
+ allow_split_physical_axes=True,
349
+ )
350
+
351
+ def _create_2d_mesh(self) -> jax.sharding.Mesh:
352
+
353
+ sharding_strategy: ShardingConfigManager = self.vllm_config.sharding_config
354
+ mesh_shape = (
355
+ sharding_strategy.model_dp_size,
356
+ sharding_strategy.tp_size,
357
+ )
358
+
359
+ enforce_device_order = (
360
+ self.vllm_config.sharding_config.device_indexes is not None
361
+ and len(self.vllm_config.sharding_config.device_indexes) > 0)
362
+
363
+ if enforce_device_order:
364
+ return jax.make_mesh(mesh_shape,
365
+ MESH_AXIS_NAMES_2D,
366
+ devices=self.devices)
367
+ else:
368
+ return make_optimized_mesh(mesh_shape,
369
+ MESH_AXIS_NAMES_2D,
370
+ devices=self.devices)
371
+
372
+ def _init_phased_profiling(self) -> None:
373
+ self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
374
+ self.phase_based_profiler = None
375
+ if self.phased_profiling_dir:
376
+ self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
377
+ self.phased_profiling_dir)
378
+
379
+ def _init_mm(self) -> None:
380
+ self.is_multimodal_model = None
381
+ self.uses_mrope = self.model_config.uses_mrope
382
+
383
+ def _init_speculative_decoding(self) -> None:
384
+ self.drafter = None
385
+ if self.speculative_config:
386
+ if self.speculative_config.method == "ngram":
387
+ self.drafter = NgramProposer(self.vllm_config)
388
+ elif self.speculative_config.method == "eagle3":
389
+ self.drafter = Eagle3Proposer(self.vllm_config, self)
390
+ else:
391
+ raise NotImplementedError(
392
+ "Unsupported speculative decoding method: "
393
+ f"{self.speculative_config.method}")
394
+ self.rejection_sampler = RejectionSampler()
395
+
396
+ def _init_inputs(self) -> None:
397
+ model_config = self.model_config
398
+ cache_config = self.cache_config
399
+ scheduler_config = self.scheduler_config
400
+
401
+ self.sliding_window = model_config.get_sliding_window()
402
+ self.block_size = cache_config.block_size
403
+ self.max_model_len = model_config.max_model_len
404
+ self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
405
+ # InputBatch needs to work with sampling tensors greater than padding
406
+ # to avoid dynamic shapes. Also, avoid suboptimal alignment.
407
+ # The total number of requests is dp_size * max_num_seqs
408
+ self.max_num_reqs = max(self.dp_size * scheduler_config.max_num_seqs,
409
+ MIN_NUM_SEQS)
410
+ # [16, 32, 64, 128, 256, 512, 1024, 2048]
411
+ self.num_tokens_paddings = runner_utils.get_token_paddings(
412
+ min_token_size=max(16, self.dp_size),
413
+ max_token_size=scheduler_config.max_num_batched_tokens *
414
+ self.dp_size,
415
+ padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
416
+ self.num_tokens_paddings_per_dp = [
417
+ padding // self.dp_size for padding in self.num_tokens_paddings
418
+ ]
419
+ # In case `max_num_tokens < max(num_tokens_paddings)` use the actual
420
+ # padded max value to pre-allocate data structures and pre-compile.
421
+ self.max_num_tokens = self.num_tokens_paddings[-1]
422
+
423
+ # Request states.
424
+ self.requests: dict[str, CachedRequestState] = {}
425
+ # mm_hash -> encoder_output
426
+ self.encoder_cache: dict[str, jax.Array] = {}
427
+ self.input_batch = InputBatch(
428
+ max_num_reqs=self.max_num_reqs,
429
+ max_model_len=self.max_model_len,
430
+ max_num_batched_tokens=self.max_num_tokens,
431
+ pin_memory=False,
432
+ vocab_size=self.model_config.get_vocab_size(),
433
+ block_sizes=[self.block_size],
434
+ is_spec_decode=bool(self.vllm_config.speculative_config),
435
+ )
436
+
437
+ self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
438
+ self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
439
+ # Note: self.input_batch and self.block_tables_cpu are both initialized
440
+ # with only 1 block_size. For hybrid kv cache, it will be re-init
441
+ # in kv_cache_manager's maybe_reinitialize_input_batch.
442
+ self.block_tables_cpu = [
443
+ np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
444
+ dtype=np.int32)
445
+ ]
446
+
447
+ self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
448
+ dtype=np.int32)
449
+ self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
450
+ self.logits_indices_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
451
+ # Range tensor with values [0 .. self.max_num_tokens - 1].
452
+ # Used to initialize positions / context_lens / seq_lens
453
+ # Keep in int64 to avoid overflow with long context
454
+ self.arange_cpu = np.arange(self.max_num_tokens, dtype=np.int64)
455
+ min_num_reqs = max(MIN_NUM_SEQS, self.dp_size)
456
+ self.num_reqs_paddings = runner_utils.get_req_paddings(
457
+ min_req_size=min_num_reqs, max_req_size=self.max_num_reqs)
458
+ self.num_reqs_paddings_per_dp = [
459
+ padding // self.dp_size for padding in self.num_reqs_paddings
460
+ ]
461
+
462
+ # Padding for logits. Without speculative decoding, each request has one position to select from.
463
+ # With speculative decoding, each request has multiple positions to select from.
464
+ max_logits_per_req = 1
465
+ if self.speculative_config:
466
+ max_logits_per_req = self.speculative_config.num_speculative_tokens + 1 # Including bonus token
467
+ self.num_logits_paddings = runner_utils.get_token_paddings(
468
+ min_token_size=MIN_NUM_SEQS,
469
+ max_token_size=self.max_num_reqs * max_logits_per_req,
470
+ padding_gap=0)
471
+ else:
472
+ self.num_logits_paddings = None
473
+
474
+ self.temperatures_cpu = np.zeros(self.max_num_tokens, dtype=np.float32)
475
+ self.top_ps_cpu = np.zeros(self.max_num_tokens, dtype=np.float32)
476
+ self.top_ks_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
477
+
478
+ # tensors for structured decoding
479
+ self.vocab_size = self.model_config.get_vocab_size()
480
+ self.grammar_bitmask_cpu = np.zeros(
481
+ (self.max_num_reqs, cdiv(self.vocab_size, 32)),
482
+ dtype=np.int32,
483
+ )
484
+ self.require_structured_out_cpu = np.zeros(
485
+ (self.max_num_reqs, 1),
486
+ dtype=np.bool_,
487
+ )
488
+ self.structured_decode_arange = np.arange(0, 32, dtype=np.int32)
489
+
490
+ # multi-modal support
491
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
492
+
493
+ # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
494
+ # the modality of inputs. For text-only inputs, each dimension has
495
+ # identical position IDs, making M-RoPE functionally equivalent to
496
+ # 1D-RoPE.
497
+ # See page 5 of https://arxiv.org/abs/2409.12191
498
+ self.mrope_positions_cpu = np.zeros((3, self.max_num_tokens),
499
+ dtype=np.int64)
500
+
501
+ def load_model(self):
502
+ self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model = get_model(
503
+ self.vllm_config,
504
+ self.rng_key,
505
+ self.mesh,
506
+ )
507
+
508
+ multimodal_fns = multimodal_fns or {}
509
+ self.precompile_vision_encoder_fn = multimodal_fns.get(
510
+ "precompile_vision_encoder_fn", None)
511
+ self.get_multimodal_embeddings_fn = multimodal_fns.get(
512
+ "get_multimodal_embeddings_fn", None)
513
+ self.get_input_embeddings_fn = multimodal_fns.get(
514
+ "get_input_embeddings_fn", None)
515
+ self.get_mrope_input_positions_fn = multimodal_fns.get(
516
+ "get_mrope_input_positions_fn", None)
517
+
518
+ if self.drafter is not None:
519
+ logger.info("Loading drafter model...")
520
+ self.drafter.load_model(self.state)
521
+
522
+ self.rng_params_for_sampling = nnx.Rngs(
523
+ jax.random.key(self.model_config.seed)).params()
524
+ self.is_multimodal_model = (
525
+ self.model_config.is_multimodal_model
526
+ and self.get_multimodal_embeddings_fn is not None and hasattr(
527
+ self.model_config.hf_config, "architectures"
528
+ ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
529
+ and len(self.model_config.hf_config.architectures) >= 1
530
+ and self.model_config.hf_config.architectures[0]
531
+ != "Llama4ForConditionalGeneration")
532
+
533
+ logger.info(f"Init model | "
534
+ f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
535
+
536
+ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
537
+ return ("generate", )
538
+
539
+ def get_kv_cache_spec(self):
540
+ return self.kv_cache_manager.get_kv_cache_spec()
541
+
542
+ def initialize_kv_cache(self,
543
+ kv_cache_config: KVCacheConfig,
544
+ topology_order_id: int = 0) -> None:
545
+ self.topology_order_id = topology_order_id
546
+ self.kv_cache_config = kv_cache_config
547
+ self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
548
+ self.kv_caches = []
549
+ self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
550
+ if has_kv_transfer_group():
551
+ get_kv_transfer_group().register_runner(self)
552
+
553
+ def capture_model(self) -> None:
554
+ self.compilation_manager.capture_model()
555
+
556
+ @time_function
557
+ def execute_model(
558
+ self,
559
+ scheduler_output: "VllmSchedulerOutput",
560
+ intermediate_tensors: Optional[JaxIntermediateTensors] = None,
561
+ ) -> ModelRunnerOutput | JaxIntermediateTensors | None:
562
+ if self.execute_model_state is not None:
563
+ raise RuntimeError("State error: sample_tokens() must be called "
564
+ "after execute_model() returns None.")
565
+ _, output = self._execute_model(scheduler_output, intermediate_tensors)
566
+ return output
567
+
568
+ def sample_tokens(
569
+ self,
570
+ grammar_output: "GrammarOutput | None",
571
+ ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
572
+ if self.execute_model_state is None:
573
+ # This can happen in pipeline parallel case.
574
+ return EMPTY_MODEL_RUNNER_OUTPUT
575
+
576
+ (scheduler_output, attn_metadata, input_ids, hidden_states, logits,
577
+ aux_hidden_states, spec_decode_metadata, kv_connector_output,
578
+ logits_indices_selector,
579
+ padded_num_reqs) = (self.execute_model_state.scheduler_output,
580
+ self.execute_model_state.attn_metadata,
581
+ self.execute_model_state.input_ids,
582
+ self.execute_model_state.hidden_states,
583
+ self.execute_model_state.logits,
584
+ self.execute_model_state.aux_hidden_states,
585
+ self.execute_model_state.spec_decode_metadata,
586
+ self.execute_model_state.kv_connector_output,
587
+ self.execute_model_state.logits_indices_selector,
588
+ self.execute_model_state.padded_num_reqs)
589
+ self.execute_model_state = None
590
+
591
+ if grammar_output is not None:
592
+ (
593
+ require_struct_decoding, grammar_bitmask_padded, arange
594
+ ) = self.structured_decoding_manager.prepare_structured_decoding_input(
595
+ logits, grammar_output)
596
+ logits = self.structured_decoding_manager.structured_decode_fn(
597
+ require_struct_decoding,
598
+ grammar_bitmask_padded,
599
+ logits,
600
+ arange,
601
+ )
602
+ return self._sample_from_logits(
603
+ scheduler_output, attn_metadata, input_ids, hidden_states, logits,
604
+ aux_hidden_states, spec_decode_metadata, kv_connector_output,
605
+ logits_indices_selector, padded_num_reqs)
606
+
607
+ def _modify_prev_results(self):
608
+ # If copy to host has not been done, we just wait.
609
+ # device_get should return immediately as we have scheduled it in previous function call.
610
+ assert self._pre_async_results is not None, "When we call _modify_prev_results(), self._pre_async_results should already exist"
611
+ pre_req_ids = self._pre_async_results.req_ids
612
+ pre_next_tokens = self._pre_async_results.next_tokens
613
+ pre_request_seq_lens = self._pre_async_results.request_seq_lens
614
+ pre_discard_sampled_tokens_req_indices = self._pre_async_results.discard_sampled_tokens_req_indices
615
+ pre_logits_indices_selector = self._pre_async_results.logits_indices_selector
616
+
617
+ next_tokens_cpu = np.asarray(jax.device_get(pre_next_tokens))
618
+ if pre_logits_indices_selector is not None:
619
+ next_tokens_cpu = next_tokens_cpu[pre_logits_indices_selector]
620
+ selected_token_ids = np.expand_dims(next_tokens_cpu[:len(pre_req_ids)],
621
+ 1)
622
+ valid_sampled_token_ids = selected_token_ids.tolist()
623
+
624
+ # Mask out the sampled tokens that should not be sampled.
625
+ for i in pre_discard_sampled_tokens_req_indices:
626
+ valid_sampled_token_ids[i].clear()
627
+ # Append sampled tokens
628
+ for pre_req_idx, req_state, _ in pre_request_seq_lens:
629
+ sampled_ids = valid_sampled_token_ids[pre_req_idx]
630
+ if not sampled_ids:
631
+ continue
632
+
633
+ # If request not active in the *current* batch (e.g. finished or evicted), skip it.
634
+ req_id = pre_req_ids[pre_req_idx]
635
+ if req_id not in self.input_batch.req_id_to_index:
636
+ continue
637
+
638
+ req_idx = self.input_batch.req_id_to_index[req_id]
639
+ assert req_state is self.requests[
640
+ req_id], "The req_state should be valid and identical"
641
+
642
+ # Updated on previous execute
643
+ end_idx = self.input_batch.num_tokens_no_spec[req_idx]
644
+ assert len(sampled_ids) == 1, "do not support spec decode yet"
645
+ start_idx = end_idx - 1
646
+ assert end_idx <= self.max_model_len, (
647
+ "Sampled token IDs exceed the max model length. "
648
+ f"Total number of tokens: {end_idx} > max_model_len: "
649
+ f"{self.max_model_len}")
650
+
651
+ self.input_batch.token_ids_cpu[req_idx,
652
+ start_idx:end_idx] = sampled_ids
653
+ # Replace previous placeholder
654
+ req_state.output_token_ids[-1] = sampled_ids[-1]
655
+
656
+ def _update_placeholder(self,
657
+ discard_sampled_tokens_req_indices,
658
+ request_seq_lens,
659
+ logits_indices_selector=None):
660
+ placeholder_req_id_to_index: dict[str, int] = {}
661
+ discard_sampled_tokens_req_indices_set = set(
662
+ discard_sampled_tokens_req_indices)
663
+ for req_idx, req_state, _ in request_seq_lens:
664
+ if req_idx in discard_sampled_tokens_req_indices_set:
665
+ continue
666
+
667
+ start_idx = self.input_batch.num_tokens_no_spec[req_idx]
668
+ # Not supporting spec decode yet, assume only 1 new token
669
+ end_idx = start_idx + 1
670
+ assert end_idx <= self.max_model_len, (
671
+ "Sampled token IDs exceed the max model length. "
672
+ f"Total number of tokens: {end_idx} > max_model_len: "
673
+ f"{self.max_model_len}")
674
+
675
+ # Update cpu tokens at next execute and prepare input from tpu
676
+ self.input_batch.num_tokens_no_spec[req_idx] = end_idx
677
+ self.input_batch.num_tokens[req_idx] = end_idx
678
+
679
+ # For placeholder, should be update on next execute.
680
+ req_state.output_token_ids.extend([0])
681
+ if logits_indices_selector is None:
682
+ placeholder_req_id_to_index[req_state.req_id] = req_idx
683
+ else:
684
+ placeholder_req_id_to_index[
685
+ req_state.req_id] = logits_indices_selector[req_idx]
686
+ return placeholder_req_id_to_index
687
+
688
+ def _execute_model(
689
+ self,
690
+ scheduler_output: "VllmSchedulerOutput",
691
+ intermediate_tensors: Optional[JaxIntermediateTensors] = None,
692
+ ) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
693
+ | None]:
694
+ self.persistent_batch_manager.update_states(
695
+ scheduler_output, self.get_mrope_input_positions_fn)
696
+ if not scheduler_output.total_num_scheduled_tokens:
697
+ if has_kv_transfer_group():
698
+ return DUMMY_METADATA, self.kv_connector_no_forward(
699
+ scheduler_output, self.vllm_config)
700
+
701
+ # Return empty ModelRunnerOutput if there's no work to do.
702
+ # TODO(fhzhang): We rely on empty cycles to remove requests in input batch. Fix it to reduce overhead.
703
+ logger.debug(f"Nothing scheduled: {scheduler_output}!")
704
+ # NOTE(pooyam): There is no guarantee that scheduler is not sending empty output: https://github.com/vllm-project/vllm/blob/7cfea0df390c154c1026f77d3682e2733ca4aca8/vllm/v1/engine/core.py#L275
705
+ # Why they are not preventing that is not clear to me.
706
+ if len(scheduler_output.finished_req_ids) == 0:
707
+ logger.warning(
708
+ "Should not schedule a request that does nothing!")
709
+ # raise Exception(
710
+ # "Should not schedule a request that does nothing!")
711
+ return DUMMY_METADATA, EMPTY_MODEL_RUNNER_OUTPUT
712
+
713
+ # TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
714
+ (
715
+ input_ids,
716
+ input_positions,
717
+ attn_metadata,
718
+ _,
719
+ logits_indices,
720
+ spec_decode_metadata,
721
+ logits_indices_selector,
722
+ padded_num_reqs,
723
+ ) = self._prepare_inputs(scheduler_output)
724
+
725
+ is_llama_guard_4 = (
726
+ hasattr(
727
+ self.model_config.hf_config, "architectures"
728
+ ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
729
+ and len(self.model_config.hf_config.architectures) >= 1
730
+ and self.model_config.hf_config.architectures[0]
731
+ == "Llama4ForConditionalGeneration")
732
+
733
+ # multi-modal support
734
+ if self.is_multimodal_model:
735
+ # Run the multimodal encoder if any.
736
+ # We have the modality embeds at this time.
737
+ self.mm_manager.execute_mm_encoder(scheduler_output)
738
+ mm_embeds = self.mm_manager.gather_mm_embeddings(
739
+ scheduler_output, input_ids.shape[0])
740
+ #TODO: Remove the follow elif statement once Llama Guard 4 Vision portion has been implemented
741
+ elif is_llama_guard_4 and any(
742
+ self.mm_manager.runner.requests[req_id].mm_features
743
+ for req_id in self.mm_manager.runner.input_batch.req_ids):
744
+ raise NotImplementedError(
745
+ "Llama Guard 4 (JAX) currently supports only text inputs. "
746
+ "Multimodal processing not yet implemented.")
747
+ else:
748
+ mm_embeds = []
749
+
750
+ # NOTE(Wenlong): For multi-modal model,
751
+ # it will embed the text tokens and merge with the existing modality embeds
752
+ # Later, the multi-modality model will take the embedding as the input.
753
+ # For text-only model, this does nothing. It will input the input_ids and
754
+ # leave the mebedding job inside the forward pass
755
+ input_ids, inputs_embeds = self._get_input_ids_embeds(
756
+ input_ids, mm_embeds)
757
+
758
+ lora_metadata = self.lora_utils.extract_lora_metadata()
759
+ # TODO: make _get_input_ids_embeds within this context
760
+ # NOTE: right now, mm model will use embeddings as the input,
761
+ # but text-only model will use input_ids
762
+ with self.maybe_forbid_compile:
763
+
764
+ with set_forward_context(
765
+ None,
766
+ self.vllm_config,
767
+ ), self.maybe_get_kv_connector_output(
768
+ scheduler_output) as kv_connector_output:
769
+ # NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
770
+ # but one of them would be `None`
771
+ (self.kv_caches, hidden_states,
772
+ aux_hidden_states) = self.model_fn(
773
+ self.state,
774
+ self.kv_caches,
775
+ input_ids,
776
+ attn_metadata,
777
+ inputs_embeds,
778
+ input_positions,
779
+ tuple(self.layer_name_to_kvcache_index.items()),
780
+ lora_metadata,
781
+ intermediate_tensors,
782
+ self.is_first_rank,
783
+ self.is_last_rank,
784
+ )
785
+ if not get_pp_group().is_last_rank:
786
+ assert isinstance(hidden_states, JaxIntermediateTensors)
787
+ hidden_states.kv_connector_output = kv_connector_output
788
+ return attn_metadata, hidden_states
789
+ hidden_states = self._select_from_array_fn(hidden_states,
790
+ logits_indices)
791
+ logits = self.compute_logits_fn(
792
+ self.state,
793
+ hidden_states,
794
+ lora_metadata,
795
+ )
796
+
797
+ self.execute_model_state = ExecuteModelState(
798
+ scheduler_output=scheduler_output,
799
+ attn_metadata=attn_metadata,
800
+ input_ids=input_ids,
801
+ hidden_states=hidden_states,
802
+ logits=logits,
803
+ aux_hidden_states=aux_hidden_states,
804
+ spec_decode_metadata=spec_decode_metadata,
805
+ kv_connector_output=kv_connector_output,
806
+ logits_indices_selector=logits_indices_selector,
807
+ padded_num_reqs=padded_num_reqs)
808
+ return attn_metadata, None
809
+
810
+ def _sample_from_logits(
811
+ self,
812
+ scheduler_output: "VllmSchedulerOutput",
813
+ attn_metadata: AttentionMetadata,
814
+ input_ids: Optional[jax.Array],
815
+ hidden_states: jax.Array,
816
+ logits: jax.Array,
817
+ aux_hidden_states: Optional[jax.Array],
818
+ spec_decode_metadata: Optional[SpecDecodeMetadata],
819
+ kv_connector_output: Optional[KVConnectorOutput],
820
+ logits_indices_selector: Optional[List[int]] = None,
821
+ padded_num_reqs: Optional[int] = None,
822
+ ) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
823
+ if padded_num_reqs is None:
824
+ padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
825
+ self.input_batch.num_reqs, self.max_num_reqs)
826
+
827
+ sharding = None
828
+ if self.dp_size > 1:
829
+ sharding = NamedSharding(self.mesh,
830
+ PartitionSpec(ShardingAxisName.ATTN_DATA))
831
+
832
+ tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
833
+ self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
834
+
835
+ # TODO(pooyam): Should we move this to `_prepare_inputs`?
836
+ if tpu_sampling_metadata.do_sampling:
837
+ self.rng_params_for_sampling, step_rng = jax.random.split(
838
+ self.rng_params_for_sampling)
839
+ else:
840
+ step_rng = self.rng_params_for_sampling
841
+
842
+ if spec_decode_metadata is None:
843
+ next_tokens = sample(
844
+ step_rng,
845
+ self.mesh,
846
+ logits,
847
+ tpu_sampling_metadata,
848
+ )
849
+ else:
850
+ if tpu_sampling_metadata.do_sampling:
851
+ bonus_rng, rejection_rng = jax.random.split(step_rng)
852
+ else:
853
+ bonus_rng = step_rng
854
+ rejection_rng = step_rng
855
+ bonus_logits = self._select_from_array_fn(
856
+ logits, spec_decode_metadata.bonus_logits_indices)
857
+ bonus_token_ids = sample(
858
+ bonus_rng,
859
+ self.mesh,
860
+ bonus_logits,
861
+ tpu_sampling_metadata,
862
+ )
863
+ target_logits = self._select_from_array_fn(
864
+ logits, spec_decode_metadata.target_logits_indices)
865
+ next_tokens = self.rejection_sampler(
866
+ draft_token_ids=spec_decode_metadata.draft_token_ids,
867
+ num_draft_tokens=spec_decode_metadata.draft_lengths,
868
+ draft_probs=None,
869
+ target_logits=target_logits,
870
+ bonus_token_ids=bonus_token_ids,
871
+ sampling_metadata=tpu_sampling_metadata,
872
+ key=rejection_rng,
873
+ )
874
+
875
+ if tpu_sampling_metadata.logprobs:
876
+ logprobs = self._compute_and_gather_logprobs(
877
+ logits, next_tokens, self.model_config.max_logprobs)
878
+ else:
879
+ logprobs = None
880
+
881
+ num_reqs = self.input_batch.num_reqs
882
+
883
+ # Update the cache state concurrently. Code above will not block until
884
+ # We use `selected_token_ids`. Add mark_step if post-processing changes
885
+ request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
886
+ discard_sampled_tokens_req_indices = []
887
+ for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
888
+ assert req_id is not None
889
+ req_state = self.requests[req_id]
890
+ seq_len = (req_state.num_computed_tokens +
891
+ scheduler_output.num_scheduled_tokens[req_id])
892
+ if seq_len >= req_state.num_tokens:
893
+ request_seq_lens.append((i, req_state, seq_len))
894
+ else:
895
+ # Ignore the sampled token from the partial request.
896
+ # Rewind the generator state as if the token was not sampled.
897
+ generator = self.input_batch.generators.get(i)
898
+ if generator is not None:
899
+ # This relies on cuda-specific torch-internal impl details
900
+ generator.set_offset(generator.get_offset() - 4)
901
+
902
+ # Record the index of the request that should not be sampled,
903
+ # so that we could clear the sampled tokens before returning.
904
+ discard_sampled_tokens_req_indices.append(i)
905
+
906
+ assert all(
907
+ req_id is not None for req_id in
908
+ self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
909
+ req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
910
+
911
+ prompt_logprobs_dict = {}
912
+ for req_id in self.input_batch.req_ids[:num_reqs]:
913
+ prompt_logprobs_dict[req_id] = None
914
+
915
+ # If async scheduler enabled
916
+ if self.scheduler_config.async_scheduling:
917
+ # Get previous results from TPU and replace the placeholder.
918
+ if self._pre_async_results is not None:
919
+ assert not self.speculative_config and spec_decode_metadata is None, "Async scheduler does not support speculative decoding yet."
920
+ self._modify_prev_results()
921
+
922
+ # Set placeholder for next tokens that is not yet generated
923
+ placeholder_req_id_to_index: dict[
924
+ str, int] = self._update_placeholder(
925
+ discard_sampled_tokens_req_indices, request_seq_lens,
926
+ logits_indices_selector)
927
+
928
+ if logprobs is not None:
929
+ # Map logprobs back to the pre-dp shuffling order
930
+ logprobs_lists = _jax_logprobs_to_lists(
931
+ logprobs, logits_indices_selector)
932
+
933
+ else:
934
+ logprobs_lists = None
935
+
936
+ # Save the previous results
937
+ next_tokens = jax.copy_to_host_async(next_tokens)
938
+ self._pre_async_results = AsyncPreResults(
939
+ req_ids=req_ids,
940
+ next_tokens=next_tokens,
941
+ request_seq_lens=request_seq_lens,
942
+ discard_sampled_tokens_req_indices=
943
+ discard_sampled_tokens_req_indices,
944
+ placeholder_req_id_to_index=placeholder_req_id_to_index,
945
+ logits_indices_selector=logits_indices_selector)
946
+
947
+ # Return Model output to executor
948
+ model_runner_output = ModelRunnerOutput(
949
+ req_ids=req_ids,
950
+ req_id_to_index=copy.deepcopy(
951
+ self.input_batch.req_id_to_index),
952
+ sampled_token_ids=[], # Fill in async get
953
+ logprobs=logprobs_lists,
954
+ prompt_logprobs_dict=prompt_logprobs_dict,
955
+ pooler_output=[],
956
+ kv_connector_output=kv_connector_output,
957
+ )
958
+ # Return async_model_runner_output
959
+ async_model_runner_output = AsyncTPUModelRunnerOutput(
960
+ model_runner_output, next_tokens, num_reqs,
961
+ discard_sampled_tokens_req_indices, logits_indices_selector)
962
+ return async_model_runner_output
963
+
964
+ if spec_decode_metadata is None:
965
+ next_tokens = np.asarray(jax.device_get(next_tokens))
966
+ # Map tokens back to the pre-dp shuffling order
967
+ if logits_indices_selector is not None:
968
+ next_tokens = next_tokens[logits_indices_selector]
969
+ selected_token_ids = np.expand_dims(next_tokens[:num_reqs], 1)
970
+ valid_sampled_token_ids = selected_token_ids.tolist()
971
+ else:
972
+ valid_sampled_token_ids = self.rejection_sampler.parse_output(
973
+ next_tokens, self.input_batch.vocab_size,
974
+ spec_decode_metadata.draft_lengths_cpu, num_reqs,
975
+ spec_decode_metadata.draft_token_ids.shape[0])
976
+
977
+ # Mask out the sampled tokens that should not be sampled.
978
+ for i in discard_sampled_tokens_req_indices:
979
+ valid_sampled_token_ids[i].clear()
980
+ # Append sampled tokens
981
+ for req_idx, req_state, _ in request_seq_lens:
982
+ sampled_ids = valid_sampled_token_ids[req_idx]
983
+ if not sampled_ids:
984
+ continue
985
+
986
+ start_idx = self.input_batch.num_tokens_no_spec[req_idx]
987
+ end_idx = start_idx + len(sampled_ids)
988
+ assert end_idx <= self.max_model_len, (
989
+ "Sampled token IDs exceed the max model length. "
990
+ f"Total number of tokens: {end_idx} > max_model_len: "
991
+ f"{self.max_model_len}")
992
+
993
+ self.input_batch.token_ids_cpu[req_idx,
994
+ start_idx:end_idx] = sampled_ids
995
+ self.input_batch.num_tokens_no_spec[req_idx] = end_idx
996
+ self.input_batch.num_tokens[req_idx] = end_idx
997
+ req_state.output_token_ids.extend(sampled_ids)
998
+
999
+ if logprobs is not None:
1000
+ # Map logprobs back to the pre-dp shuffling order
1001
+ logprobs_lists = _jax_logprobs_to_lists(logprobs,
1002
+ logits_indices_selector)
1003
+ else:
1004
+ logprobs_lists = None
1005
+
1006
+ if self.speculative_config:
1007
+ with self.maybe_forbid_compile:
1008
+ self.speculative_decoding_manager.propose_draft_token_ids(
1009
+ valid_sampled_token_ids,
1010
+ aux_hidden_states,
1011
+ attn_metadata,
1012
+ spec_decode_metadata,
1013
+ scheduler_output,
1014
+ input_ids,
1015
+ )
1016
+
1017
+ model_runner_output = ModelRunnerOutput(
1018
+ req_ids=req_ids,
1019
+ req_id_to_index=self.input_batch.req_id_to_index,
1020
+ sampled_token_ids=valid_sampled_token_ids,
1021
+ logprobs=logprobs_lists,
1022
+ prompt_logprobs_dict=prompt_logprobs_dict,
1023
+ pooler_output=[],
1024
+ kv_connector_output=kv_connector_output,
1025
+ )
1026
+ return model_runner_output
1027
+
1028
+ @functools.partial(jax.jit, static_argnums=(0, ))
1029
+ def _select_from_array_fn(self, array, indices_to_select):
1030
+
1031
+ def select_local_fn(local_array, local_indices):
1032
+ return local_array[local_indices]
1033
+
1034
+ ret = jax.shard_map(
1035
+ select_local_fn,
1036
+ mesh=self.mesh,
1037
+ in_specs=(PartitionSpec(ShardingAxisName.ATTN_DATA),
1038
+ PartitionSpec(ShardingAxisName.ATTN_DATA)),
1039
+ out_specs=PartitionSpec(ShardingAxisName.ATTN_DATA))(
1040
+ array, indices_to_select)
1041
+
1042
+ return ret
1043
+
1044
+ @staticmethod
1045
+ @functools.partial(jax.jit, static_argnames=("max_logprobs", ))
1046
+ def _compute_and_gather_logprobs(logits, next_tokens, max_logprobs):
1047
+ logprobs = compute_logprobs(logits)
1048
+ return gather_logprobs(logprobs, next_tokens, max_logprobs)
1049
+
1050
+ def _prepare_dp_input_metadata(self,
1051
+ scheduler_output: "VllmSchedulerOutput"):
1052
+
1053
+ dp_size = self.dp_size
1054
+ num_reqs = self.input_batch.num_reqs
1055
+ max_num_reqs_per_dp_rank = self.max_num_reqs // dp_size
1056
+ req_ids_dp = {dp_rank: [] for dp_rank in range(dp_size)}
1057
+ req_indices_dp = {dp_rank: [] for dp_rank in range(dp_size)}
1058
+ num_scheduled_tokens_per_dp_rank = {
1059
+ dp_rank: 0
1060
+ for dp_rank in range(dp_size)
1061
+ }
1062
+ scheduled_tokens_per_dp_rank = {
1063
+ dp_rank: []
1064
+ for dp_rank in range(dp_size)
1065
+ }
1066
+ num_req_per_dp_rank = {dp_rank: 0 for dp_rank in range(dp_size)}
1067
+
1068
+ for req_id in self.input_batch.req_ids[:num_reqs]:
1069
+ dp_rank = scheduler_output.assigned_dp_rank[req_id]
1070
+ req_ids_dp[dp_rank].append(req_id)
1071
+ req_indices_dp[dp_rank].append(
1072
+ self.input_batch.req_id_to_index[req_id])
1073
+ num_scheduled_tokens_per_dp_rank[
1074
+ dp_rank] += scheduler_output.num_scheduled_tokens[req_id]
1075
+ scheduled_tokens_per_dp_rank[dp_rank].append(
1076
+ scheduler_output.num_scheduled_tokens[req_id])
1077
+ num_req_per_dp_rank[dp_rank] += 1
1078
+
1079
+ # Find maximum number of scheduled tokens across DP ranks
1080
+ max_num_scheduled_tokens_across_dp = max(
1081
+ num_scheduled_tokens_per_dp_rank.values())
1082
+
1083
+ padded_num_scheduled_tokens_per_dp_rank = runner_utils.get_padded_token_len(
1084
+ self.num_tokens_paddings_per_dp,
1085
+ max_num_scheduled_tokens_across_dp)
1086
+
1087
+ padded_total_num_scheduled_tokens = (
1088
+ padded_num_scheduled_tokens_per_dp_rank * dp_size)
1089
+
1090
+ assert max_num_scheduled_tokens_across_dp > 0
1091
+
1092
+ # Find maximum number of requests across DP ranks
1093
+ max_num_reqs_across_dp = max(
1094
+ len(req_ids) for req_ids in req_ids_dp.values())
1095
+ padded_num_reqs_per_dp_rank = runner_utils.get_padded_token_len(
1096
+ self.num_reqs_paddings_per_dp, max_num_reqs_across_dp)
1097
+ padded_num_reqs = padded_num_reqs_per_dp_rank * dp_size
1098
+
1099
+ all_req_indices = np.concatenate(
1100
+ [req_indices_dp[dp_rank] for dp_rank in range(dp_size)])
1101
+ all_positions = np.concatenate([
1102
+ np.arange(len(req_indices_dp[dp_rank])) +
1103
+ padded_num_reqs_per_dp_rank * dp_rank for dp_rank in range(dp_size)
1104
+ ])
1105
+
1106
+ # Sort positions by request indices
1107
+ sorted_indices = np.argsort(all_req_indices)
1108
+ logits_indices_selector = all_positions[sorted_indices]
1109
+
1110
+ return (req_ids_dp, req_indices_dp, num_scheduled_tokens_per_dp_rank,
1111
+ scheduled_tokens_per_dp_rank, num_req_per_dp_rank,
1112
+ padded_num_scheduled_tokens_per_dp_rank, padded_num_reqs,
1113
+ padded_total_num_scheduled_tokens, padded_num_reqs_per_dp_rank,
1114
+ logits_indices_selector, max_num_reqs_per_dp_rank)
1115
+
1116
+ def _prepare_async_token_substitution_indices_dp(
1117
+ self, req_ids_dp, scheduled_tokens_per_dp_rank,
1118
+ padded_num_scheduled_tokens_per_dp_rank, dp_size):
1119
+ """Prepare token substitution indices for async scheduling in DP mode."""
1120
+ token_in_tpu_cur_input_indices_dp = {}
1121
+ token_in_tpu_pre_next_tokens_indices_dp = {}
1122
+
1123
+ for dp_rank in range(dp_size):
1124
+ token_in_tpu_cur_input_indices_dp[dp_rank] = []
1125
+ token_in_tpu_pre_next_tokens_indices_dp[dp_rank] = []
1126
+
1127
+ token_offset = padded_num_scheduled_tokens_per_dp_rank * dp_rank
1128
+ acc_cur_len = token_offset
1129
+
1130
+ for i, req_id in enumerate(req_ids_dp[dp_rank]):
1131
+ acc_cur_len += scheduled_tokens_per_dp_rank[dp_rank][i]
1132
+ if req_id not in self._pre_async_results.placeholder_req_id_to_index:
1133
+ continue
1134
+
1135
+ token_in_tpu_cur_input_indices_dp[dp_rank].append(acc_cur_len -
1136
+ 1)
1137
+ token_in_tpu_pre_next_tokens_indices_dp[dp_rank].append(
1138
+ self._pre_async_results.placeholder_req_id_to_index[req_id]
1139
+ )
1140
+
1141
+ return token_in_tpu_cur_input_indices_dp, token_in_tpu_pre_next_tokens_indices_dp
1142
+
1143
+ def _prepare_async_token_substitution_indices_non_dp(
1144
+ self, num_reqs, num_scheduled_tokens_per_req):
1145
+ """Prepare token substitution indices for async scheduling in non-DP mode."""
1146
+ token_in_tpu_cur_input_indices_list = []
1147
+ token_in_tpu_pre_next_tokens_indices_list = []
1148
+ acc_cur_len = 0
1149
+
1150
+ for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
1151
+ acc_cur_len += num_scheduled_tokens_per_req[i]
1152
+ assert req_id is not None
1153
+ if req_id not in self._pre_async_results.placeholder_req_id_to_index:
1154
+ continue
1155
+
1156
+ token_in_tpu_cur_input_indices_list.append(acc_cur_len - 1)
1157
+ token_in_tpu_pre_next_tokens_indices_list.append(
1158
+ self._pre_async_results.placeholder_req_id_to_index[req_id])
1159
+
1160
+ if len(token_in_tpu_cur_input_indices_list) > 0:
1161
+ return (np.array(token_in_tpu_cur_input_indices_list),
1162
+ np.array(token_in_tpu_pre_next_tokens_indices_list))
1163
+ else:
1164
+ return np.array([]), np.array([])
1165
+
1166
+ def _apply_async_token_substitution(self, input_ids,
1167
+ token_in_tpu_cur_input_indices,
1168
+ token_in_tpu_pre_next_tokens_indices):
1169
+ """Apply async token substitution if needed."""
1170
+ if len(token_in_tpu_cur_input_indices) == 0:
1171
+ return input_ids
1172
+
1173
+ idx_pad_len = len(input_ids) - len(token_in_tpu_cur_input_indices)
1174
+
1175
+ # Pad according to the instructions written inside self._substitute_placeholder_token_fn
1176
+ full_range = np.arange(0, len(input_ids))
1177
+ missing_values = np.setdiff1d(full_range,
1178
+ token_in_tpu_cur_input_indices)
1179
+ padded_token_in_tpu_cur_input_indices = np.concatenate(
1180
+ (token_in_tpu_cur_input_indices, missing_values))
1181
+
1182
+ padded_token_in_tpu_pre_next_tokens_indices = np.pad(
1183
+ token_in_tpu_pre_next_tokens_indices, (0, idx_pad_len),
1184
+ mode='constant',
1185
+ constant_values=-1)
1186
+
1187
+ (padded_token_in_tpu_cur_input_indices,
1188
+ padded_token_in_tpu_pre_next_tokens_indices) = device_array(
1189
+ self.mesh, (padded_token_in_tpu_cur_input_indices,
1190
+ padded_token_in_tpu_pre_next_tokens_indices))
1191
+
1192
+ with self.maybe_forbid_compile:
1193
+ input_ids = self._substitute_placeholder_token_fn(
1194
+ input_ids, padded_token_in_tpu_cur_input_indices,
1195
+ padded_token_in_tpu_pre_next_tokens_indices,
1196
+ self._pre_async_results.next_tokens,
1197
+ len(token_in_tpu_cur_input_indices))
1198
+
1199
+ return input_ids
1200
+
1201
+ def _prepare_inputs(self, scheduler_output: "VllmSchedulerOutput"):
1202
+ if self.dp_size > 1:
1203
+ return self._prepare_inputs_dp(scheduler_output)
1204
+ else:
1205
+ return self._prepare_inputs_non_dp(scheduler_output)
1206
+
1207
+ def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"):
1208
+ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
1209
+ assert total_num_scheduled_tokens > 0
1210
+ num_reqs = self.input_batch.num_reqs
1211
+ assert num_reqs > 0
1212
+
1213
+ dp_size = self.dp_size
1214
+ data_parallel_attn_sharding = NamedSharding(
1215
+ self.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA))
1216
+
1217
+ (req_ids_dp, req_indices_dp, num_scheduled_tokens_per_dp_rank,
1218
+ scheduled_tokens_per_dp_rank, num_req_per_dp_rank,
1219
+ padded_num_scheduled_tokens_per_dp_rank, padded_num_reqs,
1220
+ padded_total_num_scheduled_tokens, padded_num_reqs_per_dp_rank,
1221
+ logits_indices_selector, max_num_reqs_per_dp_rank
1222
+ ) = self._prepare_dp_input_metadata(scheduler_output)
1223
+ # Multi-modal support
1224
+ # Calculate M-RoPE positions.
1225
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
1226
+ if self.uses_mrope:
1227
+ self.mm_manager.calc_mrope_positions(scheduler_output)
1228
+
1229
+ # Async scheduling: prepare token substitution indices for DP
1230
+ token_in_tpu_cur_input_indices_dp = {}
1231
+ token_in_tpu_pre_next_tokens_indices_dp = {}
1232
+ if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
1233
+ # If async previous results exists, we will prepare for the token substitution here
1234
+ # The actual substitution will be performed in tpu during later parts of this function.
1235
+ (token_in_tpu_cur_input_indices_dp,
1236
+ token_in_tpu_pre_next_tokens_indices_dp
1237
+ ) = self._prepare_async_token_substitution_indices_dp(
1238
+ req_ids_dp, scheduled_tokens_per_dp_rank,
1239
+ padded_num_scheduled_tokens_per_dp_rank, dp_size)
1240
+
1241
+ # Populates input_ids and positions
1242
+ for dp_rank in range(dp_size):
1243
+ if num_req_per_dp_rank[dp_rank] == 0:
1244
+ continue
1245
+ token_offset = padded_num_scheduled_tokens_per_dp_rank * dp_rank
1246
+ num_scheduled_tokens_per_req = scheduled_tokens_per_dp_rank[
1247
+ dp_rank]
1248
+ total_num_scheduled_tokens = num_scheduled_tokens_per_dp_rank[
1249
+ dp_rank]
1250
+ input_ids_cpu = self.input_ids_cpu[
1251
+ token_offset:token_offset +
1252
+ padded_num_scheduled_tokens_per_dp_rank]
1253
+ positions_cpu = self.positions_cpu[
1254
+ token_offset:token_offset +
1255
+ padded_num_scheduled_tokens_per_dp_rank]
1256
+ # Get request indices.
1257
+ # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
1258
+ # For each scheduled token, what are the corresponding req index.
1259
+ req_indices = np.repeat(req_indices_dp[dp_rank],
1260
+ num_scheduled_tokens_per_req)
1261
+ # Get batched arange.
1262
+ # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
1263
+ # For each scheduled token, what is its position in corresponding req.
1264
+ arange = np.concatenate(
1265
+ [self.arange_cpu[:n] for n in num_scheduled_tokens_per_req])
1266
+ # Get positions.
1267
+ positions_np = positions_cpu[:total_num_scheduled_tokens]
1268
+ np.add(
1269
+ self.input_batch.num_computed_tokens_cpu[req_indices],
1270
+ arange,
1271
+ out=positions_np,
1272
+ )
1273
+ # Get token indices.
1274
+ # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
1275
+ # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
1276
+ # where M is the max_model_len.
1277
+ token_indices = (
1278
+ positions_np +
1279
+ req_indices * self.input_batch.token_ids_cpu.shape[1])
1280
+ # NOTE(woosuk): We use torch.index_select instead of np.take here
1281
+ # because torch.index_select is much faster than np.take for large
1282
+ # tensors.
1283
+ np.take(
1284
+ self.input_batch.token_ids_cpu.ravel(),
1285
+ token_indices,
1286
+ out=input_ids_cpu[:total_num_scheduled_tokens],
1287
+ )
1288
+
1289
+ input_ids_cpu[total_num_scheduled_tokens:] = 0
1290
+
1291
+ # Prepare the attention metadata (query_start_loc_cpu, seq_lens_cpu)
1292
+ for dp_rank in range(dp_size):
1293
+ req_offset = dp_rank * max_num_reqs_per_dp_rank
1294
+ query_start_loc_cpu = self.query_start_loc_cpu[
1295
+ req_offset + dp_rank:req_offset + max_num_reqs_per_dp_rank +
1296
+ dp_rank + 1]
1297
+ seq_lens_cpu = self.seq_lens_cpu[req_offset:req_offset +
1298
+ max_num_reqs_per_dp_rank]
1299
+ _num_reqs = num_req_per_dp_rank[dp_rank]
1300
+ req_indices = req_indices_dp[dp_rank]
1301
+ num_scheduled_tokens_per_req = scheduled_tokens_per_dp_rank[
1302
+ dp_rank]
1303
+
1304
+ if _num_reqs == 0:
1305
+ query_start_loc_cpu[:] = 0
1306
+ seq_lens_cpu[:] = 0
1307
+ continue
1308
+
1309
+ np.cumsum(
1310
+ num_scheduled_tokens_per_req,
1311
+ out=query_start_loc_cpu[1:_num_reqs + 1],
1312
+ )
1313
+ query_start_loc_cpu[_num_reqs + 1:] = 1
1314
+
1315
+ seq_lens_cpu[:_num_reqs] = (
1316
+ self.input_batch.num_computed_tokens_cpu[req_indices] +
1317
+ num_scheduled_tokens_per_req)
1318
+ seq_lens_cpu[_num_reqs:] = 0
1319
+
1320
+ # populate logits_indices
1321
+ for dp_rank in range(dp_size):
1322
+ req_offset = dp_rank * padded_num_reqs_per_dp_rank
1323
+ query_loc_req_offset = dp_rank * (max_num_reqs_per_dp_rank + 1)
1324
+ _num_reqs = num_req_per_dp_rank[dp_rank]
1325
+
1326
+ logits_indices_cpu = self.logits_indices_cpu[
1327
+ req_offset:req_offset + padded_num_reqs_per_dp_rank]
1328
+ logits_indices_cpu[:_num_reqs] = (
1329
+ self.query_start_loc_cpu[query_loc_req_offset +
1330
+ 1:query_loc_req_offset + _num_reqs +
1331
+ 1] - 1)
1332
+ logits_indices_cpu[_num_reqs:] = -1
1333
+
1334
+ logits_indices = self.logits_indices_cpu[:padded_num_reqs]
1335
+
1336
+ # Please see runner_utils.PhasedBasedProfiler for details
1337
+ if self.phase_based_profiler:
1338
+ batch_composition_stats = runner_utils.get_batch_composition_stats(
1339
+ self.input_batch, total_num_scheduled_tokens, num_reqs,
1340
+ padded_total_num_scheduled_tokens, scheduler_output)
1341
+
1342
+ self.phase_based_profiler.step(batch_composition_stats)
1343
+
1344
+ # Inputs
1345
+ input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens]
1346
+ positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
1347
+ mrope_positions = self.mrope_positions_cpu[:, :
1348
+ padded_total_num_scheduled_tokens]
1349
+
1350
+ query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs +
1351
+ dp_size]
1352
+ seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
1353
+
1354
+ _request_distribution = []
1355
+ for dp_rank in range(dp_size):
1356
+ _num_reqs = num_req_per_dp_rank[dp_rank]
1357
+ # The batch has been reordered by _reorder_batch so decode requests come first
1358
+ # Count decode requests (those with num_scheduled_tokens == 1) in this DP rank
1359
+ num_decode_in_dp_rank = 0
1360
+ for req_id in req_ids_dp[dp_rank]:
1361
+ if scheduler_output.num_scheduled_tokens[req_id] == 1:
1362
+ num_decode_in_dp_rank += 1
1363
+ _request_distribution.append(
1364
+ [num_decode_in_dp_rank, num_decode_in_dp_rank, _num_reqs])
1365
+ request_distribution = np.array(_request_distribution).ravel()
1366
+
1367
+ use_spec_decode = len(
1368
+ scheduler_output.scheduled_spec_decode_tokens) > 0
1369
+ if not use_spec_decode:
1370
+ spec_decode_metadata = None
1371
+ else:
1372
+ num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
1373
+ for (
1374
+ req_id,
1375
+ draft_token_ids,
1376
+ ) in scheduler_output.scheduled_spec_decode_tokens.items():
1377
+ req_idx = self.input_batch.req_id_to_index[req_id]
1378
+ num_draft_tokens[req_idx] = len(draft_token_ids)
1379
+
1380
+ spec_decode_metadata = (
1381
+ self.speculative_decoding_manager.get_spec_decode_metadata(
1382
+ num_draft_tokens,
1383
+ self.query_start_loc_cpu[1:num_reqs + 1],
1384
+ padded_num_reqs,
1385
+ ))
1386
+ logits_indices = spec_decode_metadata.final_logits_indices
1387
+
1388
+ # Put to device
1389
+ sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
1390
+ self.mesh,
1391
+ self.input_batch,
1392
+ padded_num_reqs,
1393
+ sharding=data_parallel_attn_sharding,
1394
+ )
1395
+ if self.uses_mrope:
1396
+ positions = mrope_positions
1397
+
1398
+ query_start_loc_cpu = query_start_loc
1399
+ logits_indices_cpu = logits_indices
1400
+ seq_lens_cpu = seq_lens
1401
+
1402
+ (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1403
+ request_distribution) = device_array(
1404
+ self.mesh,
1405
+ (input_ids, positions, query_start_loc, seq_lens, logits_indices,
1406
+ request_distribution),
1407
+ sharding=data_parallel_attn_sharding,
1408
+ )
1409
+
1410
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1411
+ uniform_attention_metadata: AttentionMetadata = None
1412
+ for kv_cache_gid, kv_cache_group in enumerate(
1413
+ self.kv_cache_config.kv_cache_groups):
1414
+ block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1415
+ max_num_reqs]
1416
+ for dp_rank in range(dp_size):
1417
+ req_offset = dp_rank * max_num_reqs_per_dp_rank
1418
+ _num_reqs = num_req_per_dp_rank[dp_rank]
1419
+
1420
+ block_tables[
1421
+ req_offset:req_offset + _num_reqs, :self.
1422
+ max_num_blocks_per_req] = self.input_batch.block_table[
1423
+ kv_cache_gid].get_cpu_tensor()[req_indices_dp[dp_rank]]
1424
+ # Convert block_tables to 1D on cpu.
1425
+ block_tables = block_tables.reshape(-1)
1426
+ block_tables = device_array(
1427
+ self.mesh,
1428
+ (block_tables),
1429
+ sharding=data_parallel_attn_sharding,
1430
+ )
1431
+
1432
+ attention_metadata_gid = AttentionMetadata(
1433
+ input_positions=positions,
1434
+ block_tables=block_tables,
1435
+ seq_lens=seq_lens,
1436
+ query_start_loc=query_start_loc,
1437
+ request_distribution=request_distribution,
1438
+ )
1439
+
1440
+ # This is for making these cpu buffers hidden during tracing
1441
+ attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1442
+ attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1443
+
1444
+ if not self.use_hybrid_kvcache:
1445
+ uniform_attention_metadata = attention_metadata_gid
1446
+ else:
1447
+ for layer_name in kv_cache_group.layer_names:
1448
+ attention_metadata_per_layer[
1449
+ layer_name] = attention_metadata_gid
1450
+
1451
+ # Async scheduling: substitute placeholder tokens for DP
1452
+ if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
1453
+ # Collect all token indices that need substitution across all DP ranks
1454
+ all_token_indices_to_substitute = []
1455
+ all_pre_next_tokens_indices = []
1456
+
1457
+ for dp_rank in range(dp_size):
1458
+ cur_indices = token_in_tpu_cur_input_indices_dp[dp_rank]
1459
+ pre_indices = token_in_tpu_pre_next_tokens_indices_dp[dp_rank]
1460
+ all_token_indices_to_substitute.extend(cur_indices)
1461
+ all_pre_next_tokens_indices.extend(pre_indices)
1462
+
1463
+ if len(all_token_indices_to_substitute) > 0:
1464
+ token_in_tpu_cur_input_indices = np.array(
1465
+ all_token_indices_to_substitute)
1466
+ token_in_tpu_pre_next_tokens_indices = np.array(
1467
+ all_pre_next_tokens_indices)
1468
+ input_ids = self._apply_async_token_substitution(
1469
+ input_ids, token_in_tpu_cur_input_indices,
1470
+ token_in_tpu_pre_next_tokens_indices)
1471
+
1472
+ if self.lora_config is not None:
1473
+ self.lora_utils.set_active_loras(
1474
+ num_scheduled_tokens_per_req,
1475
+ total_num_scheduled_tokens,
1476
+ padded_total_num_scheduled_tokens,
1477
+ )
1478
+
1479
+ if self.use_hybrid_kvcache:
1480
+ attention_metadata = attention_metadata_per_layer
1481
+ else:
1482
+ attention_metadata = uniform_attention_metadata
1483
+ return (
1484
+ input_ids,
1485
+ positions,
1486
+ attention_metadata,
1487
+ sampling_metadata,
1488
+ logits_indices,
1489
+ spec_decode_metadata,
1490
+ logits_indices_selector,
1491
+ padded_num_reqs,
1492
+ )
1493
+
1494
+ def _prepare_inputs_non_dp(self, scheduler_output: "VllmSchedulerOutput"):
1495
+ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
1496
+ assert total_num_scheduled_tokens > 0
1497
+ num_reqs = self.input_batch.num_reqs
1498
+ assert num_reqs > 0
1499
+
1500
+ # Get the number of scheduled tokens for each request.
1501
+ num_scheduled_tokens_per_req = []
1502
+ max_num_scheduled_tokens_all_reqs = 0
1503
+ for req_id in self.input_batch.req_ids[:num_reqs]:
1504
+ assert req_id is not None
1505
+ num_tokens = scheduler_output.num_scheduled_tokens[req_id]
1506
+ num_scheduled_tokens_per_req.append(num_tokens)
1507
+ max_num_scheduled_tokens_all_reqs = max(
1508
+ max_num_scheduled_tokens_all_reqs, num_tokens)
1509
+ num_scheduled_tokens_per_req = np.array(num_scheduled_tokens_per_req,
1510
+ dtype=np.int32)
1511
+ assert max_num_scheduled_tokens_all_reqs > 0
1512
+ padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
1513
+ num_reqs, self.max_num_reqs)
1514
+
1515
+ # Get request indices.
1516
+ # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
1517
+ # For each scheduled token, what are the corresponding req index.
1518
+ req_indices = np.repeat(self.arange_cpu[:num_reqs],
1519
+ num_scheduled_tokens_per_req)
1520
+ token_in_tpu_cur_input_indices = np.array([])
1521
+ token_in_tpu_pre_next_tokens_indices = np.array([])
1522
+ if self.scheduler_config.async_scheduling and self._pre_async_results is not None:
1523
+ # If async previous results exists, we will prepare for the token substitution here
1524
+ # The actual substitution will be performed in tpu during later parts of this function.
1525
+ (token_in_tpu_cur_input_indices,
1526
+ token_in_tpu_pre_next_tokens_indices
1527
+ ) = self._prepare_async_token_substitution_indices_non_dp(
1528
+ num_reqs, num_scheduled_tokens_per_req)
1529
+
1530
+ # Get batched arange.
1531
+ # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
1532
+ # For each scheduled token, what is its position in corresponding req.
1533
+ arange = np.concatenate(
1534
+ [self.arange_cpu[:n] for n in num_scheduled_tokens_per_req])
1535
+
1536
+ # Get positions.
1537
+ positions_np = self.positions_cpu[:total_num_scheduled_tokens]
1538
+ np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
1539
+ arange,
1540
+ out=positions_np)
1541
+
1542
+ # Multi-modal support
1543
+ # Calculate M-RoPE positions.
1544
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
1545
+ if self.uses_mrope:
1546
+ self.mm_manager.calc_mrope_positions(scheduler_output)
1547
+
1548
+ # Get token indices.
1549
+ # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
1550
+ # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
1551
+ # where M is the max_model_len.
1552
+ token_indices = (positions_np +
1553
+ req_indices * self.input_batch.token_ids_cpu.shape[1])
1554
+
1555
+ # NOTE(woosuk): We use torch.index_select instead of np.take here
1556
+ # because torch.index_select is much faster than np.take for large
1557
+ # tensors.
1558
+ np.take(self.input_batch.token_ids_cpu.ravel(),
1559
+ token_indices,
1560
+ out=self.input_ids_cpu[:total_num_scheduled_tokens])
1561
+
1562
+ # Prepare the attention metadata.
1563
+ self.query_start_loc_cpu[0] = 0
1564
+ np.cumsum(num_scheduled_tokens_per_req,
1565
+ out=self.query_start_loc_cpu[1:num_reqs + 1])
1566
+ self.query_start_loc_cpu[num_reqs + 1:] = 1
1567
+
1568
+ self.seq_lens_cpu[:num_reqs] = (
1569
+ self.input_batch.num_computed_tokens_cpu[:num_reqs] +
1570
+ num_scheduled_tokens_per_req)
1571
+
1572
+ # Do the padding and copy the tensors to the TPU.
1573
+ padded_total_num_scheduled_tokens = runner_utils.get_padded_token_len(
1574
+ self.num_tokens_paddings, total_num_scheduled_tokens)
1575
+ # Zero out to avoid spurious values from prev iteration (last cp chunk)
1576
+ self.input_ids_cpu[
1577
+ total_num_scheduled_tokens:padded_total_num_scheduled_tokens] = 0
1578
+
1579
+ # Please see runner_utils.PhasedBasedProfiler for details
1580
+ if self.phase_based_profiler:
1581
+ batch_composition_stats = runner_utils.get_batch_composition_stats(
1582
+ self.input_batch, total_num_scheduled_tokens, num_reqs,
1583
+ padded_total_num_scheduled_tokens, scheduler_output)
1584
+
1585
+ self.phase_based_profiler.step(batch_composition_stats)
1586
+
1587
+ # Inputs
1588
+ input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens]
1589
+ positions = self.positions_cpu[:padded_total_num_scheduled_tokens]
1590
+ mrope_positions = self.mrope_positions_cpu[:, :
1591
+ padded_total_num_scheduled_tokens]
1592
+
1593
+ # TODO(pooyam): Some paddings are up to `num_reqs_paddings` (spec decoding, select hidden states, etc) and some other are to `max_num_reqs` (block table, seq_lens). We should stick to one of them maybe?
1594
+ query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1]
1595
+ seq_lens = self.seq_lens_cpu[:self.max_num_reqs]
1596
+ request_distribution = np.array(self.input_batch.request_distribution)
1597
+ use_spec_decode = len(
1598
+ scheduler_output.scheduled_spec_decode_tokens) > 0
1599
+ if not use_spec_decode:
1600
+ logits_indices = self.query_start_loc_cpu[1:padded_num_reqs +
1601
+ 1] - 1
1602
+ spec_decode_metadata = None
1603
+ else:
1604
+ num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
1605
+ for req_id, draft_token_ids in (
1606
+ scheduler_output.scheduled_spec_decode_tokens.items()):
1607
+ req_idx = self.input_batch.req_id_to_index[req_id]
1608
+ num_draft_tokens[req_idx] = len(draft_token_ids)
1609
+
1610
+ spec_decode_metadata = self.speculative_decoding_manager.get_spec_decode_metadata(
1611
+ num_draft_tokens, self.query_start_loc_cpu[1:num_reqs + 1],
1612
+ padded_num_reqs)
1613
+ logits_indices = spec_decode_metadata.final_logits_indices
1614
+
1615
+ # Put to device
1616
+ sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
1617
+ self.mesh, self.input_batch, padded_num_reqs)
1618
+ if self.uses_mrope:
1619
+ positions = mrope_positions
1620
+ query_start_loc_cpu = query_start_loc
1621
+ seq_lens_cpu = seq_lens
1622
+
1623
+ (input_ids, positions, query_start_loc, seq_lens,
1624
+ logits_indices, request_distribution) = device_array(
1625
+ self.mesh, (input_ids, positions, query_start_loc, seq_lens,
1626
+ logits_indices, request_distribution))
1627
+
1628
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
1629
+ uniform_attention_metadata: AttentionMetadata = None
1630
+ for kv_cache_gid, kv_cache_group in enumerate(
1631
+ self.kv_cache_config.kv_cache_groups):
1632
+ block_tables = self.block_tables_cpu[kv_cache_gid][:self.
1633
+ max_num_reqs]
1634
+ block_tables[:num_reqs] = (
1635
+ self.input_batch.block_table[kv_cache_gid].get_cpu_tensor()
1636
+ [:num_reqs])
1637
+ # Convert block_tables to 1D on cpu.
1638
+ block_tables = block_tables.reshape(-1)
1639
+ block_tables = device_array(self.mesh, (block_tables))
1640
+
1641
+ attention_metadata_gid = AttentionMetadata(
1642
+ input_positions=positions,
1643
+ block_tables=block_tables,
1644
+ seq_lens=seq_lens,
1645
+ query_start_loc=query_start_loc,
1646
+ request_distribution=request_distribution)
1647
+ # This is for making these cpu buffers hidden during tracing
1648
+ attention_metadata_gid.query_start_loc_cpu = query_start_loc_cpu
1649
+ attention_metadata_gid.seq_lens_cpu = seq_lens_cpu
1650
+
1651
+ if not self.use_hybrid_kvcache:
1652
+ # all layers share the same attention metadata
1653
+ uniform_attention_metadata = attention_metadata_gid
1654
+ else:
1655
+ for layer_name in kv_cache_group.layer_names:
1656
+ attention_metadata_per_layer[
1657
+ layer_name] = attention_metadata_gid
1658
+
1659
+ if self.scheduler_config.async_scheduling and len(
1660
+ token_in_tpu_cur_input_indices) > 0:
1661
+ assert self._pre_async_results is not None
1662
+ input_ids = self._apply_async_token_substitution(
1663
+ input_ids, token_in_tpu_cur_input_indices,
1664
+ token_in_tpu_pre_next_tokens_indices)
1665
+
1666
+ if self.lora_config is not None:
1667
+ self.lora_utils.set_active_loras(
1668
+ num_scheduled_tokens_per_req, total_num_scheduled_tokens,
1669
+ padded_total_num_scheduled_tokens)
1670
+ logits_indices_selector = None
1671
+
1672
+ if self.use_hybrid_kvcache:
1673
+ attention_metadata = attention_metadata_per_layer
1674
+ else:
1675
+ attention_metadata = uniform_attention_metadata
1676
+ return (input_ids, positions, attention_metadata, sampling_metadata,
1677
+ logits_indices, spec_decode_metadata, logits_indices_selector,
1678
+ padded_num_reqs)
1679
+
1680
+ def _get_input_ids_embeds(self, input_ids: jax.Array,
1681
+ mm_embeds: list[jax.Array]):
1682
+ if self.is_multimodal_model:
1683
+ inputs_embeds = self.get_input_embeddings_fn(
1684
+ self.state,
1685
+ input_ids,
1686
+ mm_embeds,
1687
+ )
1688
+ return None, inputs_embeds
1689
+ else:
1690
+ return input_ids, None
1691
+
1692
+ def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
1693
+ return self.speculative_decoding_manager.take_draft_token_ids()
1694
+
1695
+ ###### Local disagg utilities ######
1696
+
1697
+ def get_kv_cache_for_block_ids(
1698
+ self,
1699
+ block_ids: List[int],
1700
+ ) -> List[jax.Array]:
1701
+ return self.kv_cache_manager.get_kv_cache_for_block_ids(block_ids)
1702
+
1703
+ def transfer_kv_cache(self,
1704
+ kv_cache_slices: List[jax.Array]) -> List[jax.Array]:
1705
+ return self.kv_cache_manager.transfer_kv_cache(kv_cache_slices)
1706
+
1707
+ def insert_request_with_kv_cache(
1708
+ self,
1709
+ request: "Request",
1710
+ kv_cache_slices: List[jax.Array],
1711
+ block_ids: List[List[int]],
1712
+ ):
1713
+ return self.kv_cache_manager.insert_request_with_kv_cache(
1714
+ request, kv_cache_slices, block_ids)
1715
+
1716
+ ###### RL framework integration ######
1717
+
1718
+ def _sync_weights(
1719
+ self,
1720
+ updated_weights: jaxtyping.PyTree,
1721
+ mappings: Dict[str, Tuple[str, Tuple[str]]],
1722
+ transpose_keys: Dict[str, Tuple[int]],
1723
+ reshard_fn: Callable[[jaxtyping.PyTree, jaxtyping.PyTree],
1724
+ jaxtyping.PyTree] = None
1725
+ ) -> None:
1726
+ """For RL framework integration."""
1727
+ if reshard_fn is not None:
1728
+ updated_weights = reshard_fn(updated_weights, self.state)
1729
+ shard = None
1730
+ else:
1731
+ shard = functools.partial(shard_put, mesh=self.mesh)
1732
+ self.state = transfer_state_with_mappings(
1733
+ src_state=updated_weights,
1734
+ tgt_state=self.state,
1735
+ mappings=mappings,
1736
+ transpose_keys=transpose_keys,
1737
+ shard=shard)
1738
+
1739
+ def get_intermediate_tensor_spec(self, num_tokens: int):
1740
+ jax_dtype = to_jax_dtype(self.dtype)
1741
+ num_padded_tokens = runner_utils.get_padded_token_len(
1742
+ self.num_tokens_paddings, num_tokens)
1743
+ sharding = NamedSharding(self.mesh, PartitionSpec())
1744
+ hidden_size = self.model_config.get_hidden_size()
1745
+ spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
1746
+ dtype=jax_dtype,
1747
+ sharding=sharding)
1748
+ tensor_spec = {"hidden_states": spec, "residual": spec}
1749
+ return tensor_spec
1750
+
1751
+ def get_uuid_for_jax_transfer(self,
1752
+ scheduler_output: "VllmSchedulerOutput",
1753
+ rank: int, step: int) -> int:
1754
+ '''
1755
+ Get a uuid for jax.transfer, here we use the hash of
1756
+ scheduler_output + counter_step + sender's rank
1757
+ '''
1758
+ scheduler_output_str = ""
1759
+ if not scheduler_output.num_scheduled_tokens:
1760
+ scheduler_output_str = "empty_batch"
1761
+ else:
1762
+ scheduler_output_str = str(
1763
+ sorted(scheduler_output.num_scheduled_tokens.items()))
1764
+ unique_str = f'{scheduler_output_str} {step} {rank}'
1765
+ import hashlib
1766
+ hasher = hashlib.sha1()
1767
+ hasher.update(unique_str.encode('utf-8'))
1768
+ return int.from_bytes(hasher.digest()[:8], 'big')