tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,528 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ JAX-based rejection sampler for speculative decoding on TPU.
16
+
17
+ This implementation follows the same algorithm as the GPU version but is
18
+ designed for JAX/TPU compatibility. It currently only supports greedy sampling.
19
+ """
20
+
21
+ import functools
22
+ from typing import Optional
23
+
24
+ import jax
25
+ import jax.numpy as jnp
26
+ import numpy as np
27
+
28
+ from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
29
+ from tpu_inference.layers.jax.sample.sampling_metadata import \
30
+ TPUSupportedSamplingMetadata
31
+
32
+ # Placeholder token ID for rejected tokens
33
+ PLACEHOLDER_TOKEN_ID = -1
34
+ GREEDY_TEMPERATURE = -1
35
+
36
+
37
+ class RejectionSampler:
38
+ """
39
+ JAX-based rejection sampler for speculative decoding.
40
+
41
+ The implementation follows the algorithm described in
42
+ https://arxiv.org/abs/2211.17192.
43
+ """
44
+
45
+ def __init__(self):
46
+ pass
47
+
48
+ def __call__(
49
+ self,
50
+ # [num_tokens] - flattened format
51
+ draft_token_ids: jnp.ndarray,
52
+ # [batch_size] - number of draft tokens per request
53
+ num_draft_tokens: jnp.ndarray,
54
+ # [num_tokens, vocab_size] - flattened format
55
+ draft_probs: Optional[jnp.ndarray],
56
+ # [num_tokens, vocab_size] - flattened format
57
+ target_logits: jnp.ndarray,
58
+ # [batch_size]
59
+ bonus_token_ids: jnp.ndarray,
60
+ sampling_metadata: TPUSupportedSamplingMetadata,
61
+ key: Optional[jax.random.PRNGKey] = None,
62
+ ) -> jnp.ndarray:
63
+ """
64
+ Perform rejection sampling on draft tokens with flattened inputs.
65
+
66
+ Args:
67
+ draft_token_ids: Draft token IDs in flattened format [num_tokens].
68
+ num_draft_tokens: Number of draft tokens per request [batch_size].
69
+ draft_probs: Draft probabilities in flattened format [num_tokens, vocab_size].
70
+ target_probs: Target probabilities in flattened format [num_tokens, vocab_size].
71
+ bonus_token_ids: Bonus token IDs [batch_size].
72
+ sampling_metadata: Additional metadata needed for sampling.
73
+ key: JAX random key for non-greedy sampling.
74
+
75
+ Returns:
76
+ output_token_ids: A tensor containing the final output token IDs.
77
+ """
78
+ return self.forward(
79
+ draft_token_ids=draft_token_ids,
80
+ num_draft_tokens=num_draft_tokens,
81
+ draft_probs=draft_probs,
82
+ target_logits=target_logits,
83
+ bonus_token_ids=bonus_token_ids,
84
+ sampling_metadata=sampling_metadata,
85
+ key=key,
86
+ )
87
+
88
+ @functools.partial(jax.jit, static_argnums=(0, ))
89
+ def forward(
90
+ self,
91
+ # [num_tokens] - flattened format
92
+ draft_token_ids: jnp.ndarray,
93
+ # [batch_size] - number of draft tokens per request
94
+ num_draft_tokens: jnp.ndarray,
95
+ # [num_tokens, vocab_size] - flattened format
96
+ draft_probs: Optional[jnp.ndarray],
97
+ # [num_tokens, vocab_size] - flattened format
98
+ target_logits: jnp.ndarray,
99
+ # [batch_size]
100
+ bonus_token_ids: jnp.ndarray,
101
+ sampling_metadata: TPUSupportedSamplingMetadata,
102
+ key: Optional[jax.random.PRNGKey] = None,
103
+ ) -> jnp.ndarray:
104
+ """
105
+ Perform rejection sampling on draft tokens with flattened inputs.
106
+
107
+ Args:
108
+ draft_token_ids: Draft token IDs in flattened format [num_tokens].
109
+ num_draft_tokens: Number of draft tokens per request [batch_size].
110
+ draft_probs: Draft probabilities in flattened format [num_tokens, vocab_size].
111
+ target_logits: Target logits in flattened format [num_tokens, vocab_size].
112
+ bonus_token_ids: Bonus token IDs [batch_size].
113
+ sampling_metadata: Additional metadata needed for sampling.
114
+ key: JAX random key for non-greedy sampling.
115
+
116
+ Returns:
117
+ output_token_ids: A tensor containing the final output token IDs.
118
+ """
119
+
120
+ if sampling_metadata.do_sampling:
121
+ target_probs = _compute_probs(target_logits, num_draft_tokens,
122
+ sampling_metadata)
123
+ else:
124
+ target_probs = target_logits
125
+
126
+ output_token_ids = rejection_sample(
127
+ draft_token_ids,
128
+ num_draft_tokens,
129
+ draft_probs,
130
+ target_probs,
131
+ bonus_token_ids,
132
+ sampling_metadata,
133
+ key=key,
134
+ )
135
+ return output_token_ids
136
+
137
+ @staticmethod
138
+ def parse_output(
139
+ output_token_ids: jnp.ndarray,
140
+ vocab_size: int,
141
+ num_draft_tokens_cpu: np.ndarray,
142
+ batch_size: int,
143
+ padded_tokens_length: int,
144
+ ) -> list[list[int]]:
145
+ """Parse the output of the rejection sampler.
146
+
147
+ Args:
148
+ output_token_ids: The sampled token IDs in shape
149
+ [num_tokens + batch_size]. The first num_tokens elements are
150
+ the main tokens, and the last batch_size elements are bonus tokens.
151
+ Rejected tokens are replaced with `PLACEHOLDER_TOKEN_ID`.
152
+ vocab_size: The size of the vocabulary.
153
+ num_draft_tokens_cpu: Number of draft tokens per request [batch_size]
154
+ as a numpy array on CPU.
155
+ batch_size: The number of requests in the batch.
156
+ padded_tokens_length: The padded length of the main tokens in the output.
157
+
158
+ Returns:
159
+ A list of lists of token IDs.
160
+ """
161
+ # Convert JAX array to numpy for easier manipulation
162
+ output_token_ids_np = np.asarray(output_token_ids)
163
+
164
+ # Split main tokens and bonus tokens
165
+ main_tokens = output_token_ids_np[:
166
+ padded_tokens_length] # [num_tokens]
167
+ bonus_tokens = output_token_ids_np[
168
+ padded_tokens_length:] # [batch_size]
169
+
170
+ # Reconstruct per-sequence outputs
171
+ outputs = []
172
+ start_idx = 0
173
+
174
+ for i in range(batch_size):
175
+ seq_length = int(num_draft_tokens_cpu[i])
176
+ end_idx = start_idx + seq_length
177
+
178
+ # Get main tokens for this sequence
179
+ seq_main_tokens = main_tokens[start_idx:end_idx]
180
+
181
+ # Filter out placeholder tokens
182
+ valid_main_tokens = seq_main_tokens[
183
+ (seq_main_tokens != PLACEHOLDER_TOKEN_ID)
184
+ & (seq_main_tokens < vocab_size)]
185
+
186
+ # Add bonus token if it's valid
187
+ bonus_token = bonus_tokens[i]
188
+ if bonus_token != PLACEHOLDER_TOKEN_ID and bonus_token < vocab_size:
189
+ seq_tokens = np.concatenate([valid_main_tokens, [bonus_token]])
190
+ else:
191
+ seq_tokens = valid_main_tokens
192
+
193
+ outputs.append(seq_tokens.tolist())
194
+ start_idx = end_idx
195
+
196
+ return outputs
197
+
198
+
199
+ def _compute_probs(
200
+ logits: jnp.ndarray,
201
+ num_draft_tokens: jnp.ndarray,
202
+ sampling_metadata: TPUSupportedSamplingMetadata,
203
+ ) -> jnp.ndarray:
204
+ """
205
+ Apply top-k, top-p, and temperature to logits and compute probabilities.
206
+ """
207
+ total_tokens = logits.shape[0]
208
+ segment_ids, _ = _get_segment_info(num_draft_tokens, total_tokens)
209
+
210
+ # Expand sampling params from [batch_size] to [num_tokens]
211
+ top_k = sampling_metadata.top_k[segment_ids]
212
+ top_p = sampling_metadata.top_p[segment_ids]
213
+ temperatures = sampling_metadata.temperature[segment_ids]
214
+
215
+ # Apply top-k and top-p masking
216
+ logits = logits.astype(jnp.float32)
217
+ # Only apply top-k masking if k > 0 for each token
218
+ should_apply_topk = jnp.expand_dims(top_k > 0, axis=-1)
219
+ topk_masked = topk_mask(logits, top_k, replace_val=-jnp.inf)
220
+ logits = jnp.where(should_apply_topk, topk_masked, logits)
221
+
222
+ # Only apply top-p masking if p < 1.0 for each token
223
+ should_apply_topp = jnp.expand_dims(top_p < 1.0, axis=-1)
224
+ topp_masked = topp_mask(logits, top_p, replace_val=-jnp.inf)
225
+ logits = jnp.where(should_apply_topp, topp_masked, logits)
226
+
227
+ # Apply temperature scaling
228
+ temperatures = jnp.expand_dims(temperatures, axis=-1)
229
+ # Add epsilon to avoid division by zero
230
+ logits /= (temperatures + 1e-9)
231
+
232
+ return jax.nn.softmax(logits, axis=-1)
233
+
234
+
235
+ def _get_segment_info(num_draft_tokens: jax.Array, total_tokens: int):
236
+ """Helper to create segment IDs and per-segment indices."""
237
+ batch_size = num_draft_tokens.shape[0]
238
+
239
+ # `segment_ids` assigns a unique ID to each token, corresponding to its
240
+ # sequence in the batch. E.g., [0, 0, 0, 1, 1, 2, 2, 2, 2] for sequences [3, 2, 4].
241
+ segment_ids = jnp.repeat(jnp.arange(batch_size),
242
+ num_draft_tokens,
243
+ total_repeat_length=total_tokens)
244
+
245
+ # `group_indices` creates a within-segment index for each token.
246
+ # E.g., [0, 1, 2, 0, 1, 0, 1, 2, 3] for the example above.
247
+ segment_starts = jnp.concatenate(
248
+ [jnp.array([0]), jnp.cumsum(num_draft_tokens)[:-1]])
249
+ broadcast_starts = jnp.repeat(segment_starts,
250
+ num_draft_tokens,
251
+ total_repeat_length=total_tokens)
252
+ group_indices = jnp.arange(total_tokens) - broadcast_starts
253
+ return segment_ids, group_indices
254
+
255
+
256
+ def _sample_recovered_tokens(
257
+ draft_token_ids: jax.Array,
258
+ draft_probs: Optional[jax.Array],
259
+ target_probs: jax.Array,
260
+ key: jax.random.PRNGKey,
261
+ ) -> jax.Array:
262
+ """
263
+ Sample recovered tokens using the Gumbel-Max trick.
264
+ This is used when a draft token is rejected in random sampling.
265
+ """
266
+ if draft_probs is not None:
267
+ # The new distribution is p' = max(p_target - p_draft, 0)
268
+ new_dist = jnp.maximum(target_probs - draft_probs, 0)
269
+ else:
270
+ # If no draft probs, the new distribution is the target distribution
271
+ # with the draft token's probability zeroed out.
272
+ vocab_size = target_probs.shape[-1]
273
+ mask = jax.nn.one_hot(draft_token_ids, vocab_size, dtype=jnp.bool)
274
+ new_dist = target_probs * ~mask
275
+
276
+ # Gumbel-Max trick to sample from the new distribution:
277
+ # y = argmax(log(p') + g) where g ~ Gumbel(0,1)
278
+ # This is equivalent to argmax(p' / q) where q ~ Exponential(1)
279
+ q = jax.random.exponential(key, shape=new_dist.shape)
280
+
281
+ # Add a small epsilon to avoid division by zero
282
+ recovered_token_ids = jnp.argmax(new_dist / (q + 1e-9), axis=-1)
283
+ return recovered_token_ids
284
+
285
+
286
+ def rejection_sample(
287
+ # [num_tokens] - flattened format
288
+ draft_token_ids: jnp.ndarray,
289
+ # [batch_size] - JAX array
290
+ num_draft_tokens: jnp.ndarray,
291
+ # [num_tokens, vocab_size] - flattened format
292
+ draft_probs: Optional[jnp.ndarray],
293
+ # [num_tokens, vocab_size] - flattened format
294
+ target_probs: jnp.ndarray,
295
+ # [batch_size]
296
+ bonus_token_ids: jnp.ndarray,
297
+ sampling_metadata: TPUSupportedSamplingMetadata,
298
+ key: Optional[jax.random.PRNGKey] = None,
299
+ ) -> jnp.ndarray:
300
+ """
301
+ Perform rejection sampling on draft tokens with flattened inputs.
302
+
303
+ Args:
304
+ draft_token_ids: Draft token IDs in flattened format [num_tokens].
305
+ num_draft_tokens: Number of draft tokens per request [batch_size].
306
+ draft_probs: Draft probabilities in flattened format [num_tokens, vocab_size].
307
+ target_probs: Target probabilities in flattened format [num_tokens, vocab_size].
308
+ bonus_token_ids: Bonus token IDs [batch_size].
309
+ sampling_metadata: Sampling metadata.
310
+ key: JAX random key for non-greedy sampling.
311
+
312
+ Returns:
313
+ output_token_ids: Output token IDs [num_tokens + batch_size].
314
+ """
315
+ if sampling_metadata.do_sampling is False:
316
+ greedy_output = _greedy_rejection_sample_with_segment(
317
+ draft_token_ids, target_probs, num_draft_tokens, bonus_token_ids)
318
+ return greedy_output
319
+
320
+ # Random path
321
+ if key is None:
322
+ raise ValueError(
323
+ "A random key must be provided for non-greedy sampling.")
324
+
325
+ random_output = _random_rejection_sample_with_segment(
326
+ draft_token_ids,
327
+ draft_probs,
328
+ target_probs,
329
+ num_draft_tokens,
330
+ bonus_token_ids,
331
+ key,
332
+ )
333
+
334
+ return random_output
335
+
336
+
337
+ def _random_rejection_sample_with_segment(
338
+ draft_token_ids: jax.Array,
339
+ draft_probs: Optional[jax.Array],
340
+ target_probs: jax.Array,
341
+ num_draft_tokens: jax.Array,
342
+ bonus_token_ids: jax.Array,
343
+ key: jax.random.PRNGKey,
344
+ ) -> jax.Array:
345
+ """
346
+ Performs random speculative decoding validation in a vectorized, jittable manner.
347
+ """
348
+ total_tokens = draft_token_ids.shape[0]
349
+ batch_size = num_draft_tokens.shape[0]
350
+
351
+ # Split random key
352
+ uniform_key, recover_key = jax.random.split(key)
353
+
354
+ # --- Step 1: Get Segment Info ---
355
+ segment_ids, group_indices = _get_segment_info(num_draft_tokens,
356
+ total_tokens)
357
+
358
+ # --- Step 2: Acceptance/Rejection Logic ---
359
+ if draft_probs is not None:
360
+ draft_token_probs = jnp.take_along_axis(draft_probs,
361
+ draft_token_ids[:, None],
362
+ axis=-1).squeeze(-1)
363
+ else:
364
+ draft_token_probs = 1.0
365
+
366
+ target_token_probs = jnp.take_along_axis(target_probs,
367
+ draft_token_ids[:, None],
368
+ axis=-1).squeeze(-1)
369
+
370
+ uniform_probs = jax.random.uniform(uniform_key, shape=(total_tokens, ))
371
+
372
+ # Acceptance condition: p_target(d) / p_draft(d) >= u
373
+ ratio = target_token_probs / (draft_token_probs + 1e-9)
374
+ accepted = ratio >= uniform_probs
375
+
376
+ # --- Step 3: Find First Rejection ---
377
+ rejections = ~accepted
378
+ large_value = total_tokens
379
+ rejection_indices = jnp.where(rejections, group_indices, large_value)
380
+
381
+ first_rejection_idx_per_segment = jax.ops.segment_min(
382
+ data=rejection_indices.astype(jnp.int32),
383
+ segment_ids=segment_ids,
384
+ num_segments=batch_size,
385
+ indices_are_sorted=True,
386
+ )
387
+
388
+ max_int = jnp.iinfo(jnp.int32).max
389
+ first_rejection_idx_per_segment = jnp.where(
390
+ first_rejection_idx_per_segment == max_int, large_value,
391
+ first_rejection_idx_per_segment)
392
+
393
+ # --- Step 4: Sample Recovered Tokens ---
394
+ recovered_token_ids = _sample_recovered_tokens(draft_token_ids,
395
+ draft_probs, target_probs,
396
+ recover_key)
397
+
398
+ # --- Step 5: Generate Main Token Output ---
399
+ first_rejection_idx_broadcast = jnp.repeat(
400
+ first_rejection_idx_per_segment,
401
+ num_draft_tokens,
402
+ total_repeat_length=total_tokens)
403
+
404
+ main_tokens = jnp.where(
405
+ group_indices < first_rejection_idx_broadcast, draft_token_ids,
406
+ jnp.where(group_indices == first_rejection_idx_broadcast,
407
+ recovered_token_ids, PLACEHOLDER_TOKEN_ID))
408
+
409
+ # --- Step 6: Handle Bonus Tokens ---
410
+ all_accepted = first_rejection_idx_per_segment == large_value
411
+ no_draft_tokens = num_draft_tokens == 0
412
+ should_get_bonus = all_accepted | no_draft_tokens
413
+ bonus_tokens = jnp.where(should_get_bonus, bonus_token_ids,
414
+ PLACEHOLDER_TOKEN_ID)
415
+
416
+ # --- Step 7: Concatenate ---
417
+ return jnp.concatenate([main_tokens, bonus_tokens])
418
+
419
+
420
+ # TODO(pooyam): Optimize/Profile this implementation further. Currently, I just want working e2e. There might be overheads with `parse_output` that can be optimized on TPU.
421
+ # I should Benchmark against the following approaches:
422
+ # - Using `jax.lax.segment_xyz`` to work with flattened inputs instead of batched inputs.
423
+ # - Using vectorized implementation using `cumprod` and other masking tricks.
424
+ # - A pallas kernel similar to the Triton implementation.
425
+ # - Scan based approach.
426
+ # Overall, I expect XLA to optimize the scan-based approach pretty well, but
427
+ # it would be good to compare performance against other methods.
428
+ def _greedy_rejection_sample_with_segment(
429
+ draft_token_ids: jax.Array,
430
+ target_probs: jax.Array,
431
+ num_draft_tokens: jax.Array,
432
+ bonus_token_ids: jax.Array,
433
+ ) -> jax.Array:
434
+ """
435
+ Performs greedy speculative decoding validation in a vectorized, jittable manner.
436
+
437
+ This function compares draft tokens with the target model's outputs. For each
438
+ sequence in the batch, it accepts tokens as long as the draft and target match.
439
+ When a mismatch occurs, it takes the target model's token and invalidates the
440
+ rest of the tokens in that sequence by setting them to -1.
441
+
442
+ Args:
443
+ draft_token_ids: A 1D JAX array (num_tokens,) of integers representing the
444
+ concatenated draft tokens for all sequences in the batch.
445
+ target_probs: A 2D JAX array (num_tokens, vocab_size) of floats representing
446
+ the concatenated target model's probabilities.
447
+ num_draft_tokens: A 1D JAX array (batch_size,) of integers specifying the
448
+ number of draft tokens for each sequence in the batch.
449
+ bonus_token_ids: A 1D JAX array (batch_size,) of integers representing the
450
+ bonus token for each sequence.
451
+
452
+ Returns:
453
+ A 1D JAX array (num_tokens + batch_size,) containing the validated token
454
+ sequence followed by bonus tokens (or -1 if not accepted).
455
+ """
456
+ # Get target argmax
457
+ target_logits_argmax = jnp.argmax(target_probs, axis=-1)
458
+
459
+ # --- Step 1: Create Segment IDs and Per-Segment Indices ---
460
+ total_tokens = draft_token_ids.shape[0]
461
+ batch_size = num_draft_tokens.shape[0]
462
+ segment_ids, group_indices = _get_segment_info(num_draft_tokens,
463
+ total_tokens)
464
+
465
+ # --- Step 2: Find the First Mismatch in Each Segment ---
466
+
467
+ # Find all mismatches between draft and target tokens.
468
+ mismatches = draft_token_ids != target_logits_argmax
469
+
470
+ # To find the *first* mismatch, we use a trick with segment_min.
471
+ # We create an array where mismatched positions hold their `group_index`
472
+ # and matched positions hold a large value.
473
+ large_value = total_tokens
474
+ mismatch_indices = jnp.where(mismatches, group_indices, large_value)
475
+
476
+ # `segment_min` finds the minimum `mismatch_index` for each segment. This
477
+ # effectively gives us the `group_index` of the first mismatch.
478
+ # For sequences with no mismatches, the result will be `large_value`.
479
+ first_mismatch_idx_per_segment = jax.ops.segment_min(
480
+ data=mismatch_indices.astype(jnp.int32),
481
+ segment_ids=segment_ids,
482
+ num_segments=batch_size,
483
+ indices_are_sorted=True,
484
+ )
485
+
486
+ # Handle empty segments (where num_draft_tokens is 0). `segment_min` returns
487
+ # the dtype's max value for empty segments; we replace it with our large_value
488
+ # for consistency.
489
+ max_int = jnp.iinfo(jnp.int32).max
490
+ first_mismatch_idx_per_segment = jnp.where(
491
+ first_mismatch_idx_per_segment == max_int, large_value,
492
+ first_mismatch_idx_per_segment)
493
+
494
+ # --- Step 3: Broadcast Mismatch Info and Generate Main Token Output ---
495
+
496
+ # Broadcast the first mismatch index back to the original token dimension.
497
+ first_mismatch_idx_broadcast = jnp.repeat(first_mismatch_idx_per_segment,
498
+ num_draft_tokens,
499
+ total_repeat_length=total_tokens)
500
+
501
+ # The final logic for main tokens:
502
+ # A token is valid if its `group_index` is less than or equal to the
503
+ # index of the first mismatch in its segment.
504
+ # - If `group_index < first_mismatch_idx`, the draft was correct.
505
+ # - If `group_index == first_mismatch_idx`, this is the correction token.
506
+ # - If `group_index > first_mismatch_idx`, the token is invalid (-1).
507
+ main_tokens = jnp.where(group_indices <= first_mismatch_idx_broadcast,
508
+ target_logits_argmax, PLACEHOLDER_TOKEN_ID)
509
+
510
+ # --- Step 4: Handle Bonus Tokens ---
511
+
512
+ # A sequence gets its bonus token if there were no mismatches
513
+ # (first_mismatch_idx_per_segment == large_value)
514
+ all_accepted = first_mismatch_idx_per_segment == large_value
515
+
516
+ # For sequences with no draft tokens, we should still give them the bonus token
517
+ # since there's nothing to reject
518
+ no_draft_tokens = num_draft_tokens == 0
519
+ should_get_bonus = all_accepted | no_draft_tokens
520
+
521
+ bonus_tokens = jnp.where(should_get_bonus, bonus_token_ids,
522
+ PLACEHOLDER_TOKEN_ID)
523
+
524
+ # --- Step 5: Concatenate Main Tokens and Bonus Tokens ---
525
+
526
+ output = jnp.concatenate([main_tokens, bonus_tokens])
527
+
528
+ return output
@@ -0,0 +1,110 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from jax.sharding import Mesh, NamedSharding
20
+ from jax.sharding import PartitionSpec as P
21
+ from vllm.v1.outputs import LogprobsTensors
22
+
23
+ from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
24
+ from tpu_inference.layers.common.sharding import ShardingAxisName
25
+ from tpu_inference.layers.jax.sample.sampling_metadata import \
26
+ TPUSupportedSamplingMetadata
27
+
28
+ _SAMPLING_EPS = 1e-5
29
+
30
+
31
+ @functools.partial(
32
+ jax.jit,
33
+ static_argnames=["mesh"],
34
+ )
35
+ def sample(
36
+ rng: jax.Array,
37
+ mesh: Mesh,
38
+ logits: jax.Array,
39
+ tpu_sampling_metadata: TPUSupportedSamplingMetadata,
40
+ ) -> jax.Array:
41
+ # (B, vocab_size)
42
+ if tpu_sampling_metadata.do_sampling:
43
+ # Unshard the logits explicity to avoid latency increase.
44
+ logits = jax.lax.with_sharding_constraint(
45
+ logits, NamedSharding(mesh, P(ShardingAxisName.ATTN_DATA, None)))
46
+ greedy_sampled = jnp.argmax(logits, axis=-1)
47
+ if not tpu_sampling_metadata.do_sampling:
48
+ return greedy_sampled
49
+
50
+ logits = logits.astype(jnp.float32)
51
+ logits = topk_mask(logits, tpu_sampling_metadata.top_k, replace_val=-1e12)
52
+ logits = topp_mask(logits, tpu_sampling_metadata.top_p, replace_val=-1e12)
53
+
54
+ temperatures = tpu_sampling_metadata.temperature.astype(logits.dtype)
55
+ temperatures = jnp.expand_dims(temperatures, axis=-1)
56
+ logits /= temperatures
57
+
58
+ # (batch_size,)
59
+ next_tokens = jax.random.categorical(rng, logits)
60
+ # Note: avoid using the sample result when temperature < _SAMPLING_EPS
61
+ # If temperature < 0, logits /= temperatures will flip the result, causing error.
62
+ return jnp.where(tpu_sampling_metadata.temperature < _SAMPLING_EPS,
63
+ greedy_sampled, next_tokens)
64
+
65
+
66
+ def compute_logprobs(logits: jax.Array) -> jax.Array:
67
+ return jax.nn.log_softmax(logits, axis=-1)
68
+
69
+
70
+ def gather_logprobs(
71
+ logprobs: jax.Array,
72
+ token_ids: jax.Array,
73
+ num_logprobs: int,
74
+ ) -> LogprobsTensors:
75
+ """
76
+ Gather logprobs for topk and sampled/prompt token.
77
+
78
+ Args:
79
+ logprobs: (num tokens) x (vocab) tensor
80
+ token_ids: prompt tokens (if prompt logprobs)
81
+ or sampled tokens (if sampled
82
+ logprobs); 1D token ID tensor
83
+ with (num tokens) elements
84
+ num_logprobs: minimum number of logprobs to
85
+ retain per token
86
+
87
+
88
+ Returns:
89
+ Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
90
+ Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
91
+ Sampled token rank tensor, (num tokens)
92
+ """
93
+ # Find the topK values.
94
+ topk_logprobs, topk_indices = jax.lax.top_k(logprobs, k=num_logprobs)
95
+
96
+ # Get with the logprob of the prompt or sampled token.
97
+ token_ids = jnp.expand_dims(token_ids, axis=-1)
98
+ token_logprobs = jnp.take_along_axis(logprobs, token_ids, axis=-1)
99
+
100
+ # Compute the ranks of the actual token.
101
+ token_ranks = jnp.sum(logprobs >= token_logprobs, axis=-1)
102
+
103
+ # Concatenate together with the topk.
104
+ indices = jnp.concatenate((token_ids, topk_indices), axis=1)
105
+ logprobs = jnp.concatenate((token_logprobs, topk_logprobs), axis=1)
106
+
107
+ # Use int32 to reduce the tensor size.
108
+ indices = jnp.int32(indices)
109
+
110
+ return LogprobsTensors(indices, logprobs, token_ranks)