tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +45 -15
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +41 -16
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  240. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import os
2
16
 
3
17
  from vllm.utils.network_utils import get_ip
@@ -54,7 +68,45 @@ def get_side_channel_port() -> str:
54
68
  return port
55
69
 
56
70
 
57
- def get_node_id() -> int:
58
- # TODO(xiang): Is it possible to get this from a pre-defiend env?
59
- id = os.getenv("TPU_NODE_ID", 0)
60
- return int(id)
71
+ def get_device_topology_order_id(local_devices, global_devices) -> int:
72
+ """
73
+ Calculates the topology order ID for the local device set within the global topology.
74
+
75
+ This function determines the rank of the current host/process based on the
76
+ coordinate of its TPU devices relative to all devices in the topology.
77
+
78
+ Args:
79
+ local_devices: A list of TpuDevice objects available to the current process.
80
+ global_devices: A list of all TpuDevice objects in the global topology.
81
+
82
+ Returns:
83
+ The topology order ID (rank) of the local devices.
84
+ """
85
+ if not local_devices:
86
+ raise ValueError("local_devices cannot be empty")
87
+ if not global_devices:
88
+ raise ValueError("global_devices cannot be empty")
89
+
90
+ # 1. Find the 'anchor' (minimum coordinate) for the local devices.
91
+ # This represents the physical top-left corner of the local machine.
92
+ local_anchor = min(d.coords for d in local_devices)
93
+
94
+ # 2. Group global devices by process to find the anchor for EVERY process.
95
+ process_anchors = {}
96
+ for d in global_devices:
97
+ pid = d.process_index
98
+ # Update the minimum coordinate found for this process so far
99
+ if pid not in process_anchors or d.coords < process_anchors[pid]:
100
+ process_anchors[pid] = d.coords
101
+
102
+ # 3. Sort the unique anchors to establish the canonical topology order.
103
+ # Tuples (x, y, z) sort lexicographically (x first, then y, then z).
104
+ sorted_anchors = sorted(process_anchors.values())
105
+
106
+ # 4. Return the index (rank) of the local anchor in the sorted list.
107
+ try:
108
+ return sorted_anchors.index(local_anchor)
109
+ except ValueError:
110
+ raise ValueError(
111
+ f"Local devices: {local_devices} do not exist in the global device: {global_devices} list."
112
+ )
tpu_inference/envs.py CHANGED
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
16
16
  DECODE_SLICES: str = ""
17
17
  SKIP_JAX_PRECOMPILE: bool = False
18
18
  VLLM_XLA_CHECK_RECOMPILATION: bool = False
19
- MODEL_IMPL_TYPE: str = "flax_nnx"
19
+ MODEL_IMPL_TYPE: str = "auto"
20
20
  NEW_MODEL_DESIGN: bool = False
21
21
  PHASED_PROFILING_DIR: str = ""
22
22
  PYTHON_TRACER_LEVEL: int = 1
@@ -24,6 +24,7 @@ if TYPE_CHECKING:
24
24
  NUM_SLICES: int = 1
25
25
  RAY_USAGE_STATS_ENABLED: str = "0"
26
26
  VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm"
27
+ ENABLE_QUANTIZED_MATMUL_KERNEL: bool = False
27
28
 
28
29
 
29
30
  def env_with_choices(
@@ -69,6 +70,34 @@ def env_with_choices(
69
70
  return _get_validated_env
70
71
 
71
72
 
73
+ def env_bool(env_name: str, default: bool = False) -> Callable[[], bool]:
74
+ """
75
+ Accepts both numeric strings ("0", "1") and boolean strings
76
+ ("true", "false", "True", "False").
77
+
78
+ Args:
79
+ env_name: Name of the environment variable
80
+ default: Default boolean value if not set
81
+ """
82
+
83
+ def _get_bool_env() -> bool:
84
+ value = os.getenv(env_name)
85
+ if value is None or value == "":
86
+ return default
87
+
88
+ value_lower = value.lower()
89
+ if value_lower in ("true", "1"):
90
+ return True
91
+ elif value_lower in ("false", "0"):
92
+ return False
93
+ else:
94
+ raise ValueError(
95
+ f"Invalid boolean value '{value}' for {env_name}. "
96
+ f"Valid options: '0', '1', 'true', 'false', 'True', 'False'.")
97
+
98
+ return _get_bool_env
99
+
100
+
72
101
  environment_variables: dict[str, Callable[[], Any]] = {
73
102
  # JAX platform selection (e.g., "tpu", "cpu", "proxy")
74
103
  "JAX_PLATFORMS":
@@ -93,17 +122,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
93
122
  lambda: os.getenv("DECODE_SLICES", ""),
94
123
  # Skip JAX precompilation step during initialization
95
124
  "SKIP_JAX_PRECOMPILE":
96
- lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE") or "0")),
125
+ env_bool("SKIP_JAX_PRECOMPILE", default=False),
97
126
  # Check for XLA recompilation during execution
98
127
  "VLLM_XLA_CHECK_RECOMPILATION":
99
- lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION") or "0")),
128
+ env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
100
129
  # Model implementation type (e.g., "flax_nnx")
101
130
  "MODEL_IMPL_TYPE":
102
- env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
103
- ["vllm", "flax_nnx", "jetpack"]),
131
+ env_with_choices("MODEL_IMPL_TYPE", "auto",
132
+ ["auto", "vllm", "flax_nnx", "jetpack"]),
104
133
  # Enable new experimental model design
105
134
  "NEW_MODEL_DESIGN":
106
- lambda: bool(int(os.getenv("NEW_MODEL_DESIGN") or "0")),
135
+ env_bool("NEW_MODEL_DESIGN", default=False),
107
136
  # Directory to store phased profiling output
108
137
  "PHASED_PROFILING_DIR":
109
138
  lambda: os.getenv("PHASED_PROFILING_DIR", ""),
@@ -112,7 +141,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
112
141
  lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
113
142
  # Use custom expert-parallel kernel for MoE (Mixture of Experts)
114
143
  "USE_MOE_EP_KERNEL":
115
- lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL") or "0")),
144
+ env_bool("USE_MOE_EP_KERNEL", default=False),
116
145
  # Number of TPU slices for multi-slice mesh
117
146
  "NUM_SLICES":
118
147
  lambda: int(os.getenv("NUM_SLICES") or "1"),
@@ -122,6 +151,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
122
151
  # Ray compiled DAG channel type for TPU
123
152
  "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE":
124
153
  env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]),
