tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,20 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import copy
2
16
  import functools
3
- import os
17
+ import logging
4
18
  import random
5
19
  from contextlib import nullcontext
6
20
  from dataclasses import dataclass
@@ -10,17 +24,15 @@ import jax
10
24
  import jax.numpy as jnp
11
25
  import jaxtyping
12
26
  import numpy as np
13
- import torch
14
- import vllm.envs as envs
27
+ import vllm.envs as vllm_envs
15
28
  from flax import nnx
16
29
  from jax.experimental import mesh_utils
17
30
  from jax.sharding import NamedSharding, PartitionSpec
18
- from torchax.ops.mappings import j2t_dtype
19
31
  from vllm.config import VllmConfig
32
+ from vllm.distributed import get_pp_group
20
33
  from vllm.distributed.kv_transfer import (get_kv_transfer_group,
21
34
  has_kv_transfer_group)
22
35
  from vllm.forward_context import set_forward_context
23
- from vllm.sequence import IntermediateTensors
24
36
  from vllm.tasks import SupportedTask
25
37
  from vllm.utils.math_utils import cdiv
26
38
  from vllm.v1.core.sched.output import GrammarOutput
@@ -35,6 +47,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import \
35
47
  KVConnectorModelRunnerMixin
36
48
  from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
37
49
 
50
+ import tpu_inference.envs as envs
38
51
  from tpu_inference import utils as common_utils
39
52
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
53
  from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
@@ -48,6 +61,8 @@ from tpu_inference.layers.jax.sample.sampling_metadata import \
48
61
  TPUSupportedSamplingMetadata
49
62
  from tpu_inference.logger import init_logger
50
63
  from tpu_inference.models.common.model_loader import get_model
64
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
65
+ JaxIntermediateTensors
51
66
  from tpu_inference.models.jax.utils.weight_utils import (
52
67
  shard_put, transfer_state_with_mappings)
53
68
  from tpu_inference.runner import utils as runner_utils
@@ -64,10 +79,12 @@ from tpu_inference.runner.structured_decoding_manager import \
64
79
  StructuredDecodingManager
65
80
  from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
66
81
  from tpu_inference.utils import (device_array, make_optimized_mesh,
67
- time_function)
82
+ time_function, to_jax_dtype, to_torch_dtype)
68
83
 
69
84
  logger = init_logger(__name__)
70
85
 
86
+ logging.getLogger("torchax.tensor").setLevel(logging.ERROR)
87
+
71
88
  INVALID_TOKEN_ID = -1
72
89
  # Smallest output size
73
90
  MIN_NUM_SEQS = 8
@@ -78,17 +95,6 @@ DUMMY_METADATA = AttentionMetadata(
78
95
  request_distribution=[0, 0, 0],
79
96
  )
80
97
 
81
- TPU_STR_DTYPE_TO_TORCH_DTYPE = {
82
- "half": torch.half,
83
- "bfloat16": torch.bfloat16,
84
- "float": torch.float,
85
- "fp8": torch.float8_e4m3fn,
86
- "fp8_e4m3": torch.float8_e4m3fn,
87
- "fp8_e5m2": torch.float8_e5m2,
88
- "int8": torch.int8,
89
- "uint8": torch.uint8,
90
- }
91
-
92
98
 
93
99
  class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
