tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import os
2
16
 
3
17
  from vllm.utils.network_utils import get_ip
@@ -54,7 +68,45 @@ def get_side_channel_port() -> str:
54
68
  return port
55
69
 
56
70
 
57
- def get_node_id() -> int:
58
- # TODO(xiang): Is it possible to get this from a pre-defiend env?
59
- id = os.getenv("TPU_NODE_ID", 0)
60
- return int(id)
71
+ def get_device_topology_order_id(local_devices, global_devices) -> int:
72
+ """
73
+ Calculates the topology order ID for the local device set within the global topology.
74
+
75
+ This function determines the rank of the current host/process based on the
76
+ coordinate of its TPU devices relative to all devices in the topology.
77
+
78
+ Args:
79
+ local_devices: A list of TpuDevice objects available to the current process.
80
+ global_devices: A list of all TpuDevice objects in the global topology.
81
+
82
+ Returns:
83
+ The topology order ID (rank) of the local devices.
84
+ """
85
+ if not local_devices:
86
+ raise ValueError("local_devices cannot be empty")
87
+ if not global_devices:
88
+ raise ValueError("global_devices cannot be empty")
89
+
90
+ # 1. Find the 'anchor' (minimum coordinate) for the local devices.
91
+ # This represents the physical top-left corner of the local machine.
92
+ local_anchor = min(d.coords for d in local_devices)
93
+
94
+ # 2. Group global devices by process to find the anchor for EVERY process.
95
+ process_anchors = {}
96
+ for d in global_devices:
97
+ pid = d.process_index
98
+ # Update the minimum coordinate found for this process so far
99
+ if pid not in process_anchors or d.coords < process_anchors[pid]:
100
+ process_anchors[pid] = d.coords
101
+
102
+ # 3. Sort the unique anchors to establish the canonical topology order.
103
+ # Tuples (x, y, z) sort lexicographically (x first, then y, then z).
104
+ sorted_anchors = sorted(process_anchors.values())
105
+
106
+ # 4. Return the index (rank) of the local anchor in the sorted list.
107
+ try:
108
+ return sorted_anchors.index(local_anchor)
109
+ except ValueError:
110
+ raise ValueError(
111
+ f"Local devices: {local_devices} do not exist in the global device: {global_devices} list."
112
+ )
tpu_inference/envs.py CHANGED
@@ -15,13 +15,88 @@ if TYPE_CHECKING:
15
15
  PREFILL_SLICES: str = ""
16
16
  DECODE_SLICES: str = ""
17
17
  SKIP_JAX_PRECOMPILE: bool = False
18
- MODEL_IMPL_TYPE: str = "flax_nnx"
18
+ VLLM_XLA_CHECK_RECOMPILATION: bool = False
19
+ MODEL_IMPL_TYPE: str = "auto"
19
20
  NEW_MODEL_DESIGN: bool = False
20
21
  PHASED_PROFILING_DIR: str = ""
21
22
  PYTHON_TRACER_LEVEL: int = 1
22
23
  USE_MOE_EP_KERNEL: bool = False
24
+ NUM_SLICES: int = 1
23
25
  RAY_USAGE_STATS_ENABLED: str = "0"
24
26
  VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
