tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,20 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import copy
2
16
  import functools
3
- import os
17
+ import logging
4
18
  import random
5
19
  from contextlib import nullcontext
6
20
  from dataclasses import dataclass
@@ -10,17 +24,15 @@ import jax
10
24
  import jax.numpy as jnp
11
25
  import jaxtyping
12
26
  import numpy as np
13
- import torch
14
- import vllm.envs as envs
27
+ import vllm.envs as vllm_envs
15
28
  from flax import nnx
16
29
  from jax.experimental import mesh_utils
17
30
  from jax.sharding import NamedSharding, PartitionSpec
18
- from torchax.ops.mappings import j2t_dtype
19
31
  from vllm.config import VllmConfig
32
+ from vllm.distributed import get_pp_group
20
33
  from vllm.distributed.kv_transfer import (get_kv_transfer_group,
21
34
  has_kv_transfer_group)
22
35
  from vllm.forward_context import set_forward_context
23
- from vllm.sequence import IntermediateTensors
24
36
  from vllm.tasks import SupportedTask
25
37
  from vllm.utils.math_utils import cdiv
26
38
  from vllm.v1.core.sched.output import GrammarOutput
@@ -35,6 +47,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
35
47
  KVConnectorModelRunnerMixin
36
48
  from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
37
49
 
50
+ import tpu_inference.envs as envs
38
51
  from tpu_inference import utils as common_utils
39
52
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
53
  from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
@@ -48,6 +61,8 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \
48
61
  TPUSupportedSamplingMetadata
49
62
  from tpu_inference.logger import init_logger
50
63
  from tpu_inference.models.common.model_loader import get_model
64
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
65
+ JaxIntermediateTensors
51
66
  from tpu_inference.models.jax.utils.weight_utils import (
52
67
  shard_put, transfer_state_with_mappings)
53
68
  from tpu_inference.runner import utils as runner_utils
@@ -64,10 +79,12 @@ from tpu_inference.runner.structured_decoding_manager import \
64
79
  StructuredDecodingManager
65
80
  from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
66
81
  from tpu_inference.utils import (device_array, make_optimized_mesh,
67
- time_function)
82
+ time_function, to_jax_dtype, to_torch_dtype)
68
83
 
69
84
  logger = init_logger(__name__)
70
85
 
86
+ logging.getLogger("torchax.tensor").setLevel(logging.ERROR)
87
+
71
88
  INVALID_TOKEN_ID = -1
72
89
  # Smallest output size
73
90
  MIN_NUM_SEQS = 8
@@ -78,17 +95,6 @@ DUMMY_METADATA = AttentionMetadata(
78
95
  request_distribution=[0, 0, 0],
79
96
  )
80
97
 
81
- TPU_STR_DTYPE_TO_TORCH_DTYPE = {
82
- "half": torch.half,
83
- "bfloat16": torch.bfloat16,
84
- "float": torch.float,
85
- "fp8": torch.float8_e4m3fn,
86
- "fp8_e4m3": torch.float8_e4m3fn,
87
- "fp8_e5m2": torch.float8_e5m2,
88
- "int8": torch.int8,
89
- "uint8": torch.uint8,
90
- }
91
-
92
98
 
93
99
  class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