94
100
  """Holds asynchronous model output specifically from a TPU runner.
@@ -243,6 +249,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
243
249
  self.maybe_forbid_compile = runner_utils.ForbidCompile(
244
250
  ) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
245
251
  self.dp_size = self.vllm_config.sharding_config.total_dp_size
252
+ self.rank = rank
253
+ self.is_first_rank = is_first_rank
254
+ self.is_last_rank = is_last_rank
246
255
 
247
256
  self._init_random()
248
257
  self._init_mesh()
@@ -253,36 +262,29 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
253
262
 
254
263
  # Delegate functions to specific manager classes.
255
264
  self.compilation_manager = CompilationManager(self)
256
- self.speculative_decoding_manager = SpeculativeDecodingManager(self)
257
- self.structured_decoding_manager = StructuredDecodingManager(self)
265
+ if self.is_last_rank:
266
+ self.speculative_decoding_manager = SpeculativeDecodingManager(
267
+ self)
268
+ self.structured_decoding_manager = StructuredDecodingManager(self)
258
269
  self.kv_cache_manager = KVCacheManager(self)
259
270
  self.mm_manager = MultiModalManager(self)
260
271
  self.persistent_batch_manager = PersistentBatchManager(
261
272
  self.requests, self.input_batch, self.encoder_cache,
262
- self.uses_mrope, self.model_config)
273
+ self.uses_mrope, self.model_config, self.is_last_rank)
263
274
  self.lora_utils = LoraUtils(self)
264
275
 
265
- cache_config = self.cache_config
266
- if cache_config.cache_dtype == "auto":
267
- model_dtype = self.dtype
268
- if isinstance(model_dtype, str):
269
- self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
270
- elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
271
- self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
272
- elif isinstance(model_dtype, torch.dtype):
273
- self.kv_cache_dtype = model_dtype
274
- else:
275
- raise ValueError(
276
- "KV cache is unsupported for model_dtype of %s",
277
- model_dtype)
278
- else:
279
- self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
280
- cache_config.cache_dtype]
276
+ cache_dtype = self.cache_config.cache_dtype
277
+ if cache_dtype == "auto":
278
+ cache_dtype = self.dtype
279
+ self.kv_cache_dtype = to_torch_dtype(cache_dtype)
281
280
 
282
281
  self._pre_async_results: AsyncPreResults | None = None
283
282
  self._substitute_placeholder_token_fn = _substitute_placeholder_token
284
283
  self.execute_model_state: ExecuteModelState | None = None
285
284
 
285
+ self.kv_caches: list[jax.Array] = []
286
+ self.layer_name_to_kvcache_index: dict[str, int] = {}
287
+
286
288
  def _init_random(self):
287
289
  if self.model_config.seed is None:
288
290
  self.model_config.seed = 0
@@ -291,7 +293,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
291
293
  self.rng_key = jax.random.key(self.model_config.seed)
292
294
 
293
295
  def _init_mesh(self) -> None:
294
- if os.getenv("NEW_MODEL_DESIGN", False):
296
+ if envs.NEW_MODEL_DESIGN:
295
297
  self.mesh = self._create_new_model_mesh()
296
298
  else:
297
299
  # NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
@@ -302,7 +304,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
302
304
  logger.info(f"Init mesh | mesh={self.mesh}")
303
305
 
304
306
  def _create_new_model_mesh(self) -> jax.sharding.Mesh:
305
- num_slices = int(os.environ.get('NUM_SLICES', 1))
307
+ num_slices = envs.NUM_SLICES
306
308
 
307
309
  logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
308
310
  f"num_slices={num_slices}")
@@ -371,7 +373,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
371
373
  devices=self.devices)
372
374
 
373
375
  def _init_phased_profiling(self) -> None:
374
- self.phased_profiling_dir = os.getenv("PHASED_PROFILING_DIR", "")
376
+ self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
375
377
  self.phase_based_profiler = None
376
378
  if self.phased_profiling_dir:
377
379
  self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
@@ -413,7 +415,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
413
415
  min_token_size=max(16, self.dp_size),
414
416
  max_token_size=scheduler_config.max_num_batched_tokens *
415
417
  self.dp_size,
416
- padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
418
+ padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
417
419
  self.num_tokens_paddings_per_dp = [
418
420
  padding // self.dp_size for padding in self.num_tokens_paddings
419
421
  ]
@@ -509,10 +511,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
509
511
  multimodal_fns = multimodal_fns or {}
510
512
  self.precompile_vision_encoder_fn = multimodal_fns.get(
511
513
  "precompile_vision_encoder_fn", None)
512
- self.get_multimodal_embeddings_fn = multimodal_fns.get(
513
- "get_multimodal_embeddings_fn", None)
514
- self.get_input_embeddings_fn = multimodal_fns.get(
515
- "get_input_embeddings_fn", None)
514
+ self.embed_multimodal_fn = multimodal_fns.get("embed_multimodal_fn",
515
+ None)
516
+ self.embed_input_ids_fn = multimodal_fns.get("embed_input_ids_fn",
517
+ None)
516
518
  self.get_mrope_input_positions_fn = multimodal_fns.get(
517
519
  "get_mrope_input_positions_fn", None)
518
520
 
@@ -524,7 +526,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
524
526
  jax.random.key(self.model_config.seed)).params()
525
527
  self.is_multimodal_model = (
526
528
  self.model_config.is_multimodal_model
527
- and self.get_multimodal_embeddings_fn is not None and hasattr(
529
+ and self.embed_multimodal_fn is not None and hasattr(
528
530
  self.model_config.hf_config, "architectures"
529
531
  ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
530
532
  and len(self.model_config.hf_config.architectures) >= 1
@@ -540,10 +542,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
540
542
  def get_kv_cache_spec(self):
541
543
  return self.kv_cache_manager.get_kv_cache_spec()
542
544
 
543
- def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
545
+ def initialize_kv_cache(self,
546
+ kv_cache_config: KVCacheConfig,
547
+ topology_order_id: int = 0) -> None:
548
+ self.topology_order_id = topology_order_id
544
549
  self.kv_cache_config = kv_cache_config
545
550
  self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
546
- self.kv_caches = []
547
551
  self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
548
552
  if has_kv_transfer_group():
549
553
  get_kv_transfer_group().register_runner(self)
@@ -555,12 +559,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
555
559
  def execute_model(
556
560
  self,
557
561
  scheduler_output: "VllmSchedulerOutput",
558
- intermediate_tensors: Optional[IntermediateTensors] = None,
559
- ) -> ModelRunnerOutput | None:
562
+ intermediate_tensors: Optional[JaxIntermediateTensors] = None,
563
+ ) -> ModelRunnerOutput | JaxIntermediateTensors | None:
560
564
  if self.execute_model_state is not None:
561
565
  raise RuntimeError("State error: sample_tokens() must be called "
562
566
  "after execute_model() returns None.")
563
- _, output = self._execute_model(scheduler_output)
567
+ _, output = self._execute_model(scheduler_output, intermediate_tensors)
564
568
  return output
565
569
 
566
570
  def sample_tokens(
@@ -686,7 +690,9 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
686
690
  def _execute_model(
687
691
  self,
688
692
  scheduler_output: "VllmSchedulerOutput",
689
- ) -> tuple[AttentionMetadata, ModelRunnerOutput | None]:
693
+ intermediate_tensors: Optional[JaxIntermediateTensors] = None,
694
+ ) -> tuple[AttentionMetadata, JaxIntermediateTensors | ModelRunnerOutput
695
+ | None]:
690
696
  self.persistent_batch_manager.update_states(
691
697
  scheduler_output, self.get_mrope_input_positions_fn)
692
698
  if not scheduler_output.total_num_scheduled_tokens:
@@ -764,7 +770,6 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
764
770
  scheduler_output) as kv_connector_output:
765
771
  # NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
766
772
  # but one of them would be `None`
767
-
768
773
  (self.kv_caches, hidden_states,
769
774
  aux_hidden_states) = self.model_fn(
770
775
  self.state,
@@ -775,8 +780,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
775
780
  input_positions,
776
781
  tuple(self.layer_name_to_kvcache_index.items()),
777
782
  lora_metadata,
783
+ intermediate_tensors,
784
+ self.is_first_rank,
785
+ self.is_last_rank,
778
786
  )
779
-
787
+ if not get_pp_group().is_last_rank:
788
+ assert isinstance(hidden_states, JaxIntermediateTensors)
789
+ hidden_states.kv_connector_output = kv_connector_output
790
+ return attn_metadata, hidden_states
780
791
  hidden_states = self._select_from_array_fn(hidden_states,
781
792
  logits_indices)
782
793
  logits = self.compute_logits_fn(
@@ -818,22 +829,35 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
818
829
  sharding = None
819
830
  if self.dp_size > 1:
820
831
  sharding = NamedSharding(self.mesh,
821
- PartitionSpec(ShardingAxisName.ATTN_DATA))
832
+ PartitionSpec(ShardingAxisName.MLP_DATA))
822
833
 
823
834
  tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
824
835
  self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
836
+
837
+ # TODO(pooyam): Should we move this to `_prepare_inputs`?
838
+ if tpu_sampling_metadata.do_sampling:
839
+ self.rng_params_for_sampling, step_rng = jax.random.split(
840
+ self.rng_params_for_sampling)
841
+ else:
842
+ step_rng = self.rng_params_for_sampling
843
+
825
844
  if spec_decode_metadata is None:
826
845
  next_tokens = sample(
827
- self.rng_params_for_sampling,
846
+ step_rng,
828
847
  self.mesh,
829
848
  logits,
830
849
  tpu_sampling_metadata,
831
850
  )
832
851
  else:
852
+ if tpu_sampling_metadata.do_sampling:
853
+ bonus_rng, rejection_rng = jax.random.split(step_rng)
854
+ else:
855
+ bonus_rng = step_rng
856
+ rejection_rng = step_rng
833
857
  bonus_logits = self._select_from_array_fn(
834
858
  logits, spec_decode_metadata.bonus_logits_indices)
835
859
  bonus_token_ids = sample(
836
- self.rng_params_for_sampling,
860
+ bonus_rng,
837
861
  self.mesh,
838
862
  bonus_logits,
839
863
  tpu_sampling_metadata,
@@ -847,7 +871,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
847
871
  target_logits=target_logits,
848
872
  bonus_token_ids=bonus_token_ids,
849
873
  sampling_metadata=tpu_sampling_metadata,
850
- key=self.rng_params_for_sampling,
874
+ key=rejection_rng,
851
875
  )
852
876
 
853
877
  if tpu_sampling_metadata.logprobs:
@@ -1332,7 +1356,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1332
1356
  _request_distribution = []
1333
1357
  for dp_rank in range(dp_size):
1334
1358
  _num_reqs = num_req_per_dp_rank[dp_rank]
1335
- _request_distribution.append([0, 0, _num_reqs])
1359
+ # The batch has been reordered by _reorder_batch so decode requests come first
1360
+ # Count decode requests (those with num_scheduled_tokens == 1) in this DP rank
1361
+ num_decode_in_dp_rank = 0
1362
+ for req_id in req_ids_dp[dp_rank]:
1363
+ if scheduler_output.num_scheduled_tokens[req_id] == 1:
1364
+ num_decode_in_dp_rank += 1
1365
+ _request_distribution.append(
1366
+ [num_decode_in_dp_rank, num_decode_in_dp_rank, _num_reqs])
1336
1367
  request_distribution = np.array(_request_distribution).ravel()
1337
1368
 
1338
1369
  use_spec_decode = len(
@@ -1361,7 +1392,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1361
1392
  self.mesh,
1362
1393
  self.input_batch,
1363
1394
  padded_num_reqs,
1364
- sharding=data_parallel_attn_sharding,
1395
+ sharding=NamedSharding(self.mesh,
1396
+ PartitionSpec(ShardingAxisName.MLP_DATA)),
1365
1397
  )
1366
1398
  if self.uses_mrope:
1367
1399
  positions = mrope_positions
@@ -1391,7 +1423,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1391
1423
  block_tables[
1392
1424
  req_offset:req_offset + _num_reqs, :self.
1393
1425
  max_num_blocks_per_req] = self.input_batch.block_table[
1394
- 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1426
+ kv_cache_gid].get_cpu_tensor()[req_indices_dp[dp_rank]]
1395
1427
  # Convert block_tables to 1D on cpu.
1396
1428
  block_tables = block_tables.reshape(-1)
1397
1429
  block_tables = device_array(
@@ -1651,7 +1683,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1651
1683
  def _get_input_ids_embeds(self, input_ids: jax.Array,
1652
1684
  mm_embeds: list[jax.Array]):
1653
1685
  if self.is_multimodal_model:
1654
- inputs_embeds = self.get_input_embeddings_fn(
1686
+ inputs_embeds = self.embed_input_ids_fn(
1655
1687
  self.state,
1656
1688
  input_ids,
1657
1689
  mm_embeds,
@@ -1706,3 +1738,34 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1706
1738
  mappings=mappings,
1707
1739
  transpose_keys=transpose_keys,
1708
1740
  shard=shard)
1741
+
1742
+ def get_intermediate_tensor_spec(self, num_tokens: int):
1743
+ jax_dtype = to_jax_dtype(self.dtype)
1744
+ num_padded_tokens = runner_utils.get_padded_token_len(
1745
+ self.num_tokens_paddings, num_tokens)
1746
+ sharding = NamedSharding(self.mesh, PartitionSpec())
1747
+ hidden_size = self.model_config.get_hidden_size()
1748
+ spec = jax.ShapeDtypeStruct(shape=(num_padded_tokens, hidden_size),
1749
+ dtype=jax_dtype,
1750
+ sharding=sharding)
1751
+ tensor_spec = {"hidden_states": spec, "residual": spec}
1752
+ return tensor_spec
1753
+
1754
+ def get_uuid_for_jax_transfer(self,
1755
+ scheduler_output: "VllmSchedulerOutput",
1756
+ rank: int, step: int) -> int:
1757
+ '''
1758
+ Get a uuid for jax.transfer, here we use the hash of
1759
+ scheduler_output + counter_step + sender's rank
1760
+ '''
1761
+ scheduler_output_str = ""
1762
+ if not scheduler_output.num_scheduled_tokens:
1763
+ scheduler_output_str = "empty_batch"
1764
+ else:
1765
+ scheduler_output_str = str(
1766
+ sorted(scheduler_output.num_scheduled_tokens.items()))
1767
+ unique_str = f'{scheduler_output_str} {step} {rank}'
1768
+ import hashlib
1769
+ hasher = hashlib.sha1()
1770
+ hasher.update(unique_str.encode('utf-8'))
1771
+ return int.from_bytes(hasher.digest()[:8], 'big')
@@ -15,6 +15,7 @@ import jax
15
15
  from jax._src.interpreters import pxla
16
16
  from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
17
17
 
18
+ from tpu_inference import envs
18
19
  from tpu_inference.logger import init_logger
19
20
  from tpu_inference.runner.input_batch import InputBatch
20
21
 
@@ -306,8 +307,7 @@ class PhasedBasedProfiler:
306
307
  InferencePhase.BALANCED: False
307
308
  }
308
309
  self.default_profiling_options = jax.profiler.ProfileOptions()
309
- self.default_profiling_options.python_tracer_level = os.getenv(
310
- "PYTHON_TRACER_LEVEL", 0)
310
+ self.default_profiling_options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
311
311
 
312
312
  self.current_phase: str = ""
313
313
 
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """Implements the Eagle3 proposer for speculative decoding on JAX/TPU."""
2
15
  import functools
3
16
  from dataclasses import replace
@@ -6,13 +19,19 @@ from typing import Any, Optional
6
19
  import jax
7
20
  import jax.numpy as jnp
8
21
  import numpy as np
22
+ from flax import nnx
23
+ from jax import lax
24
+ from jax.sharding import NamedSharding, PartitionSpec
9
25
  from vllm.config import VllmConfig
10
26
 
11
27
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
28
+ from tpu_inference.logger import init_logger
12
29
  from tpu_inference.models.common.model_loader import get_model
13
30
  from tpu_inference.runner import utils as runner_utils
14
31
  from tpu_inference.utils import device_array
15
32
 
33
+ logger = init_logger(__name__)
34
+
16
35
 
17
36
  class Eagle3Proposer:
18
37
  """A proposer for speculative decoding using the Eagle3 method.
