tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,262 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from dataclasses import dataclass
18
+ from typing import TYPE_CHECKING, Optional
19
+
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+ from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
23
+ from vllm.v1.outputs import DraftTokenIds
24
+ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
25
+
26
+ from tpu_inference.runner import utils as runner_utils
27
+ from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
28
+ from tpu_inference.utils import device_array
29
+
30
+ if TYPE_CHECKING:
31
+ from tpu_inference.layers.common.attention_metadata import \
32
+ AttentionMetadata
33
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
34
+
35
+
36
+ @dataclass
37
+ class SpecDecodeMetadata:
38
+ """Metadata for speculative decoding on JAX/TPU, containing all necessary indices."""
39
+ draft_token_ids: jnp.ndarray
40
+ draft_lengths: jnp.ndarray
41
+ draft_lengths_cpu: np.ndarray
42
+ target_logits_indices: jnp.ndarray
43
+ bonus_logits_indices: jnp.ndarray
44
+ final_logits_indices: jnp.ndarray
45
+
46
+
47
+ class SpeculativeDecodingManager:
48
+
49
+ def __init__(self, runner: TPUModelRunner):
50
+ self.runner = runner
51
+ # Cached draft tokens.
52
+ self._draft_token_ids: Optional[list[list[int]]] = None
53
+
54
+ def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
55
+ if self._draft_token_ids is None:
56
+ return None
57
+ req_ids = self.runner.input_batch.req_ids
58
+ draft_token_ids = self._draft_token_ids
59
+ self._draft_token_ids = None
60
+ return DraftTokenIds(req_ids, draft_token_ids)
61
+
62
+ def propose_draft_token_ids(
63
+ self,
64
+ sampled_token_ids: list[list[int]],
65
+ aux_hidden_states: Optional[tuple[jnp.ndarray, ...]],
66
+ attn_metadata: AttentionMetadata,
67
+ spec_decode_metadata: Optional[SpecDecodeMetadata],
68
+ scheduler_output: Optional[VllmSchedulerOutput] = None,
69
+ input_ids: Optional[jnp.ndarray] = None,
70
+ ) -> None:
71
+ if self.runner.speculative_config.method == "ngram":
72
+ assert isinstance(self.runner.drafter, NgramProposer)
73
+ self._draft_token_ids = self.runner.drafter.propose(
74
+ sampled_token_ids[:self.runner.input_batch.num_reqs],
75
+ self.runner.input_batch.req_ids,
76
+ self.runner.input_batch.num_tokens_no_spec,
77
+ self.runner.input_batch.token_ids_cpu,
78
+ self.runner.input_batch.spec_decode_unsupported_reqs)
79
+ elif self.runner.speculative_config.method == "eagle3":
80
+ self._draft_token_ids = self.propose_eagle3_draft_token_ids(
81
+ sampled_token_ids,
82
+ aux_hidden_states,
83
+ attn_metadata,
84
+ spec_decode_metadata,
85
+ scheduler_output,
86
+ input_ids,
87
+ )
88
+ else:
89
+ raise NotImplementedError(
90
+ f"Speculative decoding method "
91
+ f"'{self.runner.speculative_config.method}' is not supported.")
92
+
93
+ def propose_eagle3_draft_token_ids(
94
+ self,
95
+ sampled_token_ids: list[list[int]],
96
+ aux_hidden_states: Optional[tuple[jnp.ndarray, ...]],
97
+ attn_metadata: AttentionMetadata,
98
+ spec_decode_metadata: Optional[SpecDecodeMetadata],
99
+ scheduler_output: VllmSchedulerOutput,
100
+ input_ids: jnp.ndarray,
101
+ ) -> list[list[int]]:
102
+ assert isinstance(self.runner.drafter, Eagle3Proposer)
103
+
104
+ # TODO(woosuk): Refactor the loop.
105
+ req_ids = self.runner.input_batch.req_ids
106
+ next_token_ids: list[int] = []
107
+ for i, token_ids in enumerate(sampled_token_ids):
108
+ if token_ids:
109
+ # Common case.
110
+ next_token_id = token_ids[-1]
111
+ else:
112
+ # Partial prefill (rare case).
113
+ # Get the next token id from the request state.
114
+ req_id = req_ids[i]
115
+ req_state = self.runner.requests[req_id]
116
+ seq_len = (req_state.num_computed_tokens +
117
+ scheduler_output.num_scheduled_tokens[req_id])
118
+ next_token_id = req_state.get_token_id(seq_len)
119
+ next_token_ids.append(next_token_id)
120
+
121
+ # Pad the batch size to match with existing padding for target model
122
+ pad_len = attn_metadata.seq_lens.shape[0] - len(next_token_ids)
123
+ assert pad_len >= 0
124
+ next_token_ids += [0] * pad_len
125
+
126
+ next_token_ids = device_array(
127
+ self.runner.mesh, np.array(next_token_ids, dtype=jnp.int32))
128
+
129
+ if spec_decode_metadata is None:
130
+ num_rejected_tokens = None
131
+ else:
132
+ num_draft_tokens = spec_decode_metadata.draft_lengths_cpu
133
+ num_rejected_tokens = [
134
+ int(n) + 1 - len(sampled_token_ids[i]) if n > 0 else 0
135
+ for i, n in enumerate(num_draft_tokens)
136
+ ]
137
+
138
+ pad_len = self.runner.max_num_reqs - len(num_rejected_tokens)
139
+ num_rejected_tokens += [0] * pad_len
140
+ num_rejected_tokens = device_array(
141
+ self.runner.mesh, np.array(num_rejected_tokens,
142
+ dtype=jnp.int32))
143
+
144
+ target_hidden_states, input_ids, last_token_indices, attn_metadata = self.runner.drafter.prepare_inputs(
145
+ attn_metadata,
146
+ input_ids,
147
+ aux_hidden_states,
148
+ next_token_ids,
149
+ num_rejected_tokens,
150
+ )
151
+
152
+ self.runner.kv_caches, draft_token_ids = self.runner.drafter.propose(
153
+ kv_caches=self.runner.kv_caches,
154
+ input_ids=input_ids,
155
+ attn_metadata=attn_metadata,
156
+ last_token_indices=last_token_indices,
157
+ target_hidden_states=target_hidden_states,
158
+ )
159
+ draft_token_ids = np.array(draft_token_ids)
160
+ if draft_token_ids.ndim == 1:
161
+ draft_token_ids = np.expand_dims(draft_token_ids, axis=-1)
162
+ return draft_token_ids.tolist()
163
+
164
+ def get_spec_decode_metadata(
165
+ self,
166
+ num_draft_tokens: np.ndarray,
167
+ cu_num_scheduled_tokens: np.ndarray,
168
+ padded_num_reqs: int,
169
+ ) -> SpecDecodeMetadata:
170
+ # Inputs:
171
+ # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209]
172
+ # num_draft_tokens: [ 3, 0, 2, 0, 1]
173
+ # Outputs:
174
+ # cu_num_draft_tokens: [ 3, 3, 5, 5, 6]
175
+ # logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106,
176
+ # 206, 207, 208]
177
+ # target_logits_indices: [ 0, 1, 2, 5, 6, 9]
178
+ # bonus_logits_indices: [ 3, 4, 7, 8, 10]
179
+
180
+ # Compute the logits indices.
181
+ # [4, 1, 3, 1, 2]
182
+ num_sampled_tokens = num_draft_tokens + 1
183
+
184
+ # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11]
185
+ # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
186
+ cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
187
+ arange = np.concatenate(
188
+ [self.runner.arange_cpu[:n] for n in num_sampled_tokens])
189
+ # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
190
+ logits_indices = np.repeat(
191
+ cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
192
+ # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
193
+ logits_indices += arange
194
+ # Compute the bonus logits indices.
195
+ bonus_logits_indices = cu_num_sampled_tokens - 1
196
+
197
+ # Compute the draft logits indices.
198
+ # arange: [0, 1, 2, 0, 1, 0]
199
+ arange = np.concatenate(
200
+ [self.runner.arange_cpu[:n] for n in num_draft_tokens])
201
+ # [0, 0, 0, 5, 5, 9]
202
+ target_logits_indices = np.repeat(
203
+ cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
204
+ # [0, 1, 2, 5, 6, 9]
205
+ target_logits_indices += arange
206
+
207
+ # Compute the draft token ids.
208
+ # draft_token_indices: [ 1, 2, 3, 105, 106, 208]
209
+ draft_token_ids = self.runner.input_ids_cpu[logits_indices]
210
+ draft_token_ids = draft_token_ids[target_logits_indices + 1]
211
+ padded_logits_length = runner_utils.get_padded_token_len(
212
+ self.runner.num_logits_paddings, logits_indices.shape[0])
213
+ padded_logits_indices = np.concatenate([
214
+ logits_indices,
215
+ np.zeros(padded_logits_length - logits_indices.shape[0],
216
+ dtype=np.int32)
217
+ ])
218
+
219
+ assert bonus_logits_indices.shape[0] <= padded_num_reqs, (
220
+ f"bonus_logits_indices.shape[0]={bonus_logits_indices.shape[0]} "
221
+ f"padded_num_reqs={padded_num_reqs}")
222
+
223
+ padded_bonus_logits_indices = np.concatenate([
224
+ bonus_logits_indices,
225
+ np.zeros(padded_num_reqs - bonus_logits_indices.shape[0],
226
+ dtype=np.int32)
227
+ ])
228
+ padded_num_draft_tokens = np.concatenate([
229
+ num_draft_tokens,
230
+ np.zeros(padded_num_reqs - num_draft_tokens.shape[0],
231
+ dtype=np.int32)
232
+ ])
233
+ padded_draft_token_ids = np.concatenate([
234
+ draft_token_ids,
235
+ np.zeros(padded_logits_length - draft_token_ids.shape[0],
236
+ dtype=np.int32)
237
+ ])
238
+ padded_target_logits_indices = np.concatenate([
239
+ target_logits_indices,
240
+ np.zeros(padded_logits_length - target_logits_indices.shape[0],
241
+ dtype=np.int32)
242
+ ])
243
+
244
+ padded_num_draft_tokens_cpu = padded_num_draft_tokens
245
+ # CPU -> TPU copy.
246
+ (padded_num_draft_tokens, padded_draft_token_ids,
247
+ padded_logits_indices, padded_target_logits_indices,
248
+ padded_bonus_logits_indices) = device_array(
249
+ self.runner.mesh,
250
+ (padded_num_draft_tokens, padded_draft_token_ids,
251
+ padded_logits_indices, padded_target_logits_indices,
252
+ padded_bonus_logits_indices))
253
+
254
+ metadata = SpecDecodeMetadata(
255
+ draft_token_ids=padded_draft_token_ids,
256
+ draft_lengths=padded_num_draft_tokens,
257
+ draft_lengths_cpu=padded_num_draft_tokens_cpu,
258
+ target_logits_indices=padded_target_logits_indices,
259
+ bonus_logits_indices=padded_bonus_logits_indices,
260
+ final_logits_indices=padded_logits_indices,
261
+ )
262
+ return metadata
@@ -0,0 +1,101 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ from typing import TYPE_CHECKING, Tuple
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+
21
+ from tpu_inference.utils import device_array
22
+
23
+ if TYPE_CHECKING:
24
+ from vllm.v1.core.sched.output import GrammarOutput
25
+
26
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
27
+
28
+
29
+ class StructuredDecodingManager:
30
+
31
+ def __init__(self, runner: "TPUModelRunner"):
32
+ self.runner = runner
33
+
34
+ @functools.partial(jax.jit, static_argnums=(0, ))
35
+ def structured_decode_fn(self, require_struct_decoding: jax.Array,
36
+ grammar_bitmask: jax.Array, logits: jax.Array,
37
+ arange: jax.Array) -> jax.Array:
38
+ return jax.lax.cond(
39
+ jnp.any(require_struct_decoding),
40
+ lambda: self._apply_grammar_bitmask_kernel(
41
+ logits, grammar_bitmask, require_struct_decoding, arange),
42
+ lambda: logits)
43
+
44
+ @functools.partial(jax.jit, static_argnums=(0, ))
45
+ def _apply_grammar_bitmask_kernel(self, logits: jax.Array,
46
+ grammar_bitmask: jax.Array,
47
+ require_struct_decoding: jax.Array,
48
+ arange: jax.Array) -> jax.Array:
49
+
50
+ # Unpack the bitmask for the entire batch at once.
51
+ # grammar_bitmask: (B, N) where B=num_reqs, N=cdiv(vocab_size, 32)
52
+ # arange: (32,)
53
+ # (B, N, 1) and (1, 1, 32) broadcast to (B, N, 32)
54
+ unpacked_bitmask = jnp.right_shift(grammar_bitmask[:, :, None],
55
+ arange[None, None, :]) & 1 == 0
56
+
57
+ # Reshape to (B, vocab_size) and apply to logits.
58
+ # (B, N * 32) -> (B, vocab_size)
59
+ unpacked_bitmask = unpacked_bitmask.reshape(
60
+ logits.shape[0], -1)[:, :self.runner.vocab_size]
61
+
62
+ masked_logits = jnp.where(unpacked_bitmask, -jnp.inf, logits)
63
+
64
+ return jnp.where(require_struct_decoding, masked_logits, logits)
65
+
66
+ def prepare_structured_decoding_input(
67
+ self, logits: jax.Array, grammar_output: "GrammarOutput"
68
+ ) -> Tuple[jax.Array, jax.Array, jax.Array]:
69
+ grammar_bitmask = grammar_output.grammar_bitmask
70
+ assert grammar_bitmask is not None
71
+ num_reqs, _ = logits.shape
72
+
73
+ # Reset pre-allocated tensors
74
+ self.runner.grammar_bitmask_cpu.fill(0)
75
+ self.runner.require_structured_out_cpu.fill(0)
76
+
77
+ sorted_struct_requests = sorted(
78
+ grammar_output.structured_output_request_ids)
79
+
80
+ cumulative_mask_idx = 0
81
+ for req_id in sorted_struct_requests:
82
+ if req_id not in self.runner.input_batch.req_id_to_index:
83
+ continue
84
+ batch_index = self.runner.input_batch.req_id_to_index[req_id]
85
+ self.runner.grammar_bitmask_cpu[batch_index] = grammar_bitmask[
86
+ cumulative_mask_idx]
87
+ # It's not guaranteed that all requests in this batch require
88
+ # structured output, so create a bool tensor to represent
89
+ # the requests that need structured output.
90
+ self.runner.require_structured_out_cpu[batch_index] = True
91
+ cumulative_mask_idx += 1
92
+
93
+ (require_structured_out_cpu,
94
+ grammar_bitmask_cpu, structured_decode_arange) = device_array(
95
+ self.runner.mesh,
96
+ (self.runner.require_structured_out_cpu[:num_reqs],
97
+ self.runner.grammar_bitmask_cpu[:num_reqs],
98
+ self.runner.structured_decode_arange))
99
+
100
+ return (require_structured_out_cpu, grammar_bitmask_cpu,
101
+ structured_decode_arange)