tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
tpu_inference/utils.py CHANGED
@@ -3,16 +3,19 @@ import time
3
3
  from collections import defaultdict
4
4
  from collections.abc import Sequence
5
5
  from functools import wraps
6
- from typing import Any, Callable, List, Tuple
6
+ from typing import Any, Callable, List, Tuple, Union
7
7
 
8
8
  import jax
9
9
  import jax.numpy as jnp
10
10
  import numpy as np
11
+ import torch
11
12
  from jax._src import dtypes
12
13
  from jax._src import mesh as mesh_lib
13
14
  from jax._src import xla_bridge as xb
14
15
  from jax._src.lib import xla_client as xc
16
+ from jax._src.numpy.scalar_types import _ScalarMeta
15
17
  from jax.sharding import Mesh, NamedSharding, PartitionSpec
18
+ from torchax.ops.mappings import j2t_dtype, t2j_dtype
16
19
  from vllm import envs as vllm_envs
17
20
  from vllm import utils
18
21
 
@@ -23,21 +26,44 @@ GBYTES = 1024 * 1024 * 1024
23
26
  TPU_HEAD_SIZE_ALIGNMENT = 128
24
27
  TPU_SECOND_LAST_MINOR = 8
25
28
 
26
- # This is used to translate from a string name for a dtype
27
- # to formal jax.numpy DType. One use case for this is
28
- # converting the `--kv_cache_dtype` flag to a dtype.
29
- TPU_STR_DTYPE_TO_JAX_DTYPE = {
30
- "bfloat16": jnp.bfloat16,
31
- "fp8": jnp.float8_e4m3fn,
32
- "fp8_e4m3": jnp.float8_e4m3,
33
- "fp8_e5m2": jnp.float8_e5m2,
34
- "int8": jnp.int8,
29
+ # Map vllm dtype string that doesn't exactly match jax dtype string name.
30
+ _VLLM_DTYPE_STR_TO_JAX_DTYPE = {
31
+ "fp8": jnp.float8_e4m3fn.dtype,
32
+ "fp8_e4m3": jnp.float8_e4m3fn.dtype,
33
+ "fp8_e5m2": jnp.float8_e5m2.dtype,
35
34
  }
36
35
 
36
+
37
+ def to_jax_dtype(dtype: str | jnp.dtype | torch.dtype) -> jnp.dtype:
38
+ if isinstance(dtype, str):
39
+ if dict_dtype := _VLLM_DTYPE_STR_TO_JAX_DTYPE.get(dtype, None):
40
+ return dict_dtype
41
+ return jnp.dtype(dtype)
42
+ elif isinstance(dtype, torch.dtype):
43
+ return t2j_dtype(dtype)
44
+ elif isinstance(dtype, jnp.dtype):
45
+ return dtype
46
+ elif isinstance(dtype, _ScalarMeta):
47
+ return dtype.dtype
48
+ else:
49
+ raise ValueError(f"Argument is unsupported data type {type(dtype)}")
50
+
51
+
52
+ def to_torch_dtype(dtype: str | jnp.dtype | torch.dtype) -> torch.dtype:
53
+ # Use jax dtype as an intermediate dtype which we'll be used to convert it
54
+ # into torch dtype.
55
+ dtype = to_jax_dtype(dtype)
56
+ return j2t_dtype(dtype)
57
+
58
+
37
59
  _megacore = False
38
60
  logger = init_logger(__name__)
39
61
 
40
62
 
63
+ def align_to(unpadded_dim, pad_multiple):
64
+ return (unpadded_dim + pad_multiple - 1) // pad_multiple * pad_multiple
65
+
66
+
41
67
  def enable_megacore() -> None:
42
68
  global _megacore
43
69
  _megacore = True
@@ -164,7 +190,8 @@ def get_padded_num_heads(num_heads: int, sharding_size: int) -> int:
164
190
 
165
191
 
166
192
  def get_dtype_packing(dtype):
167
- bits = dtypes.bit_width(dtype)
193
+ bits = (dtypes.bit_width(dtype)
194
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
168
195
  return 32 // bits
169
196
 
170
197
 
@@ -249,40 +276,11 @@ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
249
276
 
250
277
  def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
251
278
  """
252
- A wrapper function of vllm.utils.get_hash_fn_by_name to support builtin
279
+ A wrapper function of vllm.utils.hashing.get_hash_fn_by_name to support builtin
253
280
  """
254
281
  if hash_fn_name == "builtin":
255
282
  return hash
256
- return utils.get_hash_fn_by_name(hash_fn_name)
257
-
258
-
259
- def quantize_kv(key: jax.Array, value: jax.Array,
260
- kv_cache_quantized_dtype: jnp.dtype, k_scale: float,
261
- v_scale: float) -> Tuple[jax.Array, jax.Array]:
262
- """
263
- Quantize the key and value tensors.
264
-
265
- Args:
266
- key: The key tensor to quantize.
267
- value: The value tensor to quantize.
268
- kv_cache_quantized_dtype: The dtype to quantize the key and value tensors to.
269
- q_scale: The scale to quantize the key and value tensors by.
270
- k_scale: The scale to quantize the key tensor by.
271
- v_scale: The scale to quantize the value tensor by.
272
-
273
- Returns:
274
- Tuple[jax.Array, jax.Array]: The quantized key and value tensors.
275
- """
276
- dtype_info = jnp.finfo(kv_cache_quantized_dtype)
277
- minval, maxval = float(dtype_info.min), float(dtype_info.max)
278
- key = key.astype(jnp.float32) / k_scale
279
- key = jnp.clip(key, minval, maxval)
280
- key = key.astype(kv_cache_quantized_dtype)
281
- value = value.astype(jnp.float32) / v_scale
282
- value = jnp.clip(value, minval, maxval)
283
- value = value.astype(kv_cache_quantized_dtype)
284
-
285
- return key, value
283
+ return utils.hashing.get_hash_fn_by_name(hash_fn_name)
286
284
 
287
285
 
288
286
  def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
@@ -295,8 +293,38 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
295
293
  Returns:
296
294
  jnp.dtype: The JAX dtype.
297
295
  """
298
- str_dtype = str_dtype.lower().strip()
299
- return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
296
+ # TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
297
+ return to_jax_dtype(str_dtype)
298
+
299
+
300
+ def get_mesh_shape_product(
301
+ mesh: Mesh,
302
+ axes: Union[str, list[str], None],
303
+ ) -> int:
304
+ """
305
+ Get the product of mesh dimensions for one or more axes.
306
+
307
+ Examples:
308
+ # Single axis (defaults to 1 if not present)
309
+ get_mesh_shape_product(mesh, "model")
310
+
311
+ # Multiple axes - computes product of their sizes
312
+ get_mesh_shape_product(mesh, ["model", "attn_dp"])
313
+
314
+ # None means no sharding on this dimension
315
+ get_mesh_shape_product(mesh, None) # returns 1
316
+ """
317
+ if axes is None:
318
+ return 1
319
+
320
+ if isinstance(axes, str):
321
+ axes = [axes]
322
+
323
+ product = 1
324
+ for axis in axes:
325
+ product *= mesh.shape.get(axis, 1)
326
+
327
+ return product
300
328
 
301
329
 
302
330
  def time_function(func):
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field
6
6
  from typing import Callable, Dict, Optional, Tuple
7
7
 
8
8
  import jax
9
- import jax.numpy as jnp
10
9
  import jaxlib
11
10
  import jaxtyping
12
11
  import vllm.envs as vllm_envs
@@ -19,30 +18,25 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
19
18
  from vllm.lora.request import LoRARequest
20
19
  from vllm.tasks import SupportedTask
21
20
  from vllm.v1 import utils as vllm_utils
22
- from vllm.v1.core.kv_cache_utils import get_num_blocks, get_uniform_page_size
21
+ from vllm.v1.core.kv_cache_utils import (get_kv_cache_groups, get_num_blocks,
22
+ get_uniform_page_size)
23
23
  from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
24
24
  from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
25
25
  from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
26
26
 
27
27
  from tpu_inference import envs, utils
28
28
  from tpu_inference.distributed import jax_parallel_state
29
- from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
30
- get_node_id)
29
+ from tpu_inference.distributed.utils import (get_device_topology_order_id,
30
+ get_host_ip, get_kv_transfer_port)
31
31
  from tpu_inference.layers.common.sharding import ShardingConfigManager
32
32
  from tpu_inference.logger import init_logger
33
33
  from tpu_inference.models.jax.jax_intermediate_tensor import \
34
34
  JaxIntermediateTensors
35
- from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
35
+ from tpu_inference.runner.kv_cache import get_attention_page_size_bytes
36
36
  from tpu_inference.runner.tpu_runner import TPUModelRunner
37
37
 
38
38
  logger = init_logger(__name__)
39
39
 
40
- _DTYPE: dict[str, jnp.dtype] = {
41
- "bfloat16": jnp.bfloat16,
42
- "float": jnp.float32,
43
- "float32": jnp.float32,
44
- }
45
-
46
40
 
47
41
  @dataclass
48
42
  class PPConfig:
@@ -77,21 +71,6 @@ class TPUWorker:
77
71
  ip: str = "localhost",
78
72
  prev_worker_ip: str = "localhost",
79
73
  ):
80
- # If we use vLLM's model implementation in PyTorch, we should set it
81
- # with torch version of the dtype.
82
- impl = envs.MODEL_IMPL_TYPE
83
- if impl != "vllm": # vllm-pytorch implementation does not need this conversion
84
-
85
- # NOTE(wenlong): because sometimes mm needs to use torch for preprocessing
86
- if not isinstance(vllm_config.model_config.dtype, str):
87
- logger.warning(
88
- "The model dtype is not properly set for JAX backend. "
89
- "Overwriting it to jnp.bfloat16")
90
- vllm_config.model_config.dtype = jnp.bfloat16
91
- else:
92
- vllm_config.model_config.dtype = _DTYPE.get(
93
- vllm_config.model_config.dtype, jnp.bfloat16)
94
-
95
74
  self.vllm_config = vllm_config
96
75
  self.model_config = vllm_config.model_config
97
76
  self.parallel_config = vllm_config.parallel_config
@@ -108,7 +87,7 @@ class TPUWorker:
108
87
 
109
88
  if self.model_config.trust_remote_code:
110
89
  # note: lazy import to avoid importing torch before initializing
111
- from vllm.utils import init_cached_hf_modules
90
+ from vllm.utils.import_utils import init_cached_hf_modules
112
91
 
113
92
  init_cached_hf_modules()
114
93
 
@@ -250,14 +229,33 @@ class TPUWorker:
250
229
  need_pp=self.parallel_config.pipeline_parallel_size > 1)
251
230
 
252
231
  ensure_kv_transfer_initialized(self.vllm_config)
253
- self.model_runner = TPUModelRunner(
254
- self.vllm_config, self.devices, self.rank, self.rank == 0,
255
- self.rank == self.pp_config.pp_world_size - 1)
232
+
233
+ is_first_rank = True
234
+ is_last_rank = True
235
+ self.topology_order_id = self.rank
236
+ if self.parallel_config.pipeline_parallel_size > 1:
237
+ is_first_rank = self.rank == 0
238
+ is_last_rank = self.rank == self.pp_config.pp_world_size - 1
239
+ else:
240
+ # topology_order_id is used to determine the KV cache
241
+ # mapping between P/D workers
242
+ if multihost_backend == "ray":
243
+ self.topology_order_id = get_device_topology_order_id(
244
+ jax.local_devices(), jax.devices())
245
+
246
+ self.model_runner = TPUModelRunner(self.vllm_config, self.devices,
247
+ self.rank, is_first_rank,
248
+ is_last_rank)
256
249
  logger.info(f"Init worker | "
257
250
  f"rank={self.rank} | "
258
- f"node_id={get_node_id()} | "
251
+ f"is_first_rank={is_first_rank} | "
252
+ f"is_last_rank={is_last_rank} | "
253
+ f"topology_order_id={self.topology_order_id} | "
259
254
  f"is_driver_worker={self.is_driver_worker} | "
260
- f"hbm={utils.hbm_usage_gb(self.devices)}GiB")
255
+ f"hbm={utils.hbm_usage_gb(self.devices)}GiB |"
256
+ f"self.devices={self.devices} | "
257
+ f"total devices={jax.devices()} | "
258
+ f"local_devices={jax.local_devices()}")
261
259
  vllm_utils.report_usage_stats(self.vllm_config)
262
260
 
263
261
  def initialize_pp_transfer_connect(self):
@@ -357,7 +355,7 @@ class TPUWorker:
357
355
  if is_start:
358
356
  options = jax.profiler.ProfileOptions()
359
357
  # default: https://docs.jax.dev/en/latest/profiling.html#general-options
360
- options.python_tracer_level = os.getenv("PYTHON_TRACER_LEVEL", 0)
358
+ options.python_tracer_level = envs.PYTHON_TRACER_LEVEL
361
359
  options.host_tracer_level = os.getenv("HOST_TRACER_LEVEL", 1)
362
360
  jax.profiler.start_trace(self.profile_dir,
363
361
  profiler_options=options)
@@ -395,45 +393,56 @@ class TPUWorker:
395
393
  # responsible for this translation. When vLLM can be modified, this
396
394
  # method should be changed to return `dict[str, AbstractKVCacheSpec]`,
397
395
  # and the vLLM side should be updated to handle the translation.
398
- kv_cache_specs = self.model_runner.get_kv_cache_spec()
396
+ kv_cache_spec = self.model_runner.get_kv_cache_spec()
399
397
 
400
- if len(kv_cache_specs) == 0:
401
- return kv_cache_specs
398
+ if len(kv_cache_spec) == 0:
399
+ return kv_cache_spec
402
400
 
403
401
  # TODO(kyuyeunk): Instead of checking page_size_bytes here, introduce
404
402
  # feature that allows overriding page_size_bytes of KVCacheSpec.
405
- vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs)
406
- rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh,
407
- kv_cache_specs)
403
+ vllm_page_size_bytes = get_uniform_page_size(
404
+ list(kv_cache_spec.values()))
405
+ attention_page_size_bytes = get_attention_page_size_bytes(
406
+ self.model_runner.mesh, kv_cache_spec)
408
407
 
409
- if vllm_page_size_bytes != rpa_page_size_bytes:
408
+ if vllm_page_size_bytes != attention_page_size_bytes:
410
409
  logger.info(
411
- f"KV cache page size calculated by vLLM "
412
- f"({vllm_page_size_bytes} Bytes) does not match with actual "
413
- f"page size used by RPA kernel ({rpa_page_size_bytes} Bytes). "
414
- f"Recalculating number of KV blocks using actual page size.")
415
-
410
+ f"Page size calculated by vLLM ({vllm_page_size_bytes} Bytes) "
411
+ f"does not match with actual page size used by the kernel "
412
+ f"({attention_page_size_bytes} Bytes). Recalculating number of "
413
+ f"KV blocks using actual page size.")
414
+
415
+ kv_cache_groups = get_kv_cache_groups(self.vllm_config,
416
+ kv_cache_spec)
417
+ group_size = max(
418
+ len(group.layer_names) for group in kv_cache_groups)
416
419
  available_memory = self.determine_available_memory()
417
- num_blocks = get_num_blocks(self.vllm_config, len(kv_cache_specs),
418
- available_memory, rpa_page_size_bytes)
419
-
420
+ num_blocks = get_num_blocks(self.vllm_config, group_size,
421
+ available_memory,
422
+ attention_page_size_bytes)
420
423
  cache_config = self.vllm_config.cache_config
421
424
  cache_config.num_gpu_blocks_override = num_blocks
422
425
 
423
- return kv_cache_specs
426
+ return kv_cache_spec
424
427
 
425
428
  def initialize_from_config(
426
429
  self,
427
430
  kv_cache_config: KVCacheConfig,
428
431
  ) -> None:
429
432
  """Allocate GPU KV cache with the specified kv_cache_config."""
430
- self.model_runner.initialize_kv_cache(kv_cache_config)
433
+ # Precompile functions with large vocab_size tensors before allocating KV cache to avoid OOM
434
+ if not (envs.SKIP_JAX_PRECOMPILE or
435
+ (hasattr(self.model_runner.model_config, "enforce_eager")
436
+ and self.model_runner.model_config.enforce_eager)):
437
+ self.model_runner.compilation_manager._precompile_sampling()
438
+ self.model_runner.compilation_manager._precompile_gather_logprobs()
439
+ self.model_runner.initialize_kv_cache(kv_cache_config,
440
+ self.topology_order_id)
431
441
 
432
442
  def get_node_kv_ip_port(self) -> tuple[int, str, int]:
433
- node_id = get_node_id()
434
443
  ip = get_host_ip()
435
444
  port = get_kv_transfer_port()
436
- return (int(node_id), ip, int(port))
445
+ return (int(self.topology_order_id), ip, int(port))
437
446
 
438
447
  def check_health(self) -> None:
439
448
  # worker will always be healthy as long as it's running.
@@ -455,3 +464,8 @@ class TPUWorker:
455
464
 
456
465
  def shutdown(self) -> None:
457
466
  return
467
+
468
+ # Ray executor do not need handshake metadata
469
+ # as we pass the kv_parameters through proxy server
470
+ def get_kv_connector_handshake_metadata(self) -> None:
471
+ pass
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tpu_inference
3
- Version: 0.11.1.dev202511220812
3
+ Version: 0.13.2.dev20251230
4
4
  Author: tpu_inference Contributors
5
5
  Classifier: Development Status :: 3 - Alpha
6
6
  Classifier: Intended Audience :: Developers
@@ -14,7 +14,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
14
  Requires-Python: >=3.10
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
- Requires-Dist: tpu-info==0.4.0
17
+ Requires-Dist: tpu-info==0.7.1
18
18
  Requires-Dist: yapf==0.43.0
19
19
  Requires-Dist: pytest
20
20
  Requires-Dist: pytest-mock
@@ -25,12 +25,13 @@ Requires-Dist: jax[tpu]==0.8.0
25
25
  Requires-Dist: jaxlib==0.8.0
26
26
  Requires-Dist: jaxtyping
27
27
  Requires-Dist: flax==0.11.1
28
- Requires-Dist: torchax==0.0.7
28
+ Requires-Dist: torchax==0.0.10
29
29
  Requires-Dist: qwix==0.1.1
30
30
  Requires-Dist: torchvision==0.24.0
31
31
  Requires-Dist: pathwaysutils
32
32
  Requires-Dist: parameterized
33
33
  Requires-Dist: numba==0.62.1
34
+ Requires-Dist: runai-model-streamer[gcs,s3]==0.15.0
34
35
  Dynamic: author
35
36
  Dynamic: classifier
36
37
  Dynamic: description
@@ -52,14 +53,12 @@ Dynamic: requires-python
52
53
 
53
54
  ---
54
55
 
55
- _Upcoming Events_ 🔥
56
-
57
- - Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
58
- - Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
59
- - Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
60
-
61
56
  _Latest News_ 🔥
62
57
 
58
+ - [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
59
+ - Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
60
+ - Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
61
+
63
62
  - [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
64
63
 
65
64
  <details>