154
+ "ENABLE_QUANTIZED_MATMUL_KERNEL":
155
+ lambda: bool(int(os.getenv("ENABLE_QUANTIZED_MATMUL_KERNEL") or "0")),
125
156
  }
126
157
 
127
158
 
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import os
2
16
  from array import array
3
17
  from typing import Any, Dict, List, Optional
@@ -145,6 +159,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
145
159
  device_str: node['Resources'][device_str]
146
160
  } for node in ray_nodes]
147
161
  else:
162
+ assert pp_size == len(
163
+ ray_nodes
164
+ ), f"Cannot use PP across hosts, please set --pipeline-parallel-size to 1 or {len(ray_nodes)}"
148
165
  num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
149
166
  placement_group_specs = [{
150
167
  device_str: num_devices_per_pp_rank
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  # TODO: Update documentation
2
16
 
3
17
  from typing import List, Optional, Tuple
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -540,12 +540,16 @@ def get_vmem_estimate_bytes(
540
540
  """Returns the total vmem bytes used by the kernel."""
541
541
  m_per_device = m // tp_size
542
542
  n_per_device = n // tp_size
543
- y_vmem_bytes = n_per_device * k * dtypes.bit_width(y_dtype) // 8
543
+ y_vmem_bytes = (n_per_device * k * (dtypes.bit_width(y_dtype) if hasattr(
544
+ dtypes, "bit_width") else dtypes.itemsize_bits(y_dtype)) // 8)
544
545
  total_bytes = (
545
- 2 * m_per_device * k * dtypes.bit_width(x_dtype) //
546
- 8 # x_vmem_scratch_ref
546
+ 2 * m_per_device * k *
547
+ (dtypes.bit_width(x_dtype) if hasattr(dtypes, "bit_width") else
548
+ dtypes.itemsize_bits(x_dtype)) // 8 # x_vmem_scratch_ref
547
549
  + y_vmem_bytes # y_vmem_scratch_ref
548
- + 2 * m * bn * dtypes.bit_width(out_dtype) // 8 # o_vmem_scratch_ref
550
+ + 2 * m * bn *
551
+ (dtypes.bit_width(out_dtype) if hasattr(dtypes, "bit_width") else
552
+ dtypes.itemsize_bits(out_dtype)) // 8 # o_vmem_scratch_ref
549
553
  + acc_bytes # acc_vmem_scratch_ref, jnp.float32
550
554
  )
551
555
  return total_bytes
@@ -639,8 +643,10 @@ def all_gather_matmul(
639
643
  # NOTE(chengjiyao): acc buffer is not used in the grid_k == 1 case.
640
644
  if grid_k == 1:
641
645
  acc_shape = (8, 128)
642
- acc_bytes = acc_shape[0] * acc_shape[1] * dtypes.bit_width(
643
- jnp.float32) // 8
646
+ acc_bytes = (
647
+ acc_shape[0] *
648
+ acc_shape[1] * (dtypes.bit_width(jnp.float32) if hasattr(
649
+ dtypes, "bit_width") else dtypes.itemsize_bits(jnp.float32)) // 8)
644
650
  y_vmem_shape = (n_per_device, k) if rhs_transpose else (k, n_per_device)
645
651
  estimated_vmem_bytes = get_vmem_estimate_bytes(
646
652
  m,
@@ -1,6 +1,8 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
  """All-gather matmul kernel's tuned block sizes."""
3
3
 
4
+ import re
5
+
4
6
  import jax
5
7
 
6
8
  # key:
@@ -32,8 +34,11 @@ def get_tpu_version() -> int:
32
34
  return -1
33
35
  if kind.endswith(' lite'):
34
36
  kind = kind[:-len(' lite')]
35
- assert kind[:-1] == 'TPU v', kind
36
- return int(kind[-1])
37
+
38
+ # v6: "TPU v6"
39
+ # v7: "TPU7x"
40
+ assert kind[:3] == 'TPU', kind
41
+ return int(re.search(r'\d+', kind).group())
37
42
 
38
43
 
39
44
  def get_key(
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.