tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,430 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Implements the Eagle3 proposer for speculative decoding on JAX/TPU."""
15
+ import functools
16
+ from dataclasses import replace
17
+ from typing import Any, Optional
18
+
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import numpy as np
22
+ from flax import nnx
23
+ from jax import lax
24
+ from jax.sharding import NamedSharding, PartitionSpec
25
+ from vllm.config import VllmConfig
26
+
27
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
28
+ from tpu_inference.logger import init_logger
29
+ from tpu_inference.models.common.model_loader import get_model
30
+ from tpu_inference.runner import utils as runner_utils
31
+ from tpu_inference.utils import device_array
32
+
33
+ logger = init_logger(__name__)
34
+
35
+
36
+ class Eagle3Proposer:
37
+ """A proposer for speculative decoding using the Eagle3 method.
38
+
39
+ This class is responsible for loading the draft model and generating draft
40
+ tokens based on the target model's outputs.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ vllm_config: VllmConfig,
46
+ runner: Any, # TPUModelRunner
47
+ ):
48
+ """Initializes the Eagle3Proposer.
49
+
50
+ Args:
51
+ vllm_config: The vLLM configuration.
52
+ runner: The TPUModelRunner instance.
53
+ """
54
+ self.vllm_config = vllm_config
55
+ self.speculative_config = vllm_config.speculative_config
56
+ assert self.speculative_config is not None
57
+ self.draft_model_config = self.speculative_config.draft_model_config
58
+ self.method = self.speculative_config.method
59
+
60
+ self.runner = runner
61
+ self.mesh = runner.mesh
62
+ self.num_speculative_tokens = (
63
+ self.speculative_config.num_speculative_tokens)
64
+ self.block_size = vllm_config.cache_config.block_size
65
+ self.rng_key = jax.random.key(self.vllm_config.model_config.seed)
66
+ self.max_num_tokens = runner.max_num_tokens
67
+ self.token_arange = jnp.arange(self.max_num_tokens)
68
+
69
+ def load_model(self, target_model: Any) -> None:
70
+ """Loads the draft model."""
71
+ self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
72
+ self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
73
+
74
+ draft_embed_tokens = getattr(self.state.model, 'embed_tokens', None)
75
+ if draft_embed_tokens is None or ~jnp.any(
76
+ draft_embed_tokens.embedding):
77
+ logger.info(
78
+ "Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
79
+ )
80
+ self.state.model.embed_tokens = target_model.model.embed
81
+ elif jnp.array_equal(draft_embed_tokens.embedding,
82
+ target_model.model.embed.embedding):
83
+ logger.info(
84
+ "Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
85
+ )
86
+ self.state.model.embed_tokens = target_model.model.embed
87
+ else:
88
+ logger.info("Draft model has its own embed_tokens.")
89
+
90
+ @functools.partial(jax.jit, static_argnums=(0, ))
91
+ def _prepare_input_ids(
92
+ self, query_start_loc: jax.Array, target_token_ids: jax.Array,
93
+ next_token_ids: jax.Array,
94
+ num_reqs: jax.Array) -> tuple[jnp.ndarray, jnp.ndarray]:
95
+ """JIT-compiled helper for preparing the input IDs for the draft model."""
96
+
97
+ last_token_indices = query_start_loc[1:] - 1
98
+ # Shift the input ids by one token.
99
+ rolled_input_ids = jnp.roll(target_token_ids, -1, axis=0)
100
+
101
+ # To make the update JIT-compatible with a dynamic `num_reqs`, we perform a
102
+ # scatter update of a static size, using a mask to handle the dynamic part.
103
+ max_num_reqs = last_token_indices.shape[0]
104
+ mask = jnp.arange(max_num_reqs) < num_reqs
105
+
106
+ # For padded requests (where mask is False), we use the original value from
107
+ # the rolled array, making the update a no-op for them.
108
+ original_values_at_indices = rolled_input_ids[last_token_indices]
109
+ values_to_set = jnp.where(mask, next_token_ids,
110
+ original_values_at_indices)
111
+
112
+ input_ids = rolled_input_ids.at[last_token_indices].set(values_to_set)
113
+
114
+ return input_ids, last_token_indices
115
+
116
+ @functools.partial(jax.jit, static_argnums=(0, ))
117
+ def _update_inputs_for_loop_speculation(
118
+ self, positions: jax.Array, seq_lens: jax.Array,
119
+ block_tables: jax.Array
120
+ ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
121
+ """JIT-compiled helper for preparing inputs in the loop of prediction."""
122
+
123
+ positions += 1
124
+ exceeds_max_model_len = positions >= self.runner.max_model_len
125
+ clamped_positions = jnp.where(exceeds_max_model_len, 0, positions)
126
+
127
+ new_seq_lens = seq_lens + 1
128
+ new_seq_lens = jnp.minimum(new_seq_lens, self.runner.max_model_len)
129
+ new_seq_lens = jnp.where(exceeds_max_model_len, 1, new_seq_lens)
130
+
131
+ num_reqs = seq_lens.shape[0]
132
+ query_start_loc = jnp.arange(num_reqs + 1)
133
+
134
+ # Compute the slot mapping.
135
+ # NOTE(woosuk): We should handle the case where the draft model
136
+ # generates tokens beyond the max model length. Since it is complex
137
+ # to remove such requests from the batch, we keep them in the batch
138
+ # but adjust the position ids and slot mappings to avoid the
139
+ # out-of-range access during the model execution. The draft tokens
140
+ # generated with this adjustment should be ignored.
141
+ max_num_blocks_per_req = block_tables.shape[0] // num_reqs
142
+ expanded_exceeds_mask = jnp.repeat(exceeds_max_model_len,
143
+ max_num_blocks_per_req)
144
+ new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
145
+
146
+ positions = lax.with_sharding_constraint(
147
+ positions, NamedSharding(self.mesh, PartitionSpec(None, )))
148
+ clamped_positions = lax.with_sharding_constraint(
149
+ clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
150
+ new_seq_lens = lax.with_sharding_constraint(
151
+ new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
152
+ query_start_loc = lax.with_sharding_constraint(
153
+ query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
154
+ new_block_tables = lax.with_sharding_constraint(
155
+ new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
156
+
157
+ return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
158
+
159
+ @functools.partial(jax.jit, static_argnums=(0, ))
160
+ def _stack_draft_token_ids(
161
+ self, draft_token_ids_list: list[jax.Array]) -> jnp.ndarray:
162
+ """JIT-compiled helper for stacking draft token IDs."""
163
+ return jnp.stack(draft_token_ids_list, axis=1)
164
+
165
+ @functools.partial(jax.jit, static_argnums=(0, ))
166
+ def _prepare_hidden_states_and_input_ids(
167
+ self,
168
+ state: nnx.State,
169
+ aux_hidden_states: tuple[jax.Array, ...],
170
+ query_start_loc: jax.Array,
171
+ target_token_ids: jax.Array,
172
+ next_token_ids: jax.Array,
173
+ num_reqs: jax.Array,
174
+ ) -> tuple[jax.Array, jax.Array, jax.Array]:
175
+ target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
176
+ target_hidden_states = self.combine_hidden_states_fn(
177
+ state, target_hidden_states)
178
+
179
+ input_ids, last_token_indices = self._prepare_input_ids(
180
+ query_start_loc, target_token_ids, next_token_ids, num_reqs)
181
+ # NOTE(pooyam): For now, we don't support multimodal.
182
+
183
+ return target_hidden_states, input_ids, last_token_indices
184
+
185
+ def prepare_inputs(
186
+ self,
187
+ attn_metadata: AttentionMetadata,
188
+ input_ids: jax.Array,
189
+ aux_hidden_states: tuple[jax.Array, ...],
190
+ next_token_ids: jax.Array,
191
+ num_rejected_tokens: Optional[jax.Array] = None,
192
+ ) -> tuple[jax.Array, jax.Array, jax.Array, AttentionMetadata]:
193
+ """Prepare drafter inputs based on target forward outputs.
194
+
195
+ Mirrors the GPU reference logic but adapted to TPU/JAX types:
196
+ - When no rejection happened, select the first N scheduled tokens.
197
+ - When rejections happened, trim the per-request tail tokens and
198
+ update attention metadata accordingly.
199
+ - Build the EAGLE3 hidden input by concatenating auxiliary hidden
200
+ states along the last dimension.
201
+
202
+ Returns updated AttentionMetadata (positions, query_start_loc, seq_lens)
203
+ and the selected `target_token_ids` and `target_hidden_states`.
204
+ """
205
+ assert aux_hidden_states is not None and len(aux_hidden_states) > 0, (
206
+ "EAGLE3 requires auxiliary hidden states from the target model.")
207
+
208
+ # The last KV cache group is for the draft model.
209
+ num_kv_cache_groups = len(self.runner.kv_cache_config.kv_cache_groups)
210
+ draft_kv_cache_group_id = num_kv_cache_groups - 1
211
+ block_tables = self.runner.input_batch.block_table[
212
+ draft_kv_cache_group_id].get_cpu_tensor().reshape(-1)
213
+ # Number of active requests in this step (un-padded count).
214
+ num_reqs = self.runner.input_batch.num_reqs
215
+
216
+ if num_rejected_tokens is None:
217
+ num_reqs = device_array(self.mesh,
218
+ np.asarray([num_reqs], dtype=jnp.int32))
219
+ # block_tables = device_array(self.mesh, block_tables)
220
+ attn_metadata = replace(attn_metadata,
221
+ block_tables=device_array(
222
+ self.mesh, block_tables))
223
+ target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
224
+ self.state, aux_hidden_states, attn_metadata.query_start_loc,
225
+ input_ids, next_token_ids, num_reqs)
226
+ return target_hidden_states, input_ids, last_token_indices, attn_metadata
227
+
228
+ # Host copies from the metadata prepared by the runner.
229
+ query_start_loc_cpu = attn_metadata.query_start_loc_cpu
230
+ seq_lens_cpu = attn_metadata.seq_lens_cpu
231
+ assert query_start_loc_cpu is not None and seq_lens_cpu is not None
232
+
233
+ # Rejection-aware path: compute new per-request lengths and token indices.
234
+ # Convert to host numpy for efficient prefix-sum and repeat ops.
235
+ nrt_cpu = jax.device_get(num_rejected_tokens).astype("int32")
236
+
237
+ # query_len_per_req = [q1, q2, ...]
238
+ query_len_per_req = (query_start_loc_cpu[1:] -
239
+ query_start_loc_cpu[:-1])
240
+
241
+ # query_start_loc_cpu and consequentaly query_len_per_req are padded
242
+ # For padded requests, the query length should be 0.
243
+ query_len_per_req[num_reqs:] = 1
244
+ # num_tokens_per_req = [q1 - n1, q2 - n2, ...]
245
+ num_tokens_per_req = (query_len_per_req - nrt_cpu)
246
+ assert (num_tokens_per_req
247
+ >= 0).all(), ("num_tokens_per_req must be non-negative")
248
+
249
+ # new_query_start_loc = [0, q1-n1, q1+q2-n1-n2, ...]
250
+ # Use numpy for cumsum and then convert back.
251
+ new_query_start_loc_cpu = np.zeros_like(query_start_loc_cpu)
252
+ np.cumsum(num_tokens_per_req, out=new_query_start_loc_cpu[1:])
253
+
254
+ # Build token indices selecting the kept tokens from each request.
255
+ total_num_tokens = int(new_query_start_loc_cpu[-1])
256
+
257
+ # Pad to total_num_tokens.
258
+ padded_total_num_tokens = runner_utils.get_padded_token_len(
259
+ self.runner.num_tokens_paddings, total_num_tokens)
260
+ pad_width = padded_total_num_tokens - total_num_tokens
261
+ assert pad_width >= 0, (
262
+ f"total_num_tokens {total_num_tokens} exceeds "
263
+ f"num_tokens_paddings {self.runner.num_tokens_paddings}")
264
+
265
+ # Expand request starts: [0, 0, q1-n1, ...,]
266
+ expanded_new_query_start_loc = np.repeat(new_query_start_loc_cpu[:-1],
267
+ num_tokens_per_req)
268
+ # Offsets within each request window: [0,1,2, 0,1,2,3, ...]
269
+ token_offsets = np.arange(total_num_tokens, dtype=np.int32)
270
+ token_offsets -= expanded_new_query_start_loc
271
+ # Map into old flat indices by adding original request starts.
272
+ old_query_start_loc_expanded = np.repeat(query_start_loc_cpu[:-1],
273
+ num_tokens_per_req)
274
+
275
+ token_indices_cpu = token_offsets + old_query_start_loc_expanded
276
+ token_indices_cpu = np.pad(token_indices_cpu, (0, pad_width),
277
+ "constant",
278
+ constant_values=0)
279
+ # Update seq_lens for active requests only: new_seq_lens = s - n.
280
+ new_seq_lens_cpu = seq_lens_cpu - nrt_cpu
281
+
282
+ query_start_loc, seq_lens, token_indices, num_reqs, block_tables = device_array(
283
+ self.mesh,
284
+ (new_query_start_loc_cpu, new_seq_lens_cpu, token_indices_cpu,
285
+ np.asarray([num_reqs], dtype=jnp.int32), block_tables))
286
+
287
+ attn_metadata = replace(attn_metadata, block_tables=block_tables)
288
+ return self._filter_token_and_prepare_initial_inputs(
289
+ self.state, token_indices, query_start_loc, seq_lens, input_ids,
290
+ aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
291
+
292
+ @functools.partial(jax.jit, static_argnums=(0, ))
293
+ def _filter_token_and_prepare_initial_inputs(
294
+ self,
295
+ state: nnx.State,
296
+ token_indices: jax.Array,
297
+ query_start_loc: jax.Array,
298
+ seq_lens: jax.Array,
299
+ input_ids: jax.Array,
300
+ aux_hidden_states: tuple[jax.Array, ...],
301
+ attn_metadata: AttentionMetadata,
302
+ next_token_ids: jax.Array,
303
+ num_reqs: jax.Array,
304
+ ) -> tuple[jax.Array, jax.Array, jax.Array, AttentionMetadata]:
305
+
306
+ # Select tokens and hidden states.
307
+ target_token_ids = input_ids[token_indices]
308
+ # Update positions to match the selected tokens.
309
+ if attn_metadata.input_positions.ndim == 2:
310
+ input_positions = attn_metadata.input_positions[:, token_indices]
311
+ else:
312
+ input_positions = attn_metadata.input_positions[token_indices]
313
+
314
+ attn_metadata = AttentionMetadata(
315
+ input_positions=input_positions,
316
+ block_tables=attn_metadata.block_tables,
317
+ seq_lens=seq_lens,
318
+ query_start_loc=query_start_loc,
319
+ request_distribution=attn_metadata.request_distribution,
320
+ )
321
+
322
+ target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
323
+ state, [h[token_indices] for h in aux_hidden_states],
324
+ query_start_loc, target_token_ids, next_token_ids, num_reqs)
325
+
326
+ return target_hidden_states, input_ids, last_token_indices, attn_metadata
327
+
328
+ @functools.partial(jax.jit, static_argnums=(0, ))
329
+ def _select_draft_token_ids(
330
+ self,
331
+ state: nnx.State,
332
+ hidden_states: jax.Array,
333
+ last_token_indices: jax.Array,
334
+ ) -> jax.Array:
335
+ sample_hidden_states = hidden_states[last_token_indices]
336
+ sample_hidden_states = lax.with_sharding_constraint(
337
+ sample_hidden_states,
338
+ NamedSharding(self.mesh, PartitionSpec(None, None)))
339
+ return self._get_draft_token_ids(state, sample_hidden_states)
340
+
341
+ @functools.partial(jax.jit, static_argnums=(0, ))
342
+ def _get_draft_token_ids(self, state: nnx.State,
343
+ hidden_states: jax.Array) -> jax.Array:
344
+ lora_metadata = None
345
+ logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
346
+ draft_token_ids = jnp.argmax(logits, axis=-1)
347
+ return lax.with_sharding_constraint(
348
+ draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
349
+
350
+ @functools.partial(jax.jit, static_argnums=(0, ))
351
+ def _select_inputs_for_loop_speculation(
352
+ self, state: nnx.State, positions: jax.Array, residual: jax.Array,
353
+ hidden_states: jax.Array,
354
+ last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
355
+ positions = positions[last_token_indices]
356
+ residual = residual[last_token_indices]
357
+ draft_token_ids = self._select_draft_token_ids(state, hidden_states,
358
+ last_token_indices)
359
+
360
+ positions = lax.with_sharding_constraint(
361
+ positions, NamedSharding(self.mesh, PartitionSpec(None, )))
362
+ residual = lax.with_sharding_constraint(
363
+ residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
364
+ draft_token_ids = lax.with_sharding_constraint(
365
+ draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
366
+
367
+ return positions, residual, draft_token_ids
368
+
369
+ def propose(
370
+ self,
371
+ kv_caches: list[jax.Array],
372
+ input_ids: jax.Array,
373
+ attn_metadata: AttentionMetadata,
374
+ last_token_indices,
375
+ target_hidden_states,
376
+ ) -> tuple[list[jax.Array], jnp.ndarray]:
377
+ """Proposes draft tokens using the draft model.
378
+ Returns:
379
+ A tuple containing the updated KV caches and a tensor of proposed
380
+ draft token IDs.
381
+ """
382
+
383
+ # input_ids and target_hidden_states for the first speculation have been prepared in prepare_inputs() to improve performance.
384
+ kv_caches, hidden_states, residual = self.model_fn(
385
+ self.state,
386
+ kv_caches,
387
+ input_ids,
388
+ target_hidden_states,
389
+ attn_metadata,
390
+ )
391
+
392
+ if self.num_speculative_tokens == 1:
393
+ return kv_caches, self._select_draft_token_ids(
394
+ self.state, hidden_states, last_token_indices)
395
+
396
+ positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
397
+ self.state, attn_metadata.input_positions, residual[0],
398
+ hidden_states, last_token_indices)
399
+
400
+ draft_token_ids_list = [draft_token_ids]
401
+
402
+ for _ in range(self.num_speculative_tokens - 1):
403
+ input_ids_loop = draft_token_ids_list[-1]
404
+
405
+ positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables = self._update_inputs_for_loop_speculation(
406
+ positions, attn_metadata.seq_lens, attn_metadata.block_tables)
407
+
408
+ attn_metadata = replace(
409
+ attn_metadata,
410
+ input_positions=clamped_positions,
411
+ seq_lens=new_seq_lens,
412
+ query_start_loc=query_start_loc,
413
+ block_tables=new_block_tables,
414
+ )
415
+ kv_caches, new_hidden_states, residual = self.model_fn(
416
+ self.state,
417
+ kv_caches,
418
+ input_ids_loop,
419
+ hidden_states, # This should be the hidden_states from previous step
420
+ attn_metadata,
421
+ )
422
+ hidden_states = residual[0]
423
+ draft_token_ids = self._get_draft_token_ids(
424
+ self.state, new_hidden_states)
425
+ draft_token_ids_list.append(draft_token_ids)
426
+
427
+ # [batch_size, num_speculative_tokens]
428
+ draft_token_ids = self._stack_draft_token_ids(draft_token_ids_list)
429
+
430
+ return kv_caches, draft_token_ids
@@ -0,0 +1,92 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import glob
16
+ import os
17
+
18
+ import requests
19
+
20
+ from tpu_inference import envs
21
+ from tpu_inference.logger import init_logger
22
+
23
+ logger = init_logger(__name__)
24
+
25
+ GCE_TPU_ACCELERATOR_ENDPOINT = (
26
+ "http://metadata.google.internal/computeMetadata/v1/instance/attributes/")
27
+ GCE_TPU_HEADERS = {"Metadata-Flavor": "Google"}
28
+
29
+
30
+ def get_tpu_metadata(key: str = "") -> str:
31
+ try:
32
+ accelerator_type_request = requests.get(
33
+ os.path.join(GCE_TPU_ACCELERATOR_ENDPOINT, key),
34
+ headers=GCE_TPU_HEADERS,
35
+ )
36
+ if (accelerator_type_request.status_code == 200
37
+ and accelerator_type_request.text):
38
+ return accelerator_type_request.text
39
+ else:
40
+ logger.error(
41
+ "Unable to poll TPU GCE Metadata. Got "
42
+ f"status code: {accelerator_type_request.status_code} and "
43
+ f"content: {accelerator_type_request.text}")
44
+ except requests.RequestException as e:
45
+ logger.error("Unable to poll the TPU GCE Metadata: %s", e)
46
+ return None
47
+
48
+
49
+ def get_tpu_type() -> str:
50
+ tpu_type = envs.TPU_ACCELERATOR_TYPE
51
+ if tpu_type is None:
52
+ tpu_type = get_tpu_metadata(key="accelerator-type")
53
+ return tpu_type
54
+
55
+
56
+ def get_node_name() -> str:
57
+ tpu_name = envs.TPU_NAME
58
+ if not tpu_name:
59
+ tpu_name = get_tpu_metadata(key="instance-id")
60
+ return tpu_name
61
+
62
+
63
+ def get_node_worker_id() -> int:
64
+ """For multi-host TPU VM, this returns the worker id for the current node."""
65
+ worker_id = envs.TPU_WORKER_ID
66
+ if worker_id is None:
67
+ worker_id = get_tpu_metadata(key="agent-worker-number")
68
+ if worker_id is None:
69
+ return 0
70
+ return int(worker_id)
71
+
72
+
73
+ def get_num_cores_per_chip() -> int:
74
+ tpu_type = get_tpu_type()
75
+ if tpu_type.startswith(("v5litepod", "v6e")):
76
+ return 1
77
+ return 2
78
+
79
+
80
+ def get_num_chips() -> int:
81
+ accel_files = glob.glob("/dev/accel*")
82
+ if accel_files:
83
+ return len(accel_files)
84
+ try:
85
+ vfio_entries = os.listdir("/dev/vfio")
86
+ numeric_entries = [
87
+ int(entry) for entry in vfio_entries if entry.isdigit()
88
+ ]
89
+ return len(numeric_entries)
90
+ except FileNotFoundError as e:
91
+ logger.error("Failed to detect number of TPUs: %s", e)
92
+ return 0