tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """TPU-Friendly Fused Mixture of Experts (MoE) kernel."""
2
15
 
3
16
  import functools
@@ -19,7 +32,8 @@ def align_to(x, a):
19
32
 
20
33
 
21
34
  def get_dtype_packing(dtype):
22
- bits = dtypes.bit_width(dtype)
35
+ bits = (dtypes.bit_width(dtype)
36
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
23
37
  return 32 // bits
24
38
 
25
39
 
@@ -65,18 +79,19 @@ def ref_moe(
65
79
  top_k: int,
66
80
  *,
67
81
  renormalize_topk_logits: bool = False,
68
- activation="silu",
82
+ act_fn: str = "silu",
69
83
  subc_quant_wsz: int | None = None,
70
84
  w1_scale:
71
85
  (
72
86
  jax.Array | None
73
- ) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size)
87
+ ) = None, # F32(num_experts, 2, hidden_size //subc_quant_wsz, 1, intermediate_size)
74
88
  w2_scale:
75
89
  (
76
90
  jax.Array | None
77
- ) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size)
78
- b1: jax.Array | None = None, # (num_experts, 2, intermediate_size)
79
- b2: jax.Array | None = None, # (num_experts, hidden_size)
91
+ ) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
92
+ b1: jax.Array
93
+ | None = None, # F32(num_experts, 2, 1, intermediate_size)
94
+ b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
80
95
  ):
81
96
  n_tokens = tokens.shape[0] # num_tokens
82
97
 
@@ -97,7 +112,7 @@ def ref_moe(
97
112
 
98
113
  # Process each token individually
99
114
  for i in range(n_tokens):
100
- curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, d_model]
115
+ curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, hidden_size]
101
116
  assigned_expert_ids = top_k_indices[
102
117
  i] # [top_k] - indices of selected experts for token i
103
118
  tok_expert_act = []
@@ -108,19 +123,19 @@ def ref_moe(
108
123
  expert_w1 = w1[expert_id, 0].astype(jnp.float32)
109
124
  expert_w3 = w1[expert_id, 1].astype(jnp.float32)
110
125
  if w1_scale is not None:
111
- expert_w1 *= jnp.repeat(w1_scale[expert_id, 0],
126
+ expert_w1 *= jnp.repeat(w1_scale[expert_id, 0, :, 0],
112
127
  subc_quant_wsz,
113
128
  axis=0)[:hidden_size]
114
- expert_w3 *= jnp.repeat(w1_scale[expert_id, 1],
129
+ expert_w3 *= jnp.repeat(w1_scale[expert_id, 1, :, 0],
115
130
  subc_quant_wsz,
116
131
  axis=0)[:hidden_size]
117
132
  expert_weight_1 = jnp.concat(
118
133
  [expert_w1, expert_w3],
119
- axis=-1) # [d_model, 2 * intermediate_size]
134
+ axis=-1) # [hidden_size, 2 * intermediate_size]
120
135
  expert_weight_2 = w2[expert_id].astype(
121
- jnp.float32) # [intermediate_size, d_model]
136
+ jnp.float32) # [intermediate_size, hidden_size]
122
137
  if w2_scale is not None:
123
- expert_weight_2 *= jnp.repeat(w2_scale[expert_id],
138
+ expert_weight_2 *= jnp.repeat(w2_scale[expert_id, :, 0],
124
139
  subc_quant_wsz,
125
140
  axis=0)[:intermediate_size]
126
141
 
@@ -132,32 +147,33 @@ def ref_moe(
132
147
  gmm_1_out, 2,
133
148
  axis=-1) # [1, intermediate_size], [1, intermediate_size]
134
149
  if b1 is not None:
135
- gmm1_w1_proj += b1[expert_id:expert_id + 1, 0]
136
- gmm1_w3_proj += b1[expert_id:expert_id + 1, 1]
150
+ gmm1_w1_proj += b1[expert_id:expert_id + 1, 0, 0]
151
+ gmm1_w3_proj += b1[expert_id:expert_id + 1, 1, 0]
137
152
 
138
153
  # Apply gated activation: activation(gate) * up
139
- act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, activation)
154
+ act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, act_fn)
140
155
 
141
156
  # Second linear layer (down projection)
142
- gmm_2_out = act @ expert_weight_2 # [1, d_model]
157
+ gmm_2_out = act @ expert_weight_2 # [1, hidden_size]
143
158
  if b2 is not None:
144
- gmm_2_out += b2[expert_id:expert_id + 1]
159
+ gmm_2_out += b2[expert_id:expert_id + 1, 0]
145
160
  tok_expert_act.append(gmm_2_out)
146
161
 
147
162
  # Combine outputs from all selected experts
148
163
  experts_act = jnp.concatenate(tok_expert_act,
149
- axis=0) # [top_k, d_model]
164
+ axis=0) # [top_k, hidden_size]
150
165
 
151
166
  # Weighted sum using top-k gating weights
152
167
  top_k_weights = top_k_logits[i] # [top_k]
153
168
  top_k_weights = jnp.expand_dims(top_k_weights, axis=1) # [top_k, 1]