@@ -51,9 +70,22 @@ class Eagle3Proposer:
51
70
  """Loads the draft model."""
52
71
  self.model_fn, self.compute_logits_fn, self.combine_hidden_states_fn, _, self.state, _, _ = get_model(
53
72
  self.vllm_config, self.rng_key, self.mesh, is_draft_model=True)
54
- if 'embed_tokens' in self.state.model:
55
- del self.state.model['embed_tokens']
56
- self.state.model.embed_tokens = target_model.model.embed
73
+
74
+ draft_embed_tokens = getattr(self.state.model, 'embed_tokens', None)
75
+ if draft_embed_tokens is None or ~jnp.any(
76
+ draft_embed_tokens.embedding):
77
+ logger.info(
78
+ "Draft model does not have embedding. Setting draft model's embed_tokens to target model's embed"
79
+ )
80
+ self.state.model.embed_tokens = target_model.model.embed
81
+ elif jnp.array_equal(draft_embed_tokens.embedding,
82
+ target_model.model.embed.embedding):
83
+ logger.info(
84
+ "Draft model's embed_tokens is identical to target model's embed. Sharing the embedding."
85
+ )
86
+ self.state.model.embed_tokens = target_model.model.embed
87
+ else:
88
+ logger.info("Draft model has its own embed_tokens.")
57
89
 
