tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import os
2
16
 
3
17
  from vllm.utils.network_utils import get_ip
@@ -54,7 +68,45 @@ def get_side_channel_port() -> str:
54
68
  return port
55
69
 
56
70
 
57
- def get_node_id() -> int:
58
- # TODO(xiang): Is it possible to get this from a pre-defiend env?
59
- id = os.getenv("TPU_NODE_ID", 0)
60
- return int(id)
71
+ def get_device_topology_order_id(local_devices, global_devices) -> int:
72
+ """
73
+ Calculates the topology order ID for the local device set within the global topology.
74
+
75
+ This function determines the rank of the current host/process based on the
76
+ coordinate of its TPU devices relative to all devices in the topology.
77
+
78
+ Args:
79
+ local_devices: A list of TpuDevice objects available to the current process.
80
+ global_devices: A list of all TpuDevice objects in the global topology.
81
+
82
+ Returns:
83
+ The topology order ID (rank) of the local devices.
84
+ """
85
+ if not local_devices:
86
+ raise ValueError("local_devices cannot be empty")
87
+ if not global_devices:
88
+ raise ValueError("global_devices cannot be empty")
89
+
90
+ # 1. Find the 'anchor' (minimum coordinate) for the local devices.
91
+ # This represents the physical top-left corner of the local machine.
92
+ local_anchor = min(d.coords for d in local_devices)
93
+
94
+ # 2. Group global devices by process to find the anchor for EVERY process.
95
+ process_anchors = {}
96
+ for d in global_devices:
97
+ pid = d.process_index
98
+ # Update the minimum coordinate found for this process so far
99
+ if pid not in process_anchors or d.coords < process_anchors[pid]:
100
+ process_anchors[pid] = d.coords
101
+
102
+ # 3. Sort the unique anchors to establish the canonical topology order.
103
+ # Tuples (x, y, z) sort lexicographically (x first, then y, then z).
104
+ sorted_anchors = sorted(process_anchors.values())
105
+
106
+ # 4. Return the index (rank) of the local anchor in the sorted list.
107
+ try:
108
+ return sorted_anchors.index(local_anchor)
109
+ except ValueError:
110
+ raise ValueError(
111
+ f"Local devices: {local_devices} do not exist in the global device: {global_devices} list."
112
+ )
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import os
2
16
  from array import array
3
17
  from typing import Any, Dict, List, Optional
@@ -6,7 +20,7 @@ import ray
6
20
  import vllm.envs as envs
7
21
  from ray.util.placement_group import PlacementGroup
8
22
  from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
9
- from vllm.multimodal.inputs import MultiModalKwargs
23
+ from vllm.multimodal.inputs import MultiModalKwargsItem
10
24
  from vllm.platforms import current_platform
11
25
  from vllm.ray.ray_env import get_env_vars_to_copy
12
26
  from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
@@ -39,7 +53,7 @@ logger = init_logger(__name__)
39
53
 
40
54
 
41
55
  def _encode_hook(obj: Any) -> Any:
42
- """Custom msgspec enc hook that supports array types and MultiModalKwargs.
56
+ """Custom msgspec enc hook that supports array types and MultiModalKwargsItem.
43
57
 
44
58
  See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