27
+ ENABLE_QUANTIZED_MATMUL_KERNEL: bool = False
28
+
29
+
30
+ def env_with_choices(
31
+ env_name: str,
32
+ default: str | None,
33
+ choices: list[str] | Callable[[], list[str]],
34
+ case_sensitive: bool = True,
35
+ ) -> Callable[[], str | None]:
36
+ """
37
+ Create a lambda that validates environment variable against allowed choices
38
+
39
+ Args:
40
+ env_name: Name of the environment variable
41
+ default: Default value if not set (can be None)
42
+ choices: List of valid string options or callable that returns list
43
+ case_sensitive: Whether validation should be case sensitive
44
+
45
+ Returns:
46
+ Lambda function for environment_variables dict
47
+ """
48
+
49
+ def _get_validated_env() -> str | None:
50
+ value = os.getenv(env_name)
51
+ if value is None:
52
+ return default
53
+
54
+ # Resolve choices if it's a callable (for lazy loading)
55
+ actual_choices = choices() if callable(choices) else choices
56
+
57
+ if not case_sensitive:
58
+ check_value = value.lower()
59
+ check_choices = [choice.lower() for choice in actual_choices]
60
+ else:
61
+ check_value = value
62
+ check_choices = actual_choices
63
+
64
+ if check_value not in check_choices:
65
+ raise ValueError(f"Invalid value '{value}' for {env_name}. "
66
+ f"Valid options: {actual_choices}.")
67
+
68
+ return value
69
+
70
+ return _get_validated_env
71
+
72
+
73
+ def env_bool(env_name: str, default: bool = False) -> Callable[[], bool]:
74
+ """
75
+ Accepts both numeric strings ("0", "1") and boolean strings
76
+ ("true", "false", "True", "False").
77
+
78
+ Args:
79
+ env_name: Name of the environment variable
80
+ default: Default boolean value if not set
81
+ """
82
+
83
+ def _get_bool_env() -> bool:
84
+ value = os.getenv(env_name)
85
+ if value is None or value == "":
86
+ return default
87
+
88
+ value_lower = value.lower()
89
+ if value_lower in ("true", "1"):
90
+ return True
91
+ elif value_lower in ("false", "0"):
92
+ return False
93
+ else:
94
+ raise ValueError(
95
+ f"Invalid boolean value '{value}' for {env_name}. "
96
+ f"Valid options: '0', '1', 'true', 'false', 'True', 'False'.")
97
+
98
+ return _get_bool_env
99
+
25
100
 
26
101
  environment_variables: dict[str, Callable[[], Any]] = {
27
102
  # JAX platform selection (e.g., "tpu", "cpu", "proxy")
@@ -38,7 +113,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
38
113
  lambda: os.getenv("TPU_WORKER_ID", None),
39
114
  # Backend for multi-host communication on TPU
40
115
  "TPU_MULTIHOST_BACKEND":
41
- lambda: os.getenv("TPU_MULTIHOST_BACKEND", "").lower(),
116
+ env_with_choices("TPU_MULTIHOST_BACKEND", "", ["ray"]),
42
117
  # Slice configuration for disaggregated prefill workers
43
118
  "PREFILL_SLICES":
44
119
  lambda: os.getenv("PREFILL_SLICES", ""),
@@ -47,28 +122,37 @@ environment_variables: dict[str, Callable[[], Any]] = {
47
122
  lambda: os.getenv("DECODE_SLICES", ""),
48
123
  # Skip JAX precompilation step during initialization
49
124
  "SKIP_JAX_PRECOMPILE":
50
- lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))),
125
+ env_bool("SKIP_JAX_PRECOMPILE", default=False),
126
+ # Check for XLA recompilation during execution
127
+ "VLLM_XLA_CHECK_RECOMPILATION":
128
+ env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
51
129
  # Model implementation type (e.g., "flax_nnx")
52
130
  "MODEL_IMPL_TYPE":
53
- lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(),
131
+ env_with_choices("MODEL_IMPL_TYPE", "auto",
132
+ ["auto", "vllm", "flax_nnx", "jetpack"]),
54
133
  # Enable new experimental model design
55
134
  "NEW_MODEL_DESIGN":
56
- lambda: bool(int(os.getenv("NEW_MODEL_DESIGN", "0"))),
135
+ env_bool("NEW_MODEL_DESIGN", default=False),
57
136
  # Directory to store phased profiling output
58
137
  "PHASED_PROFILING_DIR":
59
138
  lambda: os.getenv("PHASED_PROFILING_DIR", ""),
60
139
  # Python tracer level for profiling
61
140
  "PYTHON_TRACER_LEVEL":
62
- lambda: int(os.getenv("PYTHON_TRACER_LEVEL", "1")),
141
+ lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
63
142
  # Use custom expert-parallel kernel for MoE (Mixture of Experts)
64
143
  "USE_MOE_EP_KERNEL":
65
- lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL", "0"))),
144
+ env_bool("USE_MOE_EP_KERNEL", default=False),
145
+ # Number of TPU slices for multi-slice mesh
146
+ "NUM_SLICES":
147
+ lambda: int(os.getenv("NUM_SLICES") or "1"),
66
148
  # Enable/disable Ray usage statistics collection
67
149
  "RAY_USAGE_STATS_ENABLED":
68
150
  lambda: os.getenv("RAY_USAGE_STATS_ENABLED", "0"),
69
151
  # Ray compiled DAG channel type for TPU
70
152
  "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
71
- lambda: os.getenv("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm"),
153
+ env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
154
+ "ENABLE_QUANTIZED_MATMUL_KERNEL":
155
+ lambda: bool(int(os.getenv("ENABLE_QUANTIZED_MATMUL_KERNEL") or "0")),
72
156
  }