58
90
  @functools.partial(jax.jit, static_argnums=(0, ))
59
91
  def _prepare_input_ids(
@@ -111,6 +143,17 @@ class Eagle3Proposer:
111
143
  max_num_blocks_per_req)
112
144
  new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
113
145
 
146
+ positions = lax.with_sharding_constraint(
147
+ positions, NamedSharding(self.mesh, PartitionSpec(None, )))
148
+ clamped_positions = lax.with_sharding_constraint(
149
+ clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
150
+ new_seq_lens = lax.with_sharding_constraint(
151
+ new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
152
+ query_start_loc = lax.with_sharding_constraint(
153
+ query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
154
+ new_block_tables = lax.with_sharding_constraint(
155
+ new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
156
+
114
157
  return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
115
158
 
116
159
  @functools.partial(jax.jit, static_argnums=(0, ))
@@ -122,6 +165,7 @@ class Eagle3Proposer:
122
165
  @functools.partial(jax.jit, static_argnums=(0, ))
123
166
  def _prepare_hidden_states_and_input_ids(
124
167
  self,
168
+ state: nnx.State,
125
169
  aux_hidden_states: tuple[jax.Array, ...],
126
170
  query_start_loc: jax.Array,
127
171
  target_token_ids: jax.Array,
@@ -130,7 +174,7 @@ class Eagle3Proposer:
130
174
  ) -> tuple[jax.Array, jax.Array, jax.Array]:
131
175
  target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
132
176
  target_hidden_states = self.combine_hidden_states_fn(
133
- self.state, target_hidden_states)
177
+ state, target_hidden_states)
134
178
 
135
179
  input_ids, last_token_indices = self._prepare_input_ids(
136
180
  query_start_loc, target_token_ids, next_token_ids, num_reqs)
@@ -177,8 +221,8 @@ class Eagle3Proposer:
177
221
  block_tables=device_array(
178
222
  self.mesh, block_tables))
179
223
  target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
180
- aux_hidden_states, attn_metadata.query_start_loc, input_ids,
181
- next_token_ids, num_reqs)
224
+ self.state, aux_hidden_states, attn_metadata.query_start_loc,
225
+ input_ids, next_token_ids, num_reqs)
182
226
  return target_hidden_states, input_ids, last_token_indices, attn_metadata