154
169
  weighted_output = jnp.sum(experts_act * top_k_weights,
155
170
  axis=0,
156
- keepdims=True) # [1, d_model]
171
+ keepdims=True) # [1, hidden_size]
157
172
 
158
173
  t_outputs.append(weighted_output.astype(tokens.dtype))
159
174
 
160
- return jnp.concatenate(t_outputs, axis=0) # [num_tokens, d_model]
175
+ return jnp.concatenate(t_outputs,
176
+ axis=0) # [actual_num_tokens, hidden_size]
161
177
 
162
178
 
163
179
  def _fused_ep_moe_kernel(
@@ -177,7 +193,7 @@ def _fused_ep_moe_kernel(
177
193
  # Output
178
194
  output_hbm, # (local_num_tokens, hidden_size)
179
195
  # Scratch
180
- t2e_routing_x2_smem, # <bt_sem_id> (2, bt, padded_num_experts)
196
+ t2e_routing_x2_smem, # <bt_sem_id> (2, bt, padded_top_k)
181
197
  d2e_count_x2_smem, # <bt_sem_id> (2, num_devices, 1, padded_num_experts)
182
198
  expert_offsets_x2_smem, # <bt_sem_id> (2, 2, padded_num_experts): for a2a_s and a2a_g
183
199
  expert_starts_x2_smem, # <bt_sem_id> (2, 1, padded_num_experts)
@@ -227,6 +243,11 @@ def _fused_ep_moe_kernel(
227
243
  local_num_tokens = tokens_hbm.shape[0]
228
244
  local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
229
245
  right_id = (my_id + 1) % num_devices
246
+ num_experts = a2a_g_hbm.shape[0]
247
+ padded_num_experts = d2e_count_x2_smem.shape[-1]
248
+ padded_top_k = t2e_routing_x2_smem.shape[-1]
249
+ assert padded_num_experts == align_to(num_experts, 128)
250
+ assert padded_top_k == align_to(top_k, 128)
230
251
 
231
252
  t_dtype = tokens_hbm.dtype
232
253
  t_packing = get_dtype_packing(t_dtype)
@@ -300,35 +321,40 @@ def _fused_ep_moe_kernel(
300
321
  def get_top_k(input, top_k, renormalize_topk_logits):
301
322
  assert len(input.shape) == 2, input.shape
302
323
  input = input.astype(jnp.float32)
324
+ padded_k_shape = (input.shape[0], padded_top_k)
303
325
  top_k_logits_lst = []
304
326
  top_k_indices_lst = []
305
327
  t2e = jnp.zeros(input.shape, dtype=jnp.int32)
306
- t2e_routing = jnp.zeros(input.shape, dtype=jnp.int32)
328
+ t2e_routing = jnp.zeros(padded_k_shape, dtype=jnp.int32)
307
329
  iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1)
308
- top_k_logits_sum = jnp.zeros((input.shape[0], 128), jnp.float32)
330
+ padded_k_iota = jax.lax.broadcasted_iota(jnp.int32, padded_k_shape, 1)
331
+ top_k_logits_sum = jnp.zeros(padded_k_shape, jnp.float32)
309
332
 
310
333
  for k_id in range(top_k):
311
334
  # TODO(jevinjiang): return both top_k values and indices in Mosaic
312
335
  top_k_logits = jnp.broadcast_to(
313
- jnp.max(input, axis=1, keepdims=True),
314
- (input.shape[0], 128)).astype(input.dtype)
336
+ jnp.max(input[:, :num_experts], axis=1, keepdims=True),
337
+ padded_k_shape,
338
+ ).astype(input.dtype)
339
+ top_k_logits_lst.append(top_k_logits)
315
340
  if renormalize_topk_logits:
316
341
  top_k_logits_sum += top_k_logits
317
- top_k_logits_lst.append(top_k_logits)
318
342
  # TODO(jevinjiang): support bf16 argmax in Mosaic
319
343
  top_k_indices = jnp.broadcast_to(
320
- jnp.argmax(input, axis=1, keepdims=True), input.shape)
344
+ jnp.argmax(input[:, :num_experts], axis=1, keepdims=True),
345
+ padded_k_shape,
346
+ )
321
347
  top_k_indices_lst.append(top_k_indices)
322
- t2e_routing = jnp.where(iota == k_id, top_k_indices, t2e_routing)
323
- mask = iota == top_k_indices
348
+ t2e_routing = jnp.where(padded_k_iota == k_id, top_k_indices,
349
+ t2e_routing)
350
+ mask = iota == broadcast_minor(top_k_indices, input.shape)
324
351
  t2e += mask.astype(jnp.int32)
325
352
  if k_id != top_k - 1:
326
353
  input = jnp.where(mask, -jnp.inf, input)
327
354
 
328
355
  if renormalize_topk_logits:
329
356
  for k_id in range(top_k):
330
- top_k_logits_lst[
331
- k_id] = top_k_logits_lst[k_id] / top_k_logits_sum
357
+ top_k_logits_lst[k_id] /= top_k_logits_sum
332
358
 
333
359
  expert_sizes = jnp.sum(t2e, axis=0, keepdims=True)
334
360
  expert_starts = jnp.zeros_like(expert_sizes)
@@ -1071,27 +1097,38 @@ def _fused_ep_moe_kernel(
1071
1097
 
1072
1098
  all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts,
1073
1099
  expert_sizes)
1100
+ sync_barrier()
1074
1101
 
1102
+ # Start a2a scatter for first active expert.
1075
1103
  start_a2a_scatter(bt_id=bt_id, e_sem_id=e_sem_id, local_e_id=0)
1076
1104
 
1077
1105
  def run_per_expert(local_e_id, e_sem_id):
1078
1106
  sync_barrier()
1107
+
1108
+ # Prefetch weights for CURRENT active expert.
1109
+ # TODO(jevinjiang): It is hard to prefetch weights in previous iteration
1110
+ # because the expert_ffn keeps overwriting the buffers. Triple buffering
1111
+ # could resolve this but it takes more VMEM scratch. Need further
1112
+ # experiment on this.
1113
+ start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
1114
+ start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
1115
+
1116
+ # Next ids.
1079
1117
  next_e_sem_id = lax.select(e_sem_id == 0, 1, 0)
1080
1118
  next_local_e_id = local_e_id + 1
1081
1119
 
1120
+ # Start a2a scatter for NEXT active expert.
1082
1121
  @pl.when(next_local_e_id < local_num_experts)
1083
1122
  def _():
1084
1123
  start_a2a_scatter(bt_id, next_e_sem_id, next_local_e_id)
1085
1124
 
1086
- # Prefetch weights for active expert.
1087
- start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
1088
- start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
1089
-
1090
- # Wait for a2a scatter and perform FFN for active expert.
1125
+ # Wait a2a scatter for CURRENT active expert.
1091
1126
  wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id)