73
157
 
74
158
 
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import os
2
16
  from array import array
3
17
  from typing import Any, Dict, List, Optional
@@ -136,11 +150,18 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
136
150
 
137
151
  pp_size = self.parallel_config.pipeline_parallel_size
138
152
  placement_group_specs: List[Dict[str, float]] = []
153
+
154
+ ray_nodes = ray.nodes()
155
+ logger.info(f"RayDistributedExecutor | ray_nodes={ray_nodes}")
156
+
139
157
  if pp_size == 1:
140
158
  placement_group_specs = [{
141
159
  device_str: node['Resources'][device_str]
142
- } for node in ray.nodes()]
160
+ } for node in ray_nodes]
143
161
  else:
162
+ assert pp_size == len(
163
+ ray_nodes
164
+ ), f"Cannot use PP across hosts, please set --pipeline-parallel-size to 1 or {len(ray_nodes)}"
144
165
  num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
145
166
  placement_group_specs = [{
146
167
  device_str: num_devices_per_pp_rank
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  # TODO: Update documentation
2
16
 
3
17
  from typing import List, Optional, Tuple
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -540,12 +540,16 @@ def get_vmem_estimate_bytes(
540
540
  """Returns the total vmem bytes used by the kernel."""
541
541
  m_per_device = m // tp_size
542
542
  n_per_device = n // tp_size
543
- y_vmem_bytes = n_per_device * k * dtypes.bit_width(y_dtype) // 8
543
+ y_vmem_bytes = (n_per_device * k * (dtypes.bit_width(y_dtype) if hasattr(
544
+ dtypes, "bit_width") else dtypes.itemsize_bits(y_dtype)) // 8)
544
545
  total_bytes = (
545
- 2 * m_per_device * k * dtypes.bit_width(x_dtype) //
546
- 8 # x_vmem_scratch_ref
546
+ 2 * m_per_device * k *
547
+ (dtypes.bit_width(x_dtype) if hasattr(dtypes, "bit_width") else
548
+ dtypes.itemsize_bits(x_dtype)) // 8 # x_vmem_scratch_ref
547
549
  + y_vmem_bytes # y_vmem_scratch_ref
548
- + 2 * m * bn * dtypes.bit_width(out_dtype) // 8 # o_vmem_scratch_ref
550
+ + 2 * m * bn *
551
+ (dtypes.bit_width(out_dtype) if hasattr(dtypes, "bit_width") else
552
+ dtypes.itemsize_bits(out_dtype)) // 8 # o_vmem_scratch_ref
549
553
  + acc_bytes # acc_vmem_scratch_ref, jnp.float32
550
554
  )
551
555
  return total_bytes
@@ -639,8 +643,10 @@ def all_gather_matmul(
639
643
  # NOTE(chengjiyao): acc buffer is not used in the grid_k == 1 case.
640
644
  if grid_k == 1:
641
645
  acc_shape = (8, 128)
642
- acc_bytes = acc_shape[0] * acc_shape[1] * dtypes.bit_width(
643
- jnp.float32) // 8
646
+ acc_bytes = (
647
+ acc_shape[0] *
648
+ acc_shape[1] * (dtypes.bit_width(jnp.float32) if hasattr(
649
+ dtypes, "bit_width") else dtypes.itemsize_bits(jnp.float32)) // 8)
644
650
  y_vmem_shape = (n_per_device, k) if rhs_transpose else (k, n_per_device)
645
651
  estimated_vmem_bytes = get_vmem_estimate_bytes(
646
652
  m,
@@ -1,6 +1,8 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
  """All-gather matmul kernel's tuned block sizes."""
3
3
 
4
+ import re
5
+
4
6
  import jax
5
7
 
6
8
  # key:
@@ -32,8 +34,11 @@ def get_tpu_version() -> int:
32
34
  return -1
33
35
  if kind.endswith(' lite'):
34
36
  kind = kind[:-len(' lite')]
35
- assert kind[:-1] == 'TPU v', kind
36
- return int(kind[-1])
37
+
38
+ # v6: "TPU v6"
39
+ # v7: "TPU7x"
40
+ assert kind[:3] == 'TPU', kind
41
+ return int(re.search(r'\d+', kind).group())
37
42
 
38
43
 
39
44
  def get_key(
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.