183
227
 
184
228
  # Host copies from the metadata prepared by the runner.
@@ -242,12 +286,13 @@ class Eagle3Proposer:
242
286
 
243
287
  attn_metadata = replace(attn_metadata, block_tables=block_tables)
244
288
  return self._filter_token_and_prepare_initial_inputs(
245
- token_indices, query_start_loc, seq_lens, input_ids,
289
+ self.state, token_indices, query_start_loc, seq_lens, input_ids,
246
290
  aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
247
291
 
248
292
  @functools.partial(jax.jit, static_argnums=(0, ))
249
293
  def _filter_token_and_prepare_initial_inputs(
250
294
  self,
295
+ state: nnx.State,
251
296
  token_indices: jax.Array,
252
297
  query_start_loc: jax.Array,
253
298
  seq_lens: jax.Array,
@@ -275,35 +320,51 @@ class Eagle3Proposer:
275
320
  )
276
321
 
277
322
  target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
278
- [h[token_indices] for h in aux_hidden_states], query_start_loc,
279
- target_token_ids, next_token_ids, num_reqs)
323
+ state, [h[token_indices] for h in aux_hidden_states],
324
+ query_start_loc, target_token_ids, next_token_ids, num_reqs)
280
325
 
281
326
  return target_hidden_states, input_ids, last_token_indices, attn_metadata