45
59
  """
@@ -48,7 +62,7 @@ def _encode_hook(obj: Any) -> Any:
48
62
  f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. "
49
63
  f"Given array has a type code of {obj.typecode}.")
50
64
  return obj.tobytes()
51
- if isinstance(obj, MultiModalKwargs):
65
+ if isinstance(obj, MultiModalKwargsItem):
52
66
  return dict(obj)
53
67
 
54
68
 
@@ -145,6 +159,9 @@ class RayDistributedExecutor(RayDistributedExecutorV1):
145
159
  device_str: node['Resources'][device_str]
146
160
  } for node in ray_nodes]
147
161
  else:
162
+ assert pp_size == len(
163
+ ray_nodes
164
+ ), f"Cannot use PP across hosts, please set --pipeline-parallel-size to 1 or {len(ray_nodes)}"
148
165
  num_devices_per_pp_rank = self.vllm_config.sharding_config.total_devices
149
166
  placement_group_specs = [{
150
167
  device_str: num_devices_per_pp_rank
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  # TODO: Update documentation
2
16
 
3
17
  from typing import List, Optional, Tuple
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """TPU-Friendly Fused Mixture of Experts (MoE) kernel."""
2
15
 
3
16
  import functools
@@ -1376,171 +1389,166 @@ def fused_ep_moe(
1376
1389
  hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
1377
1390
  renorm_str = "-renorm_k" if renormalize_topk_logits else ""
1378
1391
  scope_name = f"fused-moe-k_{top_k}{renorm_str}-bt_{bt}_{btc}-bf_{bf}_{bfc}-bd1_{bd1}_{bd1c}-bd2_{bd2}_{bd2c}"
1379
- fused_moe = jax.named_scope(scope_name)(
1380
- pl.pallas_call(
1381
- functools.partial(
1382
- _fused_ep_moe_kernel,
1383
- top_k=top_k,
1384
- renormalize_topk_logits=renormalize_topk_logits,
1385
- ep_axis_name=ep_axis_name,
1386
- act_fn=act_fn,
1387
- subc_quant_wsz=subc_quant_wsz,
1388
- bt=bt,
1389
- bf=bf,
1390
- bd1=bd1,
1391
- bd2=bd2,
1392
- btc=btc,
1393
- bfc=bfc,
1394
- bd1c=bd1c,
1395
- bd2c=bd2c,
1396
- ),
1397
- out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
1398
- t_dtype),
1399
- grid_spec=pltpu.PrefetchScalarGridSpec(
1400
- num_scalar_prefetch=0,
1401
- in_specs=[
1402
- hbm_block_spec, # tokens_hbm
1403
- hbm_block_spec, # w1_hbm
1404
- hbm_block_spec, # w2_hbm
1405
- None
1406
- if w1_scale is None else hbm_block_spec, # w1_scale_hbm
1407
- None
1408
- if w2_scale is None else hbm_block_spec, # w2_scale_hbm
1409
- None if b1 is None else hbm_block_spec, # b1_hbm
1410
- None if b2 is None else hbm_block_spec, # b2_hbm
1411
- hbm_block_spec, # gating_output_hbm
1412
- hbm_block_spec, # a2a_g_hbm
1413
- ],
1414
- out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
1415
- scratch_shapes=([
1416
- # t2e_routing_x2_smem
1417
- pltpu.SMEM((2, bt, padded_top_k), jnp.int32),
1418
- # d2e_count_x2_smem
1419
- pltpu.SMEM((2, num_devices, 1, padded_num_experts),
1420
- jnp.int32),
1421
- # expert_offsets_x2_smem
1422
- pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
1423
- # expert_starts_x2_smem
1424
- pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1425
- # expert_sizes_x2_smem
1426
- pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1427
- # a2a_s_sends_x2_smem
1428
- pltpu.SMEM((2, ), jnp.int32),
1429
- # a2a_s_x2_vmem
1430
- pltpu.VMEM(
1431
- (
1432
- 2,
1433
- bt * num_devices,
1434
- t_packing,
1435
- hidden_size // t_packing,
1436
- ),
1437
- t_dtype,
1392
+ fused_moe = pl.pallas_call(
1393
+ functools.partial(
1394
+ _fused_ep_moe_kernel,
1395
+ top_k=top_k,
1396
+ renormalize_topk_logits=renormalize_topk_logits,
1397
+ ep_axis_name=ep_axis_name,
1398
+ act_fn=act_fn,
1399
+ subc_quant_wsz=subc_quant_wsz,
1400
+ bt=bt,
1401
+ bf=bf,
1402
+ bd1=bd1,
1403
+ bd2=bd2,
1404
+ btc=btc,
1405
+ bfc=bfc,
1406
+ bd1c=bd1c,
1407
+ bd2c=bd2c,
1408
+ ),
1409
+ out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
1410
+ t_dtype),
1411
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1412
+ num_scalar_prefetch=0,
1413
+ in_specs=[
1414
+ hbm_block_spec, # tokens_hbm
1415
+ hbm_block_spec, # w1_hbm
1416
+ hbm_block_spec, # w2_hbm
1417
+ None if w1_scale is None else hbm_block_spec, # w1_scale_hbm
1418
+ None if w2_scale is None else hbm_block_spec, # w2_scale_hbm
1419
+ None if b1 is None else hbm_block_spec, # b1_hbm
1420
+ None if b2 is None else hbm_block_spec, # b2_hbm
1421
+ hbm_block_spec, # gating_output_hbm
1422
+ hbm_block_spec, # a2a_g_hbm
1423
+ ],
1424
+ out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
1425
+ scratch_shapes=([
1426
+ # t2e_routing_x2_smem
1427
+ pltpu.SMEM((2, bt, padded_top_k), jnp.int32),
1428
+ # d2e_count_x2_smem
1429
+ pltpu.SMEM((2, num_devices, 1, padded_num_experts), jnp.int32),
1430
+ # expert_offsets_x2_smem
1431
+ pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
1432
+ # expert_starts_x2_smem
1433
+ pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1434
+ # expert_sizes_x2_smem
1435
+ pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1436
+ # a2a_s_sends_x2_smem
1437
+ pltpu.SMEM((2, ), jnp.int32),
1438
+ # a2a_s_x2_vmem
1439
+ pltpu.VMEM(
1440
+ (
1441
+ 2,
1442
+ bt * num_devices,
1443
+ t_packing,
1444
+ hidden_size // t_packing,
1438
1445
  ),
1439
- # a2a_s_acc_x2_vmem
1440
- pltpu.VMEM(
1441
- (
1442
- 2,
1443
- bt * num_devices,
1444
- t_packing,
1445
- hidden_size // t_packing,
1446
- ),
1447
- t_dtype,
1446
+ t_dtype,
1447
+ ),
1448
+ # a2a_s_acc_x2_vmem
1449
+ pltpu.VMEM(
1450
+ (
1451
+ 2,
1452
+ bt * num_devices,
1453
+ t_packing,
1454
+ hidden_size // t_packing,
1448
1455
  ),
1449
- # a2a_g_acc_vmem
1450
- pltpu.VMEM(
1451
- (top_k, bt, t_packing, hidden_size // t_packing),
1452
- t_dtype),
1453
- # b_gating_x2_vmem
1454
- pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
1455
- # b_output_x2_vmem
1456
- pltpu.VMEM((2, bt, hidden_size), t_dtype),
1457
- # b_w1_x2_vmem
1458
- pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1459
- # b_w3_x2_vmem
1460
- pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1461
- # b_w2_x2_vmem
1462
- pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
1463
- # b_w1_scale_x2_vmem
1464
- (None if w1_scale is None else pltpu.VMEM(
1465
- (
1466
- 2,
1467
- t_packing,
1468
- bd1 // t_packing // subc_quant_wsz,
1469
- 1,
1470
- bf,
1471
- ),
1472
- jnp.float32,
1473
- )),
1474
- # b_w3_scale_x2_vmem
1475
- (None if w1_scale is None else pltpu.VMEM(
1476
- (
1477
- 2,
1478
- t_packing,
1479
- bd1 // t_packing // subc_quant_wsz,
1480
- 1,
1481
- bf,
1482
- ),
1483
- jnp.float32,
1484
- )),
1485
- # b_w2_scale_x2_vmem
1486
- (None if w2_scale is None else pltpu.VMEM(
1487
- (
1488
- 2,
1489
- t_packing,
1490
- bf // subc_quant_wsz,
1491
- 1,
1492
- bd2 // t_packing,
1493
- ),
1494
- jnp.float32,
1495
- )),
1496
- # b_b1_x2_vmem
1497
- (None if b1 is None else pltpu.VMEM(
1498
- (
1499
- 2,
1500
- 1,
1501
- bf,
1502
- ),
1503
- jnp.float32,
1504
- )),
1505
- # b_b3_x2_vmem
1506
- (None if b1 is None else pltpu.VMEM(
1507
- (
1508
- 2,
1509
- 1,
1510
- bf,
1511
- ),
1512
- jnp.float32,
1513
- )),
1514
- # b_b2_x2_vmem
1515
- (None if b2 is None else pltpu.VMEM(
1516
- (
1517
- 2,
1518
- t_packing,
1519
- 1,
1520
- bd2 // t_packing,
1521
- ),
1522
- jnp.float32,
1523
- )),
1524
- # b_acc_vmem
1525
- pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
1526
- # local_sems
1527
- pltpu.SemaphoreType.DMA((2, 5)),
1528
- # send_sems
1529
- pltpu.SemaphoreType.DMA((2, )),
1530
- # recv_sems
1531
- pltpu.SemaphoreType.DMA((2, )),
1532
- # a2a_gather_sem
1533
- pltpu.SemaphoreType.DMA,
1534
- # a2a_acc_sem
1535
- pltpu.SemaphoreType.DMA,
1536
- ]),
1537
- ),
1538
- compiler_params=pltpu.CompilerParams(
1539
- collective_id=0,
1540
- vmem_limit_bytes=100 * 1024 * 1024,
1541
- ),
1542
- name=scope_name,
1543
- ))
1456
+ t_dtype,
1457
+ ),
1458
+ # a2a_g_acc_vmem
1459
+ pltpu.VMEM((top_k, bt, t_packing, hidden_size // t_packing),
1460
+ t_dtype),
1461
+ # b_gating_x2_vmem
1462
+ pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
1463
+ # b_output_x2_vmem
1464
+ pltpu.VMEM((2, bt, hidden_size), t_dtype),
1465
+ # b_w1_x2_vmem
1466
+ pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1467
+ # b_w3_x2_vmem
1468
+ pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1469
+ # b_w2_x2_vmem
1470
+ pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
1471
+ # b_w1_scale_x2_vmem
1472
+ (None if w1_scale is None else pltpu.VMEM(
1473
+ (
1474
+ 2,
1475
+ t_packing,
1476
+ bd1 // t_packing // subc_quant_wsz,
1477
+ 1,
1478
+ bf,
1479
+ ),
1480
+ jnp.float32,
1481
+ )),
1482
+ # b_w3_scale_x2_vmem
1483
+ (None if w1_scale is None else pltpu.VMEM(
1484
+ (
1485
+ 2,
1486
+ t_packing,
1487
+ bd1 // t_packing // subc_quant_wsz,
1488
+ 1,
1489
+ bf,
1490
+ ),
1491
+ jnp.float32,
1492
+ )),
1493
+ # b_w2_scale_x2_vmem
1494
+ (None if w2_scale is None else pltpu.VMEM(
1495
+ (
1496
+ 2,
1497
+ t_packing,
1498
+ bf // subc_quant_wsz,
1499
+ 1,
1500
+ bd2 // t_packing,
1501
+ ),
1502
+ jnp.float32,
1503
+ )),
1504
+ # b_b1_x2_vmem
1505
+ (None if b1 is None else pltpu.VMEM(
1506
+ (
1507
+ 2,
1508
+ 1,
1509
+ bf,
1510
+ ),
1511
+ jnp.float32,
1512
+ )),
1513
+ # b_b3_x2_vmem
1514
+ (None if b1 is None else pltpu.VMEM(
1515
+ (
1516
+ 2,
1517
+ 1,
1518
+ bf,
1519
+ ),
1520
+ jnp.float32,
1521
+ )),
1522
+ # b_b2_x2_vmem
1523
+ (None if b2 is None else pltpu.VMEM(
1524
+ (
1525
+ 2,
1526
+ t_packing,
1527
+ 1,
1528
+ bd2 // t_packing,
1529
+ ),
1530
+ jnp.float32,
1531
+ )),
1532
+ # b_acc_vmem
1533
+ pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
1534
+ # local_sems
1535
+ pltpu.SemaphoreType.DMA((2, 5)),
1536
+ # send_sems
1537
+ pltpu.SemaphoreType.DMA((2, )),
1538
+ # recv_sems
1539
+ pltpu.SemaphoreType.DMA((2, )),
1540
+ # a2a_gather_sem
1541
+ pltpu.SemaphoreType.DMA,
1542
+ # a2a_acc_sem
1543
+ pltpu.SemaphoreType.DMA,
1544
+ ]),
1545
+ ),
1546
+ compiler_params=pltpu.CompilerParams(
1547
+ collective_id=0,
1548
+ vmem_limit_bytes=100 * 1024 * 1024,
1549
+ ),
1550
+ name=scope_name,
1551
+ )
1544
1552
 
1545
1553
  @jax.jit
1546
1554
  @jax.shard_map(
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.