94
100
  """Holds asynchronous model output specifically from a TPU runner.
@@ -243,6 +249,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
243
249
  self.maybe_forbid_compile = runner_utils.ForbidCompile(
244
250
  ) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
245
251
  self.dp_size = self.vllm_config.sharding_config.total_dp_size
252
+ self.rank = rank
253
+ self.is_first_rank = is_first_rank
254
+ self.is_last_rank = is_last_rank
246
255
 
247
256
  self._init_random()
248
257
  self._init_mesh()
@@ -253,31 +262,21 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
253
262
 
254
263
  # Delegate functions to specific manager classes.
255
264
  self.compilation_manager = CompilationManager(self)
256
- self.speculative_decoding_manager = SpeculativeDecodingManager(self)
257
- self.structured_decoding_manager = StructuredDecodingManager(self)
265
+ if self.is_last_rank:
266
+ self.speculative_decoding_manager = SpeculativeDecodingManager(
267
+ self)
268
+ self.structured_decoding_manager = StructuredDecodingManager(self)
258
269
  self.kv_cache_manager = KVCacheManager(self)
259
270
  self.mm_manager = MultiModalManager(self)
260
271
  self.persistent_batch_manager = PersistentBatchManager(
261
272
  self.requests, self.input_batch, self.encoder_cache,
262
- self.uses_mrope, self.model_config)
273
+ self.uses_mrope, self.model_config, self.is_last_rank)
263
274
  self.lora_utils = LoraUtils(self)
264
275
 
265
- cache_config = self.cache_config
266
- if cache_config.cache_dtype == "auto":
267
- model_dtype = self.dtype
268
- if isinstance(model_dtype, str):
269
- self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
270
- elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
271
- self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
272
- elif isinstance(model_dtype, torch.dtype):
273
- self.kv_cache_dtype = model_dtype
274
- else:
275
- raise ValueError(
276
- "KV cache is unsupported for model_dtype of %s",
277
- model_dtype)
278
- else:
279
- self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
280
- cache_config.cache_dtype]
276
+ cache_dtype = self.cache_config.cache_dtype
277
+ if cache_dtype == "auto":
278
+ cache_dtype = self.dtype
279
+ self.kv_cache_dtype = to_torch_dtype(cache_dtype)
281
280
 
282
281
  self._pre_async_results: AsyncPreResults | None = None
283
282
  self._substitute_placeholder_token_fn = _substitute_placeholder_token
@@ -291,7 +290,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
291
290
  self.rng_key = jax.random.key(self.model_config.seed)
292
291
 
293
292
  def _init_mesh(self) -> None:
294
- if os.getenv("NEW_MODEL_DESIGN", False):
293
+ if envs.NEW_MODEL_DESIGN:
295
294
  self.mesh = self._create_new_model_mesh()
296
295
  else:
297
296
  # NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
@@ -302,7 +301,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
302
301
  logger.info(f"Init mesh | mesh={self.mesh}")
303
302
 
304
303
  def _create_new_model_mesh(self) -> jax.sharding.Mesh:
305
- num_slices = int(os.environ.get('NUM_SLICES', 1))
304
+ num_slices = envs.NUM_SLICES
306
305
 
307
306
  logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
308
307
  f"num_slices={num_slices}")
@@ -371,7 +370,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
371
370
  devices=self.devices)
372
371
 
373
372
  def _init_phased_profiling(self) -> None:
374
- self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "")
373
+ self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
375
374
  self.phase_based_profiler = None
376
375
  if self.phased_profiling_dir:
377
376
  self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
@@ -413,7 +412,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
413
412
  min_token_size=max(16, self.dp_size),
414
413
  max_token_size=scheduler_config.max_num_batched_tokens *
415
414
  self.dp_size,
416
- padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
415
+ padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
417
416
  self.num_tokens_paddings_per_dp = [
418
417
  padding // self.dp_size for padding in self.num_tokens_paddings
419
418
  ]
@@ -509,10 +508,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
509
508
  multimodal_fns = multimodal_fns or {}
510
509
  self.precompile_vision_encoder_fn = multimodal_fns.get(
511
510
  "precompile_vision_encoder_fn", None)
512
- self.get_multimodal_embeddings_fn = multimodal_fns.get(
513
- "get_multimodal_embeddings_fn", None)
514
- self.get_input_embeddings_fn = multimodal_fns.get(
515
- "get_input_embeddings_fn", None)
511
+ self.embed_multimodal_fn = multimodal_fns.get("embed_multimodal_fn",
512
+ None)
513
+ self.embed_input_ids_fn = multimodal_fns.get("embed_input_ids_fn",
514
+ None)
516
515
  self.get_mrope_input_positions_fn = multimodal_fns.get(
517
516
  "get_mrope_input_positions_fn", None)
518
517
 
@@ -524,7 +523,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
524
523
  jax.random.key(self.model_config.seed)).params()
525
524
  self.is_multimodal_model = (
526
525
  self.model_config.is_multimodal_model
527
- and self.get_multimodal_embeddings_fn is not None and hasattr(
526
+ and self.embed_multimodal_fn is not None and hasattr(
528
527
  self.model_config.hf_config, "architectures"
529
528
  ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
530
529
  and len(self.model_config.hf_config.architectures) >= 1
@@ -540,7 +539,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
540
539
  def get_kv_cache_spec(self):
541
540
  return self.kv_cache_manager.get_kv_cache_spec()
542
541
 
543
- def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
542
+ def initialize_kv_cache(self,
543
+ kv_cache_config: KVCacheConfig,
544
+ topology_order_id: int = 0) -> None:
545
+ self.topology_order_id = topology_order_id
544
546
  self.kv_cache_config = kv_cache_config
545
547
  self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
546
548
  self.kv_caches = []
@@ -555,12 +557,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
555
557
  def execute_model(
556
558
  self,
557
559
  scheduler_output: "VllmSchedulerOutput",
558
- intermediate_tensors: Optional[IntermediateTensors] = None,
559
- ) -> ModelRunnerOutput | None:
560
+ intermediate_tensors: Optional[JaxIntermediateTensors] = None,
561
+ ) -> ModelRunnerOutput | JaxIntermediateTensors | None:
560
562
  if self.execute_model_state is not None:
561
563
  raise RuntimeError("State error: sample_tokens() must be called "
562
564
  "after execute_model() returns None.")
563
- _, output = self._execute_model(scheduler_output)
565
+ _, output = self._execute_model(scheduler_output, intermediate_tensors)
564
566
  return output
565
567
 
566
568
  def sample_tokens(
@@ -686,7 +688,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
686
688
  def _execute_model(
687
689
  self,
688
690
  scheduler_output: "VllmSchedulerOutput",
689
- ) -> tuple[AttentionMetadata, ModelRunnerOutput | None]:
691
+ intermediate_tensors: Optional[JaxIntermediateTensors] = None,
692
+ ) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
693
+ | None]:
690
694
  self.persistent_batch_manager.update_states(
691
695
  scheduler_output, self.get_mrope_input_positions_fn)
692
696
  if not scheduler_output.total_num_scheduled_tokens:
@@ -764,7 +768,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
764
768
  scheduler_output) as kv_connector_output:
765
769
  # NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
766
770
  # but one of them would be `None`
767
-
768
771
  (self.kv_caches, hidden_states,
769
772
  aux_hidden_states) = self.model_fn(
770
773
  self.state,
@@ -775,8 +778,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
775
778
  input_positions,
776
779
  tuple(self.layer_name_to_kvcache_index.items()),
777
780
  lora_metadata,
781
+ intermediate_tensors,
782
+ self.is_first_rank,
783
+ self.is_last_rank,
778
784
  )
779
-
785
+ if not get_pp_group().is_last_rank:
786
+ assert isinstance(hidden_states, JaxIntermediateTensors)
787
+ hidden_states.kv_connector_output = kv_connector_output
788
+ return attn_metadata, hidden_states
780
789
  hidden_states = self._select_from_array_fn(hidden_states,
781
790
  logits_indices)
782
791
  logits = self.compute_logits_fn(
@@ -818,7 +827,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
818
827
  sharding = None
819
828
  if self.dp_size > 1:
820
829
  sharding = NamedSharding(self.mesh,
821
- PartitionSpec(ShardingAxisName.ATTN_DATA))
830
+ PartitionSpec(ShardingAxisName.MLP_DATA))
822
831
 
823
832
  tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
824
833
  self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
@@ -1345,7 +1354,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1345
1354
  _request_distribution = []
1346
1355
  for dp_rank in range(dp_size):
1347
1356
  _num_reqs = num_req_per_dp_rank[dp_rank]
1348
- _request_distribution.append([0, 0, _num_reqs])
1357
+ # The batch has been reordered by _reorder_batch so decode requests come first
1358
+ # Count decode requests (those with num_scheduled_tokens == 1) in this DP rank
1359
+ num_decode_in_dp_rank = 0
1360
+ for req_id in req_ids_dp[dp_rank]:
1361
+ if scheduler_output.num_scheduled_tokens[req_id] == 1:
1362
+ num_decode_in_dp_rank += 1
1363
+ _request_distribution.append(
1364
+ [num_decode_in_dp_rank, num_decode_in_dp_rank, _num_reqs])
1349
1365
  request_distribution = np.array(_request_distribution).ravel()
1350
1366
 
1351
1367
  use_spec_decode = len(
@@ -1374,7 +1390,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1374
1390
  self.mesh,
1375
1391
  self.input_batch,
1376
1392
  padded_num_reqs,
1377
- sharding=data_parallel_attn_sharding,
1393
+ sharding=NamedSharding(self.mesh,
1394
+ PartitionSpec(ShardingAxisName.MLP_DATA)),
1378
1395
  )
1379
1396
  if self.uses_mrope:
1380
1397
  positions = mrope_positions
@@ -1404,7 +1421,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1404
1421
  block_tables[
1405
1422
  req_offset:req_offset + _num_reqs, :self.
1406
1423
  max_num_blocks_per_req] = self.input_batch.block_table[
1407
- 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1424
+ kv_cache_gid].get_cpu_tensor()[req_indices_dp[dp_rank]]
1408
1425
  # Convert block_tables to 1D on cpu.
1409
1426
  block_tables = block_tables.reshape(-1)
1410
1427
  block_tables = device_array(
@@ -1664,7 +1681,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1664
1681
  def _get_input_ids_embeds(self, input_ids: jax.Array,
1665
1682
  mm_embeds: list[jax.Array]):
1666
1683
  if self.is_multimodal_model:
1667
- inputs_embeds = self.get_input_embeddings_fn(
1684
+ inputs_embeds = self.embed_input_ids_fn(
1668
1685
  self.state,
1669
1686
  input_ids,
1670
1687
  mm_embeds,
@@ -1719,3 +1736,34 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1719
1736
  mappings=mappings,
1720
1737
  transpose_keys=transpose_keys,
1721
1738
  shard=shard)
1739
+
1740
+ def get_intermediate_tensor_spec(self, num_tokens: int):
1741
+ jax_dtype = to_jax_dtype(self.dtype)
1742
+ num_padded_tokens = runner_utils.get_padded_token_len(
1743
+ self.num_tokens_paddings, num_tokens)
1744
+ sharding = NamedSharding(self.mesh, PartitionSpec())
1745
+ hidden_size = self.model_config.get_hidden_size()
1746
+ spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
1747
+ dtype=jax_dtype,
1748
+ sharding=sharding)
1749
+ tensor_spec = {"hidden_states": spec, "residual": spec}
1750
+ return tensor_spec
1751
+
1752
+ def get_uuid_for_jax_transfer(self,
1753
+ scheduler_output: "VllmSchedulerOutput",
1754
+ rank: int, step: int) -> int:
1755
+ '''
1756
+ Get a uuid for jax.transfer, here we use the hash of
1757
+ scheduler_output + counter_step + sender's rank
1758
+ '''
1759
+ scheduler_output_str = ""
1760
+ if not scheduler_output.num_scheduled_tokens:
1761
+ scheduler_output_str = "empty_batch"
1762
+ else:
1763
+ scheduler_output_str = str(
1764
+ sorted(scheduler_output.num_scheduled_tokens.items()))
1765
+ unique_str = f'{scheduler_output_str} {step} {rank}'
1766
+ import hashlib
1767
+ hasher = hashlib.sha1()
1768
+ hasher.update(unique_str.encode('utf-8'))
1769
+ return int.from_bytes(hasher.digest()[:8], 'big')
@@ -15,6 +15,7 @@ import jax
15
15
  from jax._src.interpreters import pxla
16
16
  from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
17
17
 
18
+ from tpu_inference import envs
18
19
  from tpu_inference.logger import init_logger
19
20
  from tpu_inference.runner.input_batch import InputBatch
20
21
 
@@ -306,8 +307,7 @@ class PhasedBasedProfiler:
306
307
  InferencePhase.BALANCED: False
307
308
  }
308
309
  self.default_profiling_options = jax.profiler.ProfileOptions()
309
- self.default_profiling_options.python_tracer_level = os.getenv(
310
- "PYTHON_TRACER_LEVEL", 0)
310
+ self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
311
311
 
312
312
  self.current_phase: str = ""
313
313
 
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """Implements the Eagle3 proposer for speculative decoding on JAX/TPU."""
2
15
  import functools
3
16
  from dataclasses import replace
@@ -6,6 +19,9 @@ from typing import Any, Optional
6
19
  import jax
7
20
  import jax.numpy as jnp
8
21
  import numpy as np
22
+ from flax import nnx
23
+ from jax import lax
24
+ from jax.sharding import NamedSharding, PartitionSpec
9
25
  from vllm.config import VllmConfig
10
26
 
11
27
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
@@ -127,6 +143,17 @@ class Eagle3Proposer:
127
143
  max_num_blocks_per_req)
128
144
  new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
129
145
 
146
+ positions = lax.with_sharding_constraint(
147
+ positions, NamedSharding(self.mesh, PartitionSpec(None, )))
148
+ clamped_positions = lax.with_sharding_constraint(
149
+ clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
150
+ new_seq_lens = lax.with_sharding_constraint(
151
+ new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
152
+ query_start_loc = lax.with_sharding_constraint(
153
+ query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
154
+ new_block_tables = lax.with_sharding_constraint(
155
+ new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
156
+
130
157
  return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
131
158
 
132
159
  @functools.partial(jax.jit, static_argnums=(0, ))
@@ -138,6 +165,7 @@ class Eagle3Proposer:
138
165
  @functools.partial(jax.jit, static_argnums=(0, ))
139
166
  def _prepare_hidden_states_and_input_ids(
140
167
  self,
168
+ state: nnx.State,
141
169
  aux_hidden_states: tuple[jax.Array, ...],
142
170
  query_start_loc: jax.Array,
143
171
  target_token_ids: jax.Array,
@@ -146,7 +174,7 @@ class Eagle3Proposer:
146
174
  ) -> tuple[jax.Array, jax.Array, jax.Array]:
147
175
  target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
148
176
  target_hidden_states = self.combine_hidden_states_fn(
149
- self.state, target_hidden_states)
177
+ state, target_hidden_states)
150
178
 
151
179
  input_ids, last_token_indices = self._prepare_input_ids(
152
180
  query_start_loc, target_token_ids, next_token_ids, num_reqs)
@@ -193,8 +221,8 @@ class Eagle3Proposer:
193
221
  block_tables=device_array(
194
222
  self.mesh, block_tables))
195
223
  target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
196
- aux_hidden_states, attn_metadata.query_start_loc, input_ids,
197
- next_token_ids, num_reqs)
224
+ self.state, aux_hidden_states, attn_metadata.query_start_loc,
225
+ input_ids, next_token_ids, num_reqs)
198
226
  return target_hidden_states, input_ids, last_token_indices, attn_metadata
199
227
 
200
228
  # Host copies from the metadata prepared by the runner.
@@ -258,12 +286,13 @@ class Eagle3Proposer:
258
286
 
259
287
  attn_metadata = replace(attn_metadata, block_tables=block_tables)
260
288
  return self._filter_token_and_prepare_initial_inputs(
261
- token_indices, query_start_loc, seq_lens, input_ids,
289
+ self.state, token_indices, query_start_loc, seq_lens, input_ids,
262
290
  aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
263
291
 
264
292
  @functools.partial(jax.jit, static_argnums=(0, ))
265
293
  def _filter_token_and_prepare_initial_inputs(
266
294
  self,
295
+ state: nnx.State,
267
296
  token_indices: jax.Array,
268
297
  query_start_loc: jax.Array,
269
298
  seq_lens: jax.Array,
@@ -291,35 +320,51 @@ class Eagle3Proposer:
291
320
  )
292
321
 
293
322
  target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
294
- [h[token_indices] for h in aux_hidden_states], query_start_loc,
295
- target_token_ids, next_token_ids, num_reqs)
323
+ state, [h[token_indices] for h in aux_hidden_states],
324
+ query_start_loc, target_token_ids, next_token_ids, num_reqs)
296
325
 
297
326
  return target_hidden_states, input_ids, last_token_indices, attn_metadata
298
327
 
299
328
  @functools.partial(jax.jit, static_argnums=(0, ))
300
329
  def _select_draft_token_ids(
301
330
  self,
331
+ state: nnx.State,
302
332
  hidden_states: jax.Array,
303
333
  last_token_indices: jax.Array,
304
334
  ) -> jax.Array:
305
335
  sample_hidden_states = hidden_states[last_token_indices]
306
- return self._get_draft_token_ids(sample_hidden_states)
336
+ sample_hidden_states = lax.with_sharding_constraint(
337
+ sample_hidden_states,
338
+ NamedSharding(self.mesh, PartitionSpec(None, None)))
339
+ return self._get_draft_token_ids(state, sample_hidden_states)
307
340
 
308
341
  @functools.partial(jax.jit, static_argnums=(0, ))
309
- def _get_draft_token_ids(self, hidden_states: jax.Array) -> jax.Array:
342
+ def _get_draft_token_ids(self, state: nnx.State,
343
+ hidden_states: jax.Array) -> jax.Array:
310
344
  lora_metadata = None
311
- logits = self.compute_logits_fn(self.state, hidden_states,
312
- lora_metadata)
313
- return jnp.argmax(logits, axis=-1)
345
+ logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
346
+ draft_token_ids = jnp.argmax(logits, axis=-1)
347
+ return lax.with_sharding_constraint(
348
+ draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
314
349
 
315
350
  @functools.partial(jax.jit, static_argnums=(0, ))
316
351
  def _select_inputs_for_loop_speculation(
317
- self, positions: jax.Array, residual: jax.Array,
352
+ self, state: nnx.State, positions: jax.Array, residual: jax.Array,
318
353
  hidden_states: jax.Array,
319
354
  last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
320
- return positions[last_token_indices], residual[
321
- last_token_indices], self._select_draft_token_ids(
322
- hidden_states, last_token_indices)
355
+ positions = positions[last_token_indices]
356
+ residual = residual[last_token_indices]
357
+ draft_token_ids = self._select_draft_token_ids(state, hidden_states,
358
+ last_token_indices)
359
+
360
+ positions = lax.with_sharding_constraint(
361
+ positions, NamedSharding(self.mesh, PartitionSpec(None, )))
362
+ residual = lax.with_sharding_constraint(
363
+ residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
364
+ draft_token_ids = lax.with_sharding_constraint(
365
+ draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
366
+
367
+ return positions, residual, draft_token_ids
323
368
 
324
369
  def propose(
325
370
  self,
@@ -346,11 +391,11 @@ class Eagle3Proposer:
346
391
 
347
392
  if self.num_speculative_tokens == 1:
348
393
  return kv_caches, self._select_draft_token_ids(
349
- hidden_states, last_token_indices)
394
+ self.state, hidden_states, last_token_indices)
350
395
 
351
396
  positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
352
- attn_metadata.input_positions, residual[0], hidden_states,
353
- last_token_indices)
397
+ self.state, attn_metadata.input_positions, residual[0],
398
+ hidden_states, last_token_indices)
354
399
 
355
400
  draft_token_ids_list = [draft_token_ids]
356
401
 
@@ -375,7 +420,8 @@ class Eagle3Proposer:
375
420
  attn_metadata,
376
421
  )
377
422
  hidden_states = residual[0]
378
- draft_token_ids = self._get_draft_token_ids(new_hidden_states)
423
+ draft_token_ids = self._get_draft_token_ids(
424
+ self.state, new_hidden_states)
379
425
  draft_token_ids_list.append(draft_token_ids)
380
426
 
381
427
  # [batch_size, num_speculative_tokens]
tpu_inference/tpu_info.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import glob
2
16
  import os
3
17