282
327
 
283
328
  @functools.partial(jax.jit, static_argnums=(0, ))
284
329
  def _select_draft_token_ids(
285
330
  self,
331
+ state: nnx.State,
286
332
  hidden_states: jax.Array,
287
333
  last_token_indices: jax.Array,
288
334
  ) -> jax.Array:
289
335
  sample_hidden_states = hidden_states[last_token_indices]
290
- return self._get_draft_token_ids(sample_hidden_states)
336
+ sample_hidden_states = lax.with_sharding_constraint(
337
+ sample_hidden_states,
338
+ NamedSharding(self.mesh, PartitionSpec(None, None)))
339
+ return self._get_draft_token_ids(state, sample_hidden_states)
291
340
 
292
341
  @functools.partial(jax.jit, static_argnums=(0, ))
293
- def _get_draft_token_ids(self, hidden_states: jax.Array) -> jax.Array:
342
+ def _get_draft_token_ids(self, state: nnx.State,
343
+ hidden_states: jax.Array) -> jax.Array:
294
344
  lora_metadata = None
295
- logits = self.compute_logits_fn(self.state, hidden_states,
296
- lora_metadata)
297
- return jnp.argmax(logits, axis=-1)
345
+ logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
346
+ draft_token_ids = jnp.argmax(logits, axis=-1)
347
+ return lax.with_sharding_constraint(
348
+ draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
298
349
 
299
350
  @functools.partial(jax.jit, static_argnums=(0, ))
300
351
  def _select_inputs_for_loop_speculation(
301
- self, positions: jax.Array, residual: jax.Array,
352
+ self, state: nnx.State, positions: jax.Array, residual: jax.Array,
302
353
  hidden_states: jax.Array,
303
354
  last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
304
- return positions[last_token_indices], residual[
305
- last_token_indices], self._select_draft_token_ids(
306
- hidden_states, last_token_indices)
355
+ positions = positions[last_token_indices]
356
+ residual = residual[last_token_indices]
357
+ draft_token_ids = self._select_draft_token_ids(state, hidden_states,
358
+ last_token_indices)
359
+
360
+ positions = lax.with_sharding_constraint(
361
+ positions, NamedSharding(self.mesh, PartitionSpec(None, )))
362
+ residual = lax.with_sharding_constraint(
363
+ residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
364
+ draft_token_ids = lax.with_sharding_constraint(
365
+ draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
366
+
367
+ return positions, residual, draft_token_ids
307
368
 
308
369
  def propose(
309
370
  self,
@@ -330,11 +391,11 @@ class Eagle3Proposer:
330
391
 
331
392
  if self.num_speculative_tokens == 1:
332
393
  return kv_caches, self._select_draft_token_ids(
333
- hidden_states, last_token_indices)
394
+ self.state, hidden_states, last_token_indices)
334
395
 
335
396
  positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
336
- attn_metadata.input_positions, residual[0], hidden_states,
337
- last_token_indices)
397
+ self.state, attn_metadata.input_positions, residual[0],
398
+ hidden_states, last_token_indices)
338
399
 
339
400
  draft_token_ids_list = [draft_token_ids]
340
401
 
@@ -359,7 +420,8 @@ class Eagle3Proposer:
359
420
  attn_metadata,
360
421
  )
361
422
  hidden_states = residual[0]
362
- draft_token_ids = self._get_draft_token_ids(new_hidden_states)
423
+ draft_token_ids = self._get_draft_token_ids(
424
+ self.state, new_hidden_states)
363
425
  draft_token_ids_list.append(draft_token_ids)
364
426
 
365
427
  # [batch_size, num_speculative_tokens]
tpu_inference/tpu_info.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import glob
2
16
  import os
3
17