1127
+
1128
+ # Perform FFN for CURRENT active expert.
1092
1129
  expert_ffn(bt_id, e_sem_id, local_e_id)
1093
1130
 
1094
- # Wait for a2a gather to send back tokens for active expert.
1131
+ # Start a2a gather to send back tokens for CURRENT active expert.
1095
1132
  start_a2a_gather(bt_id, e_sem_id, local_e_id)
1096
1133
 
1097
1134
  # A must-wait before next sync_barrier.
@@ -1104,7 +1141,10 @@ def _fused_ep_moe_kernel(
1104
1141
  e_sem_id,
1105
1142
  unroll=False)
1106
1143
 
1144
+ # Wait to receive a2a gather for ALL experts.
1107
1145
  wait_a2a_gather_recv_all()
1146
+
1147
+ # Accumulate results for current batch.
1108
1148
  output = bt_acc(bt_id, top_k_logits_lst)
1109
1149
 
1110
1150
  # Make sure it is safe to overwrite output buffer.
@@ -1158,18 +1198,18 @@ def fused_ep_moe(
1158
1198
  w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
1159
1199
  gating_output: jax.Array, # (num_tokens, num_experts)
1160
1200
  top_k: int,
1201
+ *,
1161
1202
  renormalize_topk_logits: bool = False,
1162
1203
  act_fn: str = "silu",
1163
- *,
1164
1204
  subc_quant_wsz: int | None = None,
1165
1205
  w1_scale: (
1166
1206
  jax.Array | None
1167
- ) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size)
1207
+ ) = None, # F32(num_experts, 2, hidden_size // subc_quant_wsz, 1, intermediate_size)
1168
1208
  w2_scale: (
1169
1209
  jax.Array | None
1170
- ) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size)
1171
- b1: jax.Array | None = None, # (num_experts, 2, intermediate_size)
1172
- b2: jax.Array | None = None, # (num_experts, hidden_size)
1210
+ ) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
1211
+ b1: jax.Array | None = None, # F32(num_experts, 2, 1, intermediate_size)
1212
+ b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
1173
1213
  # Kernel tuning parameters.
1174
1214
  bt: int,
1175
1215
  bf: int,
@@ -1182,75 +1222,159 @@ def fused_ep_moe(
1182
1222
  ep_axis_name: str = "model",
1183
1223
  ):
1184
1224
  # TODO(jevinjiang): move all these assertions to validation function.
1185
- # Assert all other axes have length of 1
1186
- assert len(mesh.shape) == 2, "Expect 2D mesh"
1187
- assert ("data" in mesh.shape
1188
- and mesh.shape["data"] == 1), "Expect data axis size of 1"
1225
+ if len(mesh.shape) != 2:
1226
+ raise NotImplementedError("Only 2D mesh is supported.")
1227
+
1228
+ for axis_name in mesh.axis_names:
1229
+ if axis_name == ep_axis_name:
1230
+ continue
1231
+ if mesh.shape[axis_name] != 1:
1232
+ raise NotImplementedError(
1233
+ f"Expected all non-ep axis to have size 1 in {mesh.shape=}")
1189
1234
 
1190
1235
  ep_size = mesh.shape[ep_axis_name]
1191
1236
  num_devices = ep_size
1192
1237
 
1193
- num_tokens, actual_hidden_size = tokens.shape
1194
- num_experts, actual_intermediate_size, _ = w2.shape
1238
+ num_tokens, hidden_size = tokens.shape
1239
+ num_experts, intermediate_size, _ = w2.shape
1240
+
1241
+ if w1.shape != (num_experts, 2, hidden_size, intermediate_size):
1242
+ raise ValueError(
1243
+ f"Expected {w1.shape=} to be"
1244
+ f" {(num_experts, 2, hidden_size, intermediate_size)}.")
1245
+
1246
+ if w2.shape != (num_experts, intermediate_size, hidden_size):
1247
+ raise ValueError(f"Expected {w2.shape=} to be"
1248
+ f" {(num_experts, intermediate_size, hidden_size)}.")
1195
1249
 
1196
- assert num_tokens % ep_size == 0
1197
- assert num_experts % ep_size == 0
1250
+ if gating_output.shape != (num_tokens, num_experts):
1251
+ raise ValueError(
1252
+ f"Expected {gating_output.shape=} to be {(num_tokens, num_experts)}."
1253
+ )
1254
+
1255
+ if not (0 < top_k <= num_experts):
1256
+ raise ValueError(
1257
+ f"Expected {top_k=} to be in range (0, {num_experts=}].")
1258
+
1259
+ if hidden_size % 128 != 0 or intermediate_size % 128 != 0:
1260
+ raise ValueError(
1261
+ f"Expected {hidden_size=} and {intermediate_size=} to be aligned to"
1262
+ " 128. Did you pad them with zeros outside the kernel?")
1263
+ if num_tokens % ep_size != 0:
1264
+ raise ValueError(
1265
+ f"Expected {num_tokens=} to be aligned to {ep_size=}.")
1266
+ if num_experts % ep_size != 0:
1267
+ raise ValueError(
1268
+ f"Expected {num_experts=} to be aligned to {ep_size=}.")
1198
1269
 
1199
1270
  local_num_tokens = num_tokens // ep_size
1200
1271
  # local_num_experts = num_experts // ep_size
1201
1272
  padded_num_experts = align_to(num_experts, 128)
1273
+ padded_top_k = align_to(top_k, 128)
1202
1274
  t_dtype = tokens.dtype
1203
1275
  t_packing = get_dtype_packing(t_dtype)
1204
1276
 
1277
+ # Override bt
1278
+ if local_num_tokens <= t_packing * 8:
1279
+ bt = local_num_tokens
1280
+ btc = bt
1281
+ bt = min(local_num_tokens, bt)
1282
+ # The worst case is that all devices send bt to one device.
1283
+ btc = min(bt, btc, bt * num_devices)
1284
+
1285
+ if local_num_tokens % t_packing != 0:
1286
+ raise ValueError(
1287
+ f"Expected {local_num_tokens=} to be aligned to {t_packing=}.")
1288
+
1289
+ if bt % t_packing != 0:
1290
+ raise ValueError(f"Expected {bt=} to be aligned to {t_packing=}.")
1291
+ if local_num_tokens % bt != 0:
1292
+ raise ValueError(
1293
+ f"Expected {local_num_tokens=} to be aligned to {bt=}.")
1294
+
1205
1295
  if subc_quant_wsz is not None:
1296
+ if subc_quant_wsz <= 0:
1297
+ raise ValueError(f"Expected {subc_quant_wsz=} to be non-negative.")
1206
1298
  if subc_quant_wsz % 256 != 0:
1207
- raise NotImplementedError(
1208
- "Sub-quantized window is not aligned to 256.")
1209
- # We force compute size of contracting dim to subc_quant_wsz. So we can
1299
+ raise ValueError(
1300
+ "Expected {subc_quant_wsz=} to be aligned to 256.")
1301
+ if hidden_size % subc_quant_wsz != 0:
1302
+ raise ValueError(
1303
+ f"Expected {hidden_size=} to be aligned to {subc_quant_wsz=}.")
1304
+ if intermediate_size % subc_quant_wsz != 0:
1305
+ raise ValueError(
1306
+ f"Expected {intermediate_size=} to be aligned to {subc_quant_wsz=}."
1307
+ )
1308
+ # We force compute size of contracting dim to be subc_quant_wsz. So we can
1210
1309
  # apply same scale after matmul and accumulation.
1211
1310
  bd1c = subc_quant_wsz * t_packing
1212
1311
  bfc = subc_quant_wsz
1213
1312
 
1214
- assert bfc % 128 == 0
1215
- assert bd1c % (t_packing * 128) == 0
1216
- assert bd2c % (t_packing * 128) == 0
1217
- assert bf % bfc == 0
1218
- assert bd1 % bd1c == 0
1219
- assert bd2 % bd2c == 0
1220
-
1221
- btc = min(btc, bt * num_devices)
1222
- hidden_size = align_to(actual_hidden_size, 128 * t_packing)
1223
- # TODO(jevinjiang): instead of padding outside the kernel, we can try dynammic
1224
- # masking inside the kernel.
1225
- hidden_size = align_to(hidden_size, bd1)
1226
- hidden_size = align_to(hidden_size, bd2)
1227
- intermediate_size = align_to(actual_intermediate_size, bf)
1228
-
1229
- # TODO(jevinjiang): we should dump scale as the kernel expected shape in the
1313
+ if bfc % 128 != 0:
1314
+ raise ValueError(f"Expected {bfc=} to be aligned to 128.")
1315
+ if bd1c % (t_packing * 128) != 0:
1316
+ raise ValueError(
1317
+ f"Expected {bd1c=} to be aligned to {t_packing * 128}.")
1318
+ if bd2c % (t_packing * 128) != 0:
1319
+ raise ValueError(
1320
+ f"Expected {bd2c=} to be aligned to {t_packing * 128}.")
1321
+ if bf % bfc != 0:
1322
+ raise ValueError(f"Expected {bf=} to be aligned to {bfc=}.")
1323
+ if bd1 % bd1c != 0:
1324
+ raise ValueError(f"Expected {bd1=} to be aligned to {bd1c=}.")
1325
+ if bd2 % bd2c != 0:
1326
+ raise ValueError(f"Expected {bd2=} to be aligned to {bd2c=}.")
1327
+ if hidden_size % bd1 != 0 or hidden_size % bd2 != 0:
1328
+ raise ValueError(
1329
+ f"Expected {hidden_size=} to be aligned to {bd1=} and {bd2=}.")
1330
+ if intermediate_size % bf != 0:
1331
+ raise ValueError(
1332
+ f"Expected {intermediate_size=} to be aligned to {bf=}.")
1333
+
1334
+ # Note: we should dump scale as the kernel expected shape in the
1230
1335
  # checkpoint offline or reshape right after weight loading.
1231
1336
  if w1_scale is not None:
1232
- assert w1_scale.shape[0] == w1.shape[0]
1233
- assert w1_scale.shape[1] == w1.shape[1] == 2
1234
- assert w1_scale.shape[2] == cdiv(w1.shape[2], subc_quant_wsz)
1235
- assert w1_scale.shape[3] == w1.shape[3]
1236
- w1_scale = jnp.expand_dims(w1_scale.astype(jnp.float32), axis=-2)
1337
+ expected_w1_scale_shape = (
1338
+ num_experts,
1339
+ 2,
1340
+ hidden_size // subc_quant_wsz,
1341
+ 1,
1342
+ intermediate_size,
1343
+ )
1344
+ if w1_scale.shape != expected_w1_scale_shape:
1345
+ raise ValueError(
1346
+ f"Expected {w1_scale.shape=} to be {expected_w1_scale_shape}.")
1347
+ if w1_scale.dtype != jnp.float32:
1348
+ w1_scale = w1_scale.astype(jnp.float32)
1237
1349
 
1238
1350
  if w2_scale is not None:
1239
- assert w2_scale.shape[0] == w2.shape[0]
1240
- assert w2_scale.shape[1] == cdiv(w2.shape[1], subc_quant_wsz)
1241
- assert w2_scale.shape[2] == w2.shape[2]
1242
- w2_scale = jnp.expand_dims(w2_scale.astype(jnp.float32), axis=-2)
1351
+ expected_w2_scale_shape = (
1352
+ num_experts,
1353
+ intermediate_size // subc_quant_wsz,
1354
+ 1,
1355
+ hidden_size,
1356
+ )
1357
+ if w2_scale.shape != expected_w2_scale_shape:
1358
+ raise ValueError(
1359
+ f"Expected {w2_scale.shape=} to be {expected_w2_scale_shape}.")
1360
+ if w2_scale.dtype != jnp.float32:
1361
+ w2_scale = w2_scale.astype(jnp.float32)
1243
1362
 
1244
1363
  if b1 is not None:
1245
- assert b1.shape[0] == w1.shape[0]
1246
- assert b1.shape[1] == w1.shape[1] == 2
1247
- assert b1.shape[2] == w1.shape[3]
1248
- b1 = jnp.expand_dims(b1.astype(jnp.float32), axis=-2)
1364
+ expected_b1_shape = (num_experts, 2, 1, intermediate_size)
1365
+ if b1.shape != expected_b1_shape:
1366
+ raise ValueError(
1367
+ f"Expected {b1.shape=} to be {expected_b1_shape}.")
1368
+ if b1.dtype != jnp.float32:
1369
+ b1 = b1.astype(jnp.float32)
1249
1370
 
1250
1371
  if b2 is not None:
1251
- assert b2.shape[0] == w2.shape[0]
1252
- assert b2.shape[1] == w2.shape[2]
1253
- b2 = jnp.expand_dims(b2.astype(jnp.float32), axis=-2)
1372
+ expected_b2_shape = (num_experts, 1, hidden_size)
1373
+ if b2.shape != expected_b2_shape:
1374
+ raise ValueError(
1375
+ f"Expected {b2.shape=} to be {expected_b2_shape}.")
1376
+ if b2.dtype != jnp.float32:
1377
+ b2 = b2.astype(jnp.float32)
1254
1378
 
1255
1379
  # Prepare inputs for the kernel.
1256
1380
  if padded_num_experts != gating_output.shape[-1]:
@@ -1260,248 +1384,171 @@ def fused_ep_moe(
1260
1384
  constant_values=-jnp.inf,
1261
1385
  )
1262
1386
 
1263
- if (hidden_size != actual_hidden_size
1264
- or intermediate_size != actual_intermediate_size):
1265
- tokens = jnp.pad(
1266
- tokens,
1267
- ((0, 0), (0, hidden_size - actual_hidden_size)),
1268
- constant_values=0,
1269
- )
1270
- w1 = jnp.pad(
1271
- w1,
1272
- (
1273
- (0, 0),
1274
- (0, 0),
1275
- (0, hidden_size - actual_hidden_size),
1276
- (0, intermediate_size - actual_intermediate_size),
1277
- ),
1278
- constant_values=0,
1279
- )
1280
- w2 = jnp.pad(
1281
- w2,
1282
- (
1283
- (0, 0),
1284
- (0, intermediate_size - actual_intermediate_size),
1285
- (0, hidden_size - actual_hidden_size),
1286
- ),
1287
- constant_values=0,
1288
- )
1289
- if w1_scale is not None:
1290
- w1_scale = jnp.pad(
1291
- w1_scale,
1292
- (
1293
- (0, 0),
1294
- (0, 0),
1295
- (0,
1296
- cdiv(hidden_size, subc_quant_wsz) - w1_scale.shape[-3]),
1297
- (0, 0),
1298
- (0, intermediate_size - w1_scale.shape[-1]),
1299
- ),
1300
- constant_values=0,
1301
- )
1302
- if w2_scale is not None:
1303
- w2_scale = jnp.pad(
1304
- w2_scale,
1305
- (
1306
- (0, 0),
1307
- (0, cdiv(intermediate_size, subc_quant_wsz) -
1308
- w2_scale.shape[-3]),
1309
- (0, 0),
1310
- (0, hidden_size - w2_scale.shape[-1]),
1311
- ),
1312
- constant_values=0,
1313
- )
1314
- if b1 is not None:
1315
- b1 = jnp.pad(
1316
- b1,
1317
- (
1318
- (0, 0),
1319
- (0, 0),
1320
- (0, 0),
1321
- (0, intermediate_size - b1.shape[-1]),
1322
- ),
1323
- constant_values=0,
1324
- )
1325
- if b2 is not None:
1326
- b2 = jnp.pad(
1327
- b2,
1328
- (
1329
- (0, 0),
1330
- (0, 0),
1331
- (0, hidden_size - b2.shape[-1]),
1332
- ),
1333
- constant_values=0,
1334
- )
1335
-
1336
1387
  tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
1337
1388
 
1338
1389
  hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
1339
- scope_name = f"fused_moe_k-{top_k}_renorm-{renormalize_topk_logits}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}"
1340
- fused_moe = jax.named_scope(scope_name)(
1341
- pl.pallas_call(
1342
- functools.partial(
1343
- _fused_ep_moe_kernel,
1344
- top_k=top_k,
1345
- renormalize_topk_logits=renormalize_topk_logits,
1346
- ep_axis_name=ep_axis_name,
1347
- act_fn=act_fn,
1348
- subc_quant_wsz=subc_quant_wsz,
1349
- bt=bt,
1350
- bf=bf,
1351
- bd1=bd1,
1352
- bd2=bd2,
1353
- btc=btc,
1354
- bfc=bfc,
1355
- bd1c=bd1c,
1356
- bd2c=bd2c,
1357
- ),
1358
- out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
1359
- t_dtype),
1360
- grid_spec=pltpu.PrefetchScalarGridSpec(
1361
- num_scalar_prefetch=0,
1362
- in_specs=[
1363
- hbm_block_spec, # tokens_hbm
1364
- hbm_block_spec, # w1_hbm
1365
- hbm_block_spec, # w2_hbm
1366
- None
1367
- if w1_scale is None else hbm_block_spec, # w1_scale_hbm
1368
- None
1369
- if w2_scale is None else hbm_block_spec, # w2_scale_hbm
1370
- None if b1 is None else hbm_block_spec, # b1_hbm
1371
- None if b2 is None else hbm_block_spec, # b2_hbm
1372
- hbm_block_spec, # gating_output_hbm
1373
- hbm_block_spec, # a2a_g_hbm
1374
- ],
1375
- out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
1376
- scratch_shapes=([
1377
- # t2e_routing_x2_smem
1378
- pltpu.SMEM((2, bt, padded_num_experts), jnp.int32),
1379
- # d2e_count_x2_smem
1380
- pltpu.SMEM((2, num_devices, 1, padded_num_experts),
1381
- jnp.int32),
1382
- # expert_offsets_x2_smem
1383
- pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
1384
- # expert_starts_x2_smem
1385
- pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1386
- # expert_sizes_x2_smem
1387
- pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1388
- # a2a_s_sends_x2_smem
1389
- pltpu.SMEM((2, ), jnp.int32),
1390
- # a2a_s_x2_vmem
1391
- pltpu.VMEM(
1392
- (
1393
- 2,
1394
- bt * num_devices,
1395
- t_packing,
1396
- hidden_size // t_packing,
1397
- ),
1398
- t_dtype,
1390
+ renorm_str = "-renorm_k" if renormalize_topk_logits else ""
1391
+ scope_name = f"fused-moe-k_{top_k}{renorm_str}-bt_{bt}_{btc}-bf_{bf}_{bfc}-bd1_{bd1}_{bd1c}-bd2_{bd2}_{bd2c}"
1392
+ fused_moe = pl.pallas_call(
1393
+ functools.partial(
1394
+ _fused_ep_moe_kernel,
1395
+ top_k=top_k,
1396
+ renormalize_topk_logits=renormalize_topk_logits,
1397
+ ep_axis_name=ep_axis_name,
1398
+ act_fn=act_fn,
1399
+ subc_quant_wsz=subc_quant_wsz,
1400
+ bt=bt,
1401
+ bf=bf,
1402
+ bd1=bd1,
1403
+ bd2=bd2,
1404
+ btc=btc,
1405
+ bfc=bfc,
1406
+ bd1c=bd1c,
1407
+ bd2c=bd2c,
1408
+ ),
1409
+ out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
1410
+ t_dtype),
1411
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1412
+ num_scalar_prefetch=0,
1413
+ in_specs=[
1414
+ hbm_block_spec, # tokens_hbm
1415
+ hbm_block_spec, # w1_hbm
1416
+ hbm_block_spec, # w2_hbm
1417
+ None if w1_scale is None else hbm_block_spec, # w1_scale_hbm
1418
+ None if w2_scale is None else hbm_block_spec, # w2_scale_hbm
1419
+ None if b1 is None else hbm_block_spec, # b1_hbm
1420
+ None if b2 is None else hbm_block_spec, # b2_hbm
1421
+ hbm_block_spec, # gating_output_hbm
1422
+ hbm_block_spec, # a2a_g_hbm
1423
+ ],
1424
+ out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
1425
+ scratch_shapes=([
1426
+ # t2e_routing_x2_smem
1427
+ pltpu.SMEM((2, bt, padded_top_k), jnp.int32),
1428
+ # d2e_count_x2_smem
1429
+ pltpu.SMEM((2, num_devices, 1, padded_num_experts), jnp.int32),
1430
+ # expert_offsets_x2_smem
1431
+ pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
1432
+ # expert_starts_x2_smem
1433
+ pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1434
+ # expert_sizes_x2_smem
1435
+ pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1436
+ # a2a_s_sends_x2_smem
1437
+ pltpu.SMEM((2, ), jnp.int32),
1438
+ # a2a_s_x2_vmem
1439
+ pltpu.VMEM(
1440
+ (
1441
+ 2,
1442
+ bt * num_devices,
1443
+ t_packing,
1444
+ hidden_size // t_packing,
1399
1445
  ),
1400
- # a2a_s_acc_x2_vmem
1401
- pltpu.VMEM(
1402
- (
1403
- 2,
1404
- bt * num_devices,
1405
- t_packing,
1406
- hidden_size // t_packing,
1407
- ),
1408
- t_dtype,
1446
+ t_dtype,
1447
+ ),
1448
+ # a2a_s_acc_x2_vmem
1449
+ pltpu.VMEM(
1450
+ (
1451
+ 2,
1452
+ bt * num_devices,
1453
+ t_packing,
1454
+ hidden_size // t_packing,
1409
1455
  ),
1410
- # a2a_g_acc_vmem
1411
- pltpu.VMEM(
1412
- (top_k, bt, t_packing, hidden_size // t_packing),
1413
- t_dtype),
1414
- # b_gating_x2_vmem
1415
- pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
1416
- # b_output_x2_vmem
1417
- pltpu.VMEM((2, bt, hidden_size), t_dtype),
1418
- # b_w1_x2_vmem
1419
- pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1420
- # b_w3_x2_vmem
1421
- pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1422
- # b_w2_x2_vmem
1423
- pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
1424
- # b_w1_scale_x2_vmem
1425
- (None if w1_scale is None else pltpu.VMEM(
1426
- (
1427
- 2,
1428
- t_packing,
1429
- bd1 // t_packing // subc_quant_wsz,
1430
- 1,
1431
- bf,
1432
- ),
1433
- jnp.float32,
1434
- )),
1435
- # b_w3_scale_x2_vmem
1436
- (None if w1_scale is None else pltpu.VMEM(
1437
- (
1438
- 2,
1439
- t_packing,
1440
- bd1 // t_packing // subc_quant_wsz,
1441
- 1,
1442
- bf,
1443
- ),
1444
- jnp.float32,
1445
- )),
1446
- # b_w2_scale_x2_vmem
1447
- (None if w2_scale is None else pltpu.VMEM(
1448
- (
1449
- 2,
1450
- t_packing,
1451
- bf // subc_quant_wsz,
1452
- 1,
1453
- bd2 // t_packing,
1454
- ),
1455
- jnp.float32,
1456
- )),
1457
- # b_b1_x2_vmem
1458
- (None if b1 is None else pltpu.VMEM(
1459
- (
1460
- 2,
1461
- 1,
1462
- bf,
1463
- ),
1464
- jnp.float32,
1465
- )),
1466
- # b_b3_x2_vmem
1467
- (None if b1 is None else pltpu.VMEM(
1468
- (
1469
- 2,
1470
- 1,
1471
- bf,
1472
- ),
1473
- jnp.float32,
1474
- )),
1475
- # b_b2_x2_vmem
1476
- (None if b2 is None else pltpu.VMEM(
1477
- (
1478
- 2,
1479
- t_packing,
1480
- 1,
1481
- bd2 // t_packing,
1482
- ),
1483
- jnp.float32,
1484
- )),
1485
- # b_acc_vmem
1486
- pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
1487
- # local_sems
1488
- pltpu.SemaphoreType.DMA((2, 5)),
1489
- # send_sems
1490
- pltpu.SemaphoreType.DMA((2, )),
1491
- # recv_sems
1492
- pltpu.SemaphoreType.DMA((2, )),
1493
- # a2a_gather_sem
1494
- pltpu.SemaphoreType.DMA,
1495
- # a2a_acc_sem
1496
- pltpu.SemaphoreType.DMA,
1497
- ]),
1498
- ),
1499
- compiler_params=pltpu.CompilerParams(
1500
- collective_id=0,
1501
- vmem_limit_bytes=100 * 1024 * 1024,
1502
- ),
1503
- name=scope_name,
1504
- ))
1456
+ t_dtype,
1457
+ ),
1458
+ # a2a_g_acc_vmem
1459
+ pltpu.VMEM((top_k, bt, t_packing, hidden_size // t_packing),
1460
+ t_dtype),
1461
+ # b_gating_x2_vmem
1462
+ pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
1463
+ # b_output_x2_vmem
1464
+ pltpu.VMEM((2, bt, hidden_size), t_dtype),
1465
+ # b_w1_x2_vmem
1466
+ pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1467
+ # b_w3_x2_vmem
1468
+ pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1469
+ # b_w2_x2_vmem
1470
+ pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
1471
+ # b_w1_scale_x2_vmem
1472
+ (None if w1_scale is None else pltpu.VMEM(
1473
+ (
1474
+ 2,
1475
+ t_packing,
1476
+ bd1 // t_packing // subc_quant_wsz,
1477
+ 1,
1478
+ bf,
1479
+ ),
1480
+ jnp.float32,
1481
+ )),
1482
+ # b_w3_scale_x2_vmem
1483
+ (None if w1_scale is None else pltpu.VMEM(
1484
+ (
1485
+ 2,
1486
+ t_packing,
1487
+ bd1 // t_packing // subc_quant_wsz,
1488
+ 1,
1489
+ bf,
1490
+ ),
1491
+ jnp.float32,
1492
+ )),
1493
+ # b_w2_scale_x2_vmem
1494
+ (None if w2_scale is None else pltpu.VMEM(
1495
+ (
1496
+ 2,
1497
+ t_packing,
1498
+ bf // subc_quant_wsz,
1499
+ 1,
1500
+ bd2 // t_packing,
1501
+ ),
1502
+ jnp.float32,
1503
+ )),
1504
+ # b_b1_x2_vmem
1505
+ (None if b1 is None else pltpu.VMEM(
1506
+ (
1507
+ 2,
1508
+ 1,
1509
+ bf,
1510
+ ),
1511
+ jnp.float32,
1512
+ )),
1513
+ # b_b3_x2_vmem
1514
+ (None if b1 is None else pltpu.VMEM(
1515
+ (
1516
+ 2,
1517
+ 1,
1518
+ bf,
1519
+ ),
1520
+ jnp.float32,
1521
+ )),
1522
+ # b_b2_x2_vmem
1523
+ (None if b2 is None else pltpu.VMEM(
1524
+ (
1525
+ 2,
1526
+ t_packing,
1527
+ 1,
1528
+ bd2 // t_packing,
1529
+ ),
1530
+ jnp.float32,
1531
+ )),
1532
+ # b_acc_vmem
1533
+ pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
1534
+ # local_sems
1535
+ pltpu.SemaphoreType.DMA((2, 5)),
1536
+ # send_sems
1537
+ pltpu.SemaphoreType.DMA((2, )),
1538
+ # recv_sems
1539
+ pltpu.SemaphoreType.DMA((2, )),
1540
+ # a2a_gather_sem
1541
+ pltpu.SemaphoreType.DMA,
1542
+ # a2a_acc_sem
1543
+ pltpu.SemaphoreType.DMA,
1544
+ ]),
1545
+ ),
1546
+ compiler_params=pltpu.CompilerParams(
1547
+ collective_id=0,
1548
+ vmem_limit_bytes=100 * 1024 * 1024,
1549
+ ),
1550
+ name=scope_name,
1551
+ )
1505
1552
 
1506
1553
  @jax.jit
1507
1554
  @jax.shard_map(
@@ -1552,7 +1599,7 @@ def fused_ep_moe(
1552
1599
 
1553
1600
  a2a_g_hbm_scratch = pl.empty(
1554
1601
  (num_experts, bt, t_packing, hidden_size // t_packing), t_dtype)
1555
- results = kernel(
1602
+ return kernel(
1556
1603
  tokens,
1557
1604
  w1,
1558
1605
  w2,
@@ -1563,4 +1610,3 @@ def fused_ep_moe(
1563
1610
  gating_output,
1564
1611
  a2a_g_hbm_scratch,
1565
1612
  )
1566
- return results[:, :actual_hidden_size]