tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +45 -15
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +41 -16
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  240. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,20 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import copy
2
16
  import functools
3
- import os
17
+ import logging
4
18
  import random
5
19
  from contextlib import nullcontext
6
20
  from dataclasses import dataclass
@@ -14,7 +28,6 @@ import vllm.envs as vllm_envs
14
28
  from flax import nnx
15
29
  from jax.experimental import mesh_utils
16
30
  from jax.sharding import NamedSharding, PartitionSpec
17
- from torchax.ops.mappings import t2j_dtype
18
31
  from vllm.config import VllmConfig
19
32
  from vllm.distributed import get_pp_group
20
33
  from vllm.distributed.kv_transfer import (get_kv_transfer_group,
@@ -66,10 +79,12 @@ from tpu_inference.runner.structured_decoding_manager import \
66
79
  StructuredDecodingManager
67
80
  from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
68
81
  from tpu_inference.utils import (device_array, make_optimized_mesh,
69
- time_function, to_torch_dtype)
82
+ time_function, to_jax_dtype, to_torch_dtype)
70
83
 
71
84
  logger = init_logger(__name__)
72
85
 
86
+ logging.getLogger("torchax.tensor").setLevel(logging.ERROR)
87
+
73
88
  INVALID_TOKEN_ID = -1
74
89
  # Smallest output size
75
90
  MIN_NUM_SEQS = 8
@@ -493,10 +508,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
493
508
  multimodal_fns = multimodal_fns or {}
494
509
  self.precompile_vision_encoder_fn = multimodal_fns.get(
495
510
  "precompile_vision_encoder_fn", None)
496
- self.get_multimodal_embeddings_fn = multimodal_fns.get(
497
- "get_multimodal_embeddings_fn", None)
498
- self.get_input_embeddings_fn = multimodal_fns.get(
499
- "get_input_embeddings_fn", None)
511
+ self.embed_multimodal_fn = multimodal_fns.get("embed_multimodal_fn",
512
+ None)
513
+ self.embed_input_ids_fn = multimodal_fns.get("embed_input_ids_fn",
514
+ None)
500
515
  self.get_mrope_input_positions_fn = multimodal_fns.get(
501
516
  "get_mrope_input_positions_fn", None)
502
517
 
@@ -508,7 +523,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
508
523
  jax.random.key(self.model_config.seed)).params()
509
524
  self.is_multimodal_model = (
510
525
  self.model_config.is_multimodal_model
511
- and self.get_multimodal_embeddings_fn is not None and hasattr(
526
+ and self.embed_multimodal_fn is not None and hasattr(
512
527
  self.model_config.hf_config, "architectures"
513
528
  ) #TODO: Remove Llama Guard 4 specific condition once the LG4 Vision portion is implemented
514
529
  and len(self.model_config.hf_config.architectures) >= 1
@@ -524,7 +539,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
524
539
  def get_kv_cache_spec(self):
525
540
  return self.kv_cache_manager.get_kv_cache_spec()
526
541
 
527
- def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
542
+ def initialize_kv_cache(self,
543
+ kv_cache_config: KVCacheConfig,
544
+ topology_order_id: int = 0) -> None:
545
+ self.topology_order_id = topology_order_id
528
546
  self.kv_cache_config = kv_cache_config
529
547
  self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
530
548
  self.kv_caches = []
@@ -809,7 +827,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
809
827
  sharding = None
810
828
  if self.dp_size > 1:
811
829
  sharding = NamedSharding(self.mesh,
812
- PartitionSpec(ShardingAxisName.ATTN_DATA))
830
+ PartitionSpec(ShardingAxisName.MLP_DATA))
813
831
 
814
832
  tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
815
833
  self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
@@ -1336,7 +1354,14 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1336
1354
  _request_distribution = []
1337
1355
  for dp_rank in range(dp_size):
1338
1356
  _num_reqs = num_req_per_dp_rank[dp_rank]
1339
- _request_distribution.append([0, 0, _num_reqs])
1357
+ # The batch has been reordered by _reorder_batch so decode requests come first
1358
+ # Count decode requests (those with num_scheduled_tokens == 1) in this DP rank
1359
+ num_decode_in_dp_rank = 0
1360
+ for req_id in req_ids_dp[dp_rank]:
1361
+ if scheduler_output.num_scheduled_tokens[req_id] == 1:
1362
+ num_decode_in_dp_rank += 1
1363
+ _request_distribution.append(
1364
+ [num_decode_in_dp_rank, num_decode_in_dp_rank, _num_reqs])
1340
1365
  request_distribution = np.array(_request_distribution).ravel()
1341
1366
 
1342
1367
  use_spec_decode = len(
@@ -1365,7 +1390,8 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1365
1390
  self.mesh,
1366
1391
  self.input_batch,
1367
1392
  padded_num_reqs,
1368
- sharding=data_parallel_attn_sharding,
1393
+ sharding=NamedSharding(self.mesh,
1394
+ PartitionSpec(ShardingAxisName.MLP_DATA)),
1369
1395
  )
1370
1396
  if self.uses_mrope:
1371
1397
  positions = mrope_positions
@@ -1395,7 +1421,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1395
1421
  block_tables[
1396
1422
  req_offset:req_offset + _num_reqs, :self.
1397
1423
  max_num_blocks_per_req] = self.input_batch.block_table[
1398
- 0].get_cpu_tensor()[req_indices_dp[dp_rank]]
1424
+ kv_cache_gid].get_cpu_tensor()[req_indices_dp[dp_rank]]
1399
1425
  # Convert block_tables to 1D on cpu.
1400
1426
  block_tables = block_tables.reshape(-1)
1401
1427
  block_tables = device_array(
@@ -1655,7 +1681,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1655
1681
  def _get_input_ids_embeds(self, input_ids: jax.Array,
1656
1682
  mm_embeds: list[jax.Array]):
1657
1683
  if self.is_multimodal_model:
1658
- inputs_embeds = self.get_input_embeddings_fn(
1684
+ inputs_embeds = self.embed_input_ids_fn(
1659
1685
  self.state,
1660
1686
  input_ids,
1661
1687
  mm_embeds,
@@ -1712,8 +1738,7 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1712
1738
  shard=shard)
1713
1739
 
1714
1740
  def get_intermediate_tensor_spec(self, num_tokens: int):
1715
- impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
1716
- jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
1741
+ jax_dtype = to_jax_dtype(self.dtype)
1717
1742
  num_padded_tokens = runner_utils.get_padded_token_len(
1718
1743
  self.num_tokens_paddings, num_tokens)
1719
1744
  sharding = NamedSharding(self.mesh, PartitionSpec())
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """Implements the Eagle3 proposer for speculative decoding on JAX/TPU."""
2
15
  import functools
3
16
  from dataclasses import replace
tpu_inference/tpu_info.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import glob
2
16
  import os
3
17
 
tpu_inference/utils.py CHANGED
@@ -3,7 +3,7 @@ import time
3
3
  from collections import defaultdict
4
4
  from collections.abc import Sequence
5
5
  from functools import wraps
6
- from typing import Any, Callable, List, Tuple
6
+ from typing import Any, Callable, List, Tuple, Union
7
7
 
8
8
  import jax
9
9
  import jax.numpy as jnp
@@ -28,9 +28,9 @@ TPU_SECOND_LAST_MINOR = 8
28
28
 
29
29
  # Map vllm dtype string that doesn't exactly match jax dtype string name.
30
30
  _VLLM_DTYPE_STR_TO_JAX_DTYPE = {
31
- "fp8": jnp.float8_e4m3fn,
32
- "fp8_e4m3": jnp.float8_e4m3fn,
33
- "fp8_e5m2": jnp.float8_e5m2,
31
+ "fp8": jnp.float8_e4m3fn.dtype,
32
+ "fp8_e4m3": jnp.float8_e4m3fn.dtype,
33
+ "fp8_e5m2": jnp.float8_e5m2.dtype,
34
34
  }
35
35
 
36
36
 
@@ -60,6 +60,10 @@ _megacore = False
60
60
  logger = init_logger(__name__)
61
61
 
62
62
 
63
+ def align_to(unpadded_dim, pad_multiple):
64
+ return (unpadded_dim + pad_multiple - 1) // pad_multiple * pad_multiple
65
+
66
+
63
67
  def enable_megacore() -> None:
64
68
  global _megacore
65
69
  _megacore = True
@@ -186,7 +190,8 @@ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
186
190
 
187
191
 
188
192
  def get_dtype_packing(dtype):
189
- bits = dtypes.bit_width(dtype)
193
+ bits = (dtypes.bit_width(dtype)
194
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
190
195
  return 32 // bits
191
196
 
192
197
 
@@ -271,40 +276,11 @@ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
271
276
 
272
277
  def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
273
278
  """
274
- A wrapper function of vllm.utils.get_hash_fn_by_name to support builtin
279
+ A wrapper function of vllm.utils.hashing.get_hash_fn_by_name to support builtin
275
280
  """
276
281
  if hash_fn_name == "builtin":
277
282
  return hash
278
- return utils.get_hash_fn_by_name(hash_fn_name)
279
-
280
-
281
- def quantize_kv(key: jax.Array, value: jax.Array,
282
- kv_cache_quantized_dtype: jnp.dtype, k_scale: float,
283
- v_scale: float) -> Tuple[jax.Array, jax.Array]:
284
- """
285
- Quantize the key and value tensors.
286
-
287
- Args:
288
- key: The key tensor to quantize.
289
- value: The value tensor to quantize.
290
- kv_cache_quantized_dtype: The dtype to quantize the key and value tensors to.
291
- q_scale: The scale to quantize the key and value tensors by.
292
- k_scale: The scale to quantize the key tensor by.
293
- v_scale: The scale to quantize the value tensor by.
294
-
295
- Returns:
296
- Tuple[jax.Array, jax.Array]: The quantized key and value tensors.
297
- """
298
- dtype_info = jnp.finfo(kv_cache_quantized_dtype)
299
- minval, maxval = float(dtype_info.min), float(dtype_info.max)
300
- key = key.astype(jnp.float32) / k_scale
301
- key = jnp.clip(key, minval, maxval)
302
- key = key.astype(kv_cache_quantized_dtype)
303
- value = value.astype(jnp.float32) / v_scale
304
- value = jnp.clip(value, minval, maxval)
305
- value = value.astype(kv_cache_quantized_dtype)
306
-
307
- return key, value
283
+ return utils.hashing.get_hash_fn_by_name(hash_fn_name)
308
284
 
309
285
 
310
286
  def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
@@ -321,6 +297,36 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
321
297
  return to_jax_dtype(str_dtype)
322
298
 
323
299
 
300
+ def get_mesh_shape_product(
301
+ mesh: Mesh,
302
+ axes: Union[str, list[str], None],
303
+ ) -> int:
304
+ """
305
+ Get the product of mesh dimensions for one or more axes.
306
+
307
+ Examples:
308
+ # Single axis (defaults to 1 if not present)
309
+ get_mesh_shape_product(mesh, "model")
310
+
311
+ # Multiple axes - computes product of their sizes
312
+ get_mesh_shape_product(mesh, ["model", "attn_dp"])
313
+
314
+ # None means no sharding on this dimension
315
+ get_mesh_shape_product(mesh, None) # returns 1
316
+ """
317
+ if axes is None:
318
+ return 1
319
+
320
+ if isinstance(axes, str):
321
+ axes = [axes]
322
+
323
+ product = 1
324
+ for axis in axes:
325
+ product *= mesh.shape.get(axis, 1)
326
+
327
+ return product
328
+
329
+
324
330
  def time_function(func):
325
331
  """
326
332
  A decorator to measure the execution time of a function.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field
6
6
  from typing import Callable, Dict, Optional, Tuple
7
7
 
8
8
  import jax
9
- import jax.numpy as jnp
10
9
  import jaxlib
11
10
  import jaxtyping
12
11
  import vllm.envs as vllm_envs
@@ -19,30 +18,25 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
19
18
  from vllm.lora.request import LoRARequest
20
19
  from vllm.tasks import SupportedTask
21
20
  from vllm.v1 import utils as vllm_utils
22
- from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size
21
+ from vllm.v1.core.kv_cache_utils import (get_kv_cache_groups, get_num_blocks,
22
+ get_uniform_page_size)
23
23
  from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
24
24
  from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
25
25
  from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
26
26
 
27
27
  from tpu_inference import envs, utils
28
28
  from tpu_inference.distributed import jax_parallel_state
29
- from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
30
- get_node_id)
29
+ from tpu_inference.distributed.utils import (get_device_topology_order_id,
30
+ get_host_ip, get_kv_transfer_port)
31
31
  from tpu_inference.layers.common.sharding import ShardingConfigManager
32
32
  from tpu_inference.logger import init_logger
33
33
  from tpu_inference.models.jax.jax_intermediate_tensor import \
34
34
  JaxIntermediateTensors
35
- from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
35
+ from tpu_inference.runner.kv_cache import get_attention_page_size_bytes
36
36
  from tpu_inference.runner.tpu_runner import TPUModelRunner
37
37
 
38
38
  logger = init_logger(__name__)
39
39
 
40
- _DTYPE: dict[str, jnp.dtype] = {
41
- "bfloat16": jnp.bfloat16,
42
- "float": jnp.float32,
43
- "float32": jnp.float32,
44
- }
45
-
46
40
 
47
41
  @dataclass
48
42
  class PPConfig:
@@ -77,21 +71,6 @@ class TPUWorker:
77
71
  ip: str = "localhost",
78
72
  prev_worker_ip: str = "localhost",
79
73
  ):
80
- # If we use vLLM's model implementation in PyTorch, we should set it
81
- # with torch version of the dtype.
82
- impl = envs.MODEL_IMPL_TYPE
83
- if impl != "vllm": # vllm-pytorch implementation does not need this conversion
84
-
85
- # NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
86
- if not isinstance(vllm_config.model_config.dtype, str):
87
- logger.warning(
88
- "The model dtype is not properly set for JAX backend. "
89
- "Overwriting it to jnp.bfloat16")
90
- vllm_config.model_config.dtype = jnp.bfloat16
91
- else:
92
- vllm_config.model_config.dtype = _DTYPE.get(
93
- vllm_config.model_config.dtype, jnp.bfloat16)
94
-
95
74
  self.vllm_config = vllm_config
96
75
  self.model_config = vllm_config.model_config
97
76
  self.parallel_config = vllm_config.parallel_config
@@ -250,14 +229,33 @@ class TPUWorker:
250
229
  need_pp=self.parallel_config.pipeline_parallel_size > 1)
251
230
 
252
231
  ensure_kv_transfer_initialized(self.vllm_config)
253
- self.model_runner = TPUModelRunner(
254
- self.vllm_config, self.devices, self.rank, self.rank == 0,
255
- self.rank == self.pp_config.pp_world_size - 1)
232
+
233
+ is_first_rank = True
234
+ is_last_rank = True
235
+ self.topology_order_id = self.rank
236
+ if self.parallel_config.pipeline_parallel_size > 1:
237
+ is_first_rank = self.rank == 0
238
+ is_last_rank = self.rank == self.pp_config.pp_world_size - 1
239
+ else:
240
+ # topology_order_id is used to determine the KV cache
241
+ # mapping between P/D workers
242
+ if multihost_backend == "ray":
243
+ self.topology_order_id = get_device_topology_order_id(
244
+ jax.local_devices(), jax.devices())
245
+
246
+ self.model_runner = TPUModelRunner(self.vllm_config, self.devices,
247
+ self.rank, is_first_rank,
248
+ is_last_rank)
256
249
  logger.info(f"Init worker | "
257
250
  f"rank={self.rank} | "
258
- f"node_id={get_node_id()} | "
251
+ f"is_first_rank={is_first_rank} | "
252
+ f"is_last_rank={is_last_rank} | "
253
+ f"topology_order_id={self.topology_order_id} | "
259
254
  f"is_driver_worker={self.is_driver_worker} | "
260
- f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
255
+ f"hbm={utils.hbm_usage_gb(self.devices)}GiB |"
256
+ f"self.devices={self.devices} | "
257
+ f"total devices={jax.devices()} | "
258
+ f"local_devices={jax.local_devices()}")
261
259
  vllm_utils.report_usage_stats(self.vllm_config)
262
260
 
263
261
  def initialize_pp_transfer_connect(self):
@@ -395,46 +393,56 @@ class TPUWorker:
395
393
  # responsible for this translation. When vLLM can be modified, this
396
394
  # method should be changed to return `dict[str, AbstractKVCacheSpec]`,
397
395
  # and the vLLM side should be updated to handle the translation.
398
- kv_cache_specs = self.model_runner.get_kv_cache_spec()
396
+ kv_cache_spec = self.model_runner.get_kv_cache_spec()
399
397
 
400
- if len(kv_cache_specs) == 0:
401
- return kv_cache_specs
398
+ if len(kv_cache_spec) == 0:
399
+ return kv_cache_spec
402
400
 
403
401
  # TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
404
402
  # feature that allows overriding page_size_bytes of KVCacheSpec.
405
403
  vllm_page_size_bytes = get_uniform_page_size(
406
- list(kv_cache_specs.values()))
407
- rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh,
408
- kv_cache_specs)
404
+ list(kv_cache_spec.values()))
405
+ attention_page_size_bytes = get_attention_page_size_bytes(
406
+ self.model_runner.mesh, kv_cache_spec)
409
407
 
410
- if vllm_page_size_bytes != rpa_page_size_bytes:
408
+ if vllm_page_size_bytes != attention_page_size_bytes:
411
409
  logger.info(
412
- f"KV cache page size calculated by vLLM "
413
- f"({vllm_page_size_bytes} Bytes) does not match with actual "
414
- f"page size used by RPA kernel ({rpa_page_size_bytes} Bytes). "
415
- f"Recalculating number of KV blocks using actual page size.")
416
-
410
+ f"Page size calculated by vLLM ({vllm_page_size_bytes} Bytes) "
411
+ f"does not match with actual page size used by the kernel "
412
+ f"({attention_page_size_bytes} Bytes). Recalculating number of "
413
+ f"KV blocks using actual page size.")
414
+
415
+ kv_cache_groups = get_kv_cache_groups(self.vllm_config,
416
+ kv_cache_spec)
417
+ group_size = max(
418
+ len(group.layer_names) for group in kv_cache_groups)
417
419
  available_memory = self.determine_available_memory()
418
- num_blocks = get_num_blocks(self.vllm_config, len(kv_cache_specs),
419
- available_memory, rpa_page_size_bytes)
420
-
420
+ num_blocks = get_num_blocks(self.vllm_config, group_size,
421
+ available_memory,
422
+ attention_page_size_bytes)
421
423
  cache_config = self.vllm_config.cache_config
422
424
  cache_config.num_gpu_blocks_override = num_blocks
423
425
 
424
- return kv_cache_specs
426
+ return kv_cache_spec
425
427
 
426
428
  def initialize_from_config(
427
429
  self,
428
430
  kv_cache_config: KVCacheConfig,
429
431
  ) -> None:
430
432
  """Allocate GPU KV cache with the specified kv_cache_config."""
431
- self.model_runner.initialize_kv_cache(kv_cache_config)
433
+ # Precompile functions with large vocab_size tensors before allocating KV cache to avoid OOM
434
+ if not (envs.SKIP_JAX_PRECOMPILE or
435
+ (hasattr(self.model_runner.model_config, "enforce_eager")
436
+ and self.model_runner.model_config.enforce_eager)):
437
+ self.model_runner.compilation_manager._precompile_sampling()
438
+ self.model_runner.compilation_manager._precompile_gather_logprobs()
439
+ self.model_runner.initialize_kv_cache(kv_cache_config,
440
+ self.topology_order_id)
432
441
 
433
442
  def get_node_kv_ip_port(self) -> tuple[int, str, int]:
434
- node_id = get_node_id()
435
443
  ip = get_host_ip()
436
444
  port = get_kv_transfer_port()
437
- return (int(node_id), ip, int(port))
445
+ return (int(self.topology_order_id), ip, int(port))
438
446
 
439
447
  def check_health(self) -> None:
440
448
  # worker will always be healthy as long as it's running.
@@ -456,3 +464,8 @@ class TPUWorker:
456
464
 
457
465
  def shutdown(self) -> None:
458
466
  return
467
+
468
+ # Ray executor do not need handshake metadata
469
+ # as we pass the kv_parameters through proxy server
470
+ def get_kv_connector_handshake_metadata(self) -> None:
471
+ pass
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpu_inference
3
- Version: 0.11.1.dev202512030818
3
+ Version: 0.13.0rc2.post7
4
4
  Author: tpu_inference Contributors
5
5
  Classifier: Development Status :: 3 - Alpha
6
6
  Classifier: Intended Audience :: Developers
@@ -14,7 +14,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
14
  Requires-Python: >=3.10
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
- Requires-Dist: tpu-info==0.4.0
17
+ Requires-Dist: tpu-info==0.7.1
18
18
  Requires-Dist: yapf==0.43.0
19
19
  Requires-Dist: pytest
20
20
  Requires-Dist: pytest-mock
@@ -25,13 +25,17 @@ Requires-Dist: jax[tpu]==0.8.0
25
25
  Requires-Dist: jaxlib==0.8.0
26
26
  Requires-Dist: jaxtyping
27
27
  Requires-Dist: flax==0.11.1
28
- Requires-Dist: torchax==0.0.7
28
+ Requires-Dist: torchax==0.0.10
29
29
  Requires-Dist: qwix==0.1.1
30
30
  Requires-Dist: torchvision==0.24.0
31
31
  Requires-Dist: pathwaysutils
32
32
  Requires-Dist: parameterized
33
33
  Requires-Dist: numba==0.62.1
34
34
  Requires-Dist: runai-model-streamer[gcs,s3]==0.15.0
35
+ Requires-Dist: jax==0.8.1
36
+ Requires-Dist: jaxlib==0.8.1
37
+ Requires-Dist: jaxtyping==0.3.2
38
+ Requires-Dist: libtpu==0.0.31
35
39
  Dynamic: author
36
40
  Dynamic: classifier
37
41
  Dynamic: description
@@ -53,14 +57,12 @@ Dynamic: requires-python
53
57
 
54
58
  ---
55
59
 
56
- _Upcoming Events_ 🔥
57
-
58
- - Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
59
- - Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
60
- - Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
61
-
62
60
  _Latest News_ 🔥
63
61
 
62
+ - [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
63
+ - Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
64
+ - Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
65
+
64
66
  - [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
65
67
 
66
68
  <details>