tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """TPU-Friendly Fused Mixture of Experts (MoE) kernel."""
2
15
 
3
16
  import functools
@@ -7,7 +20,6 @@ import jax.numpy as jnp
7
20
  from jax import lax
8
21
  from jax._src import dtypes
9
22
  from jax.experimental import pallas as pl
10
- from jax.experimental import shard_map
11
23
  from jax.experimental.pallas import tpu as pltpu
12
24
 
13
25
  P = jax.sharding.PartitionSpec
@@ -20,7 +32,8 @@ def align_to(x, a):
20
32
 
21
33
 
22
34
  def get_dtype_packing(dtype):
23
- bits = dtypes.bit_width(dtype)
35
+ bits = (dtypes.bit_width(dtype)
36
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
24
37
  return 32 // bits
25
38
 
26
39
 
@@ -35,13 +48,50 @@ def broadcast_minor(src, shape):
35
48
  axis=-1)[..., :shape[-1]]
36
49
 
37
50
 
51
+ def swigluoai(gate: jax.Array,
52
+ up: jax.Array,
53
+ *,
54
+ alpha: float = 1.702,
55
+ limit: float = 7.0) -> jax.Array:
56
+ """Activation used in some models such as GPT-OSS."""
57
+ gate = jnp.clip(gate, a_max=limit)
58
+ up = jnp.clip(up, a_min=-limit, a_max=limit)
59
+ glu = gate * jax.nn.sigmoid(alpha * gate)
60
+ return (up + 1.0) * glu
61
+
62
+
63
+ def activation_fn(acc1, acc3, act_fn):
64
+ if act_fn == "silu":
65
+ return jax.nn.silu(acc1) * acc3
66
+ elif act_fn == "gelu":
67
+ return jax.nn.gelu(acc1) * acc3
68
+ elif act_fn == "swigluoai":
69
+ return swigluoai(acc1, acc3)
70
+ else:
71
+ raise RuntimeError(f"Unsupported activation function: {act_fn}")
72
+
73
+
38
74
  def ref_moe(
39
- tokens: jax.Array, # (num_tokens, hidden_size)
40
- w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
41
- w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
42
- gating_output: jax.Array, # (num_tokens, num_experts)
43
- top_k: int,
44
- activation="silu",
75
+ tokens: jax.Array, # (num_tokens, hidden_size)
76
+ w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
77
+ w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
78
+ gating_output: jax.Array, # (num_tokens, num_experts)
79
+ top_k: int,
80
+ *,
81
+ renormalize_topk_logits: bool = False,
82
+ act_fn: str = "silu",
83
+ subc_quant_wsz: int | None = None,
84
+ w1_scale:
85
+ (
86
+ jax.Array | None
87
+ ) = None, # F32(num_experts, 2, hidden_size //subc_quant_wsz, 1, intermediate_size)
88
+ w2_scale:
89
+ (
90
+ jax.Array | None
91
+ ) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
92
+ b1: jax.Array
93
+ | None = None, # F32(num_experts, 2, 1, intermediate_size)
94
+ b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
45
95
  ):
46
96
  n_tokens = tokens.shape[0] # num_tokens
47
97
 
@@ -53,11 +103,16 @@ def ref_moe(
53
103
  top_k_logits, top_k_indices = lax.top_k(
54
104
  gating_logits, top_k) # [num_tokens, top_k], [num_tokens, top_k]
55
105
 
106
+ if renormalize_topk_logits:
107
+ top_k_logits = top_k_logits / jnp.sum(
108
+ top_k_logits, axis=-1, keepdims=True)
109
+
56
110
  t_outputs = []
111
+ hidden_size, intermediate_size = w1.shape[-2:]
57
112
 
58
113
  # Process each token individually
59
114
  for i in range(n_tokens):
60
- curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, d_model]
115
+ curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, hidden_size]
61
116
  assigned_expert_ids = top_k_indices[
62
117
  i] # [top_k] - indices of selected experts for token i
63
118
  tok_expert_act = []
@@ -65,10 +120,24 @@ def ref_moe(
65
120
  # Process each selected expert for the current token
66
121
  for expert_id in assigned_expert_ids:
67
122
  # Get expert weights
123
+ expert_w1 = w1[expert_id, 0].astype(jnp.float32)
124
+ expert_w3 = w1[expert_id, 1].astype(jnp.float32)
125
+ if w1_scale is not None:
126
+ expert_w1 *= jnp.repeat(w1_scale[expert_id, 0, :, 0],
127
+ subc_quant_wsz,
128
+ axis=0)[:hidden_size]
129
+ expert_w3 *= jnp.repeat(w1_scale[expert_id, 1, :, 0],
130
+ subc_quant_wsz,
131
+ axis=0)[:hidden_size]
68
132
  expert_weight_1 = jnp.concat(
69
- [w1[expert_id, 0], w1[expert_id, 1]],
70
- axis=-1) # [d_model, 2 * intermediate_size]
71
- expert_weight_2 = w2[expert_id] # [intermediate_size, d_model]
133
+ [expert_w1, expert_w3],
134
+ axis=-1) # [hidden_size, 2 * intermediate_size]
135
+ expert_weight_2 = w2[expert_id].astype(
136
+ jnp.float32) # [intermediate_size, hidden_size]
137
+ if w2_scale is not None:
138
+ expert_weight_2 *= jnp.repeat(w2_scale[expert_id, :, 0],
139
+ subc_quant_wsz,
140
+ axis=0)[:intermediate_size]
72
141
 
73
142
  # First linear layer with SwiGLU activation
74
143
  gmm_1_out = curr_token @ expert_weight_1 # [1, 2 * intermediate_size]
@@ -77,37 +146,34 @@ def ref_moe(
77
146
  gmm1_w1_proj, gmm1_w3_proj = jnp.split(
78
147
  gmm_1_out, 2,
79
148
  axis=-1) # [1, intermediate_size], [1, intermediate_size]
149
+ if b1 is not None:
150
+ gmm1_w1_proj += b1[expert_id:expert_id + 1, 0, 0]
151
+ gmm1_w3_proj += b1[expert_id:expert_id + 1, 1, 0]
80
152
 
81
153
  # Apply gated activation: activation(gate) * up
82
- if activation == "silu":
83
- act = jax.nn.silu(
84
- gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
85
- elif activation == "gelu":
86
- act = jax.nn.gelu(
87
- gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
88
- else:
89
- raise ValueError(
90
- f"Unsupported activation: {activation}. Use 'silu' or 'gelu'."
91
- )
154
+ act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, act_fn)
92
155
 
93
156
  # Second linear layer (down projection)
94
- gmm_2_out = act @ expert_weight_2 # [1, d_model]
157
+ gmm_2_out = act @ expert_weight_2 # [1, hidden_size]
158
+ if b2 is not None:
159
+ gmm_2_out += b2[expert_id:expert_id + 1, 0]
95
160
  tok_expert_act.append(gmm_2_out)
96
161
 
97
162
  # Combine outputs from all selected experts
98
163
  experts_act = jnp.concatenate(tok_expert_act,
99
- axis=0) # [top_k, d_model]
164
+ axis=0) # [top_k, hidden_size]
100
165
 
101
166
  # Weighted sum using top-k gating weights
102
167
  top_k_weights = top_k_logits[i] # [top_k]
103
168
  top_k_weights = jnp.expand_dims(top_k_weights, axis=1) # [top_k, 1]
104
169
  weighted_output = jnp.sum(experts_act * top_k_weights,
105
170
  axis=0,
106
- keepdims=True) # [1, d_model]
171
+ keepdims=True) # [1, hidden_size]
107
172
 
108
- t_outputs.append(weighted_output)
173
+ t_outputs.append(weighted_output.astype(tokens.dtype))
109
174
 
110
- return jnp.concatenate(t_outputs, axis=0) # [num_tokens, d_model]
175
+ return jnp.concatenate(t_outputs,
176
+ axis=0) # [actual_num_tokens, hidden_size]
111
177
 
112
178
 
113
179
  def _fused_ep_moe_kernel(
@@ -115,12 +181,19 @@ def _fused_ep_moe_kernel(
115
181
  tokens_hbm, # (local_num_tokens, t_packing, hidden_size // t_packing)
116
182
  w1_hbm, # (local_num_experts, 2, hidden_size, intermediate_size)
117
183
  w2_hbm, # (local_num_experts, intermediate_size, hidden_size)
184
+ # TODO(jevinjiang): We choose F32 scale for easier slicing. The extra
185
+ # latency should be hidden in the pipeline overlaping. But is there a better
186
+ # way to do this?
187
+ w1_scale_hbm, # None | F32(local_num_experts, 2, cdiv(hidden_size, subc_quant_wsz), 1, intermediate_size)
188
+ w2_scale_hbm, # None | F32(local_num_experts, cdiv(intermediate_size, subc_quant_wsz), 1, hidden_size)
189
+ b1_hbm, # None | F32(local_num_experts, 2, 1, intermediate_size)
190
+ b2_hbm, # None | F32(local_num_experts, 1, hidden_size)
118
191
  gating_hbm, # (local_num_tokens, padded_num_experts)
119
192
  a2a_g_hbm, # (num_experts, bt, t_packing, hidden_size // t_packing)
120
193
  # Output
121
194
  output_hbm, # (local_num_tokens, hidden_size)
122
195
  # Scratch
123
- t2e_routing_x2_smem, # <bt_sem_id> (2, bt, padded_num_experts)
196
+ t2e_routing_x2_smem, # <bt_sem_id> (2, bt, padded_top_k)
124
197
  d2e_count_x2_smem, # <bt_sem_id> (2, num_devices, 1, padded_num_experts)
125
198
  expert_offsets_x2_smem, # <bt_sem_id> (2, 2, padded_num_experts): for a2a_s and a2a_g
126
199
  expert_starts_x2_smem, # <bt_sem_id> (2, 1, padded_num_experts)
@@ -136,6 +209,12 @@ def _fused_ep_moe_kernel(
136
209
  b_w1_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
137
210
  b_w3_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
138
211
  b_w2_x2_vmem, # <bw_sem_id> (2, t_packing, bf, bd2 // t_packing)
212
+ b_w1_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
213
+ b_w3_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
214
+ b_w2_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bf // subc_quant_wsz, 1, bd2 // t_packing)
215
+ b_b1_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
216
+ b_b3_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
217
+ b_b2_x2_vmem, # None | <bw_sem_id> (2, t_packing, 1, bd2 // t_packing)
139
218
  b_acc_vmem, # F32(bt * num_devices, 1, bf * 2)
140
219
  ### Semaphores:
141
220
  local_sems, # (2, 5): 2 x [b_gating_sem, b_w1_sem, b_w2_sem, b_w3_sem, b_output_sem]
@@ -145,7 +224,10 @@ def _fused_ep_moe_kernel(
145
224
  a2a_acc_sem,
146
225
  *,
147
226
  top_k: int,
227
+ renormalize_topk_logits: bool,
148
228
  ep_axis_name: str,
229
+ act_fn: str,
230
+ subc_quant_wsz: int | None = None,
149
231
  # Kernel tuning params.
150
232
  bt: int, # Block size of local_num_tokens.
151
233
  bf: int, # Block size of intermediate_size.
@@ -160,34 +242,58 @@ def _fused_ep_moe_kernel(
160
242
  num_devices = lax.axis_size(ep_axis_name)
161
243
  local_num_tokens = tokens_hbm.shape[0]
162
244
  local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
163
- # num_experts = local_num_experts * num_devices
164
- # padded_num_experts = expert_starts_x2_smem.shape[-1]
165
245
  right_id = (my_id + 1) % num_devices
246
+ num_experts = a2a_g_hbm.shape[0]
247
+ padded_num_experts = d2e_count_x2_smem.shape[-1]
248
+ padded_top_k = t2e_routing_x2_smem.shape[-1]
249
+ assert padded_num_experts == align_to(num_experts, 128)
250
+ assert padded_top_k == align_to(top_k, 128)
166
251
 
167
252
  t_dtype = tokens_hbm.dtype
168
253
  t_packing = get_dtype_packing(t_dtype)
169
254
  t_bitwidth = 32 // t_packing
170
255
  assert a2a_g_hbm.dtype == t_dtype
171
- assert w1_hbm.dtype == t_dtype
172
- assert w2_hbm.dtype == t_dtype
256
+ assert w1_hbm.dtype == w2_hbm.dtype
173
257
 
174
- h_per_packing = hidden_size // t_packing
175
- assert tokens_hbm.shape[-1] == h_per_packing
176
- bd1_per_packing = bd1 // t_packing
177
- bd2_per_packing = bd2 // t_packing
178
- bd1c_per_packing = bd1c // t_packing
179
- bd2c_per_packing = bd2c // t_packing
258
+ assert bd1 % bd1c == 0
259
+ assert bd2 % bd2c == 0
260
+ assert bf % bfc == 0
261
+ assert hidden_size % t_packing == 0
262
+ assert bd1 % t_packing == 0
263
+ assert bd2 % t_packing == 0
264
+ assert bd1c % t_packing == 0
265
+ assert bd2c % t_packing == 0
266
+
267
+ h_per_t_packing = hidden_size // t_packing
268
+ assert tokens_hbm.shape[-1] == h_per_t_packing
269
+ bd1_per_t_packing = bd1 // t_packing
270
+ bd2_per_t_packing = bd2 // t_packing
271
+ bd1c_per_t_packing = bd1c // t_packing
272
+ bd2c_per_t_packing = bd2c // t_packing
273
+
274
+ if subc_quant_wsz is not None:
275
+ assert subc_quant_wsz % 256 == 0
276
+ assert bd1c_per_t_packing == subc_quant_wsz
277
+ assert bfc == subc_quant_wsz
278
+ assert bd1 % subc_quant_wsz == 0
279
+ assert bf % subc_quant_wsz == 0
280
+ assert bd1_per_t_packing % subc_quant_wsz == 0
281
+ assert h_per_t_packing % subc_quant_wsz == 0
180
282
 
181
283
  num_bt = cdiv(local_num_tokens, bt)
182
284
  num_bf = cdiv(intermediate_size, bf)
183
285
  num_bd1 = cdiv(hidden_size, bd1)
184
286
  num_bd2 = cdiv(hidden_size, bd2)
185
287
 
288
+ def get_mesh_device_id(ep_rank):
289
+ dp_rank = jax.lax.axis_index("data")
290
+ return (dp_rank, ep_rank)
291
+
186
292
  def sync_barrier():
187
293
  barrier_sem = pltpu.get_barrier_semaphore()
188
294
  pltpu.semaphore_signal(
189
295
  barrier_sem,
190
- device_id=(0, right_id),
296
+ device_id=get_mesh_device_id(right_id),
191
297
  device_id_type=pltpu.DeviceIdType.MESH,
192
298
  )
193
299
  pltpu.semaphore_wait(barrier_sem, 1)
@@ -212,30 +318,44 @@ def _fused_ep_moe_kernel(
212
318
  sem=b_gating_sem,
213
319
  ).wait()
214
320
 
215
- def get_top_k(input, top_k):
321
+ def get_top_k(input, top_k, renormalize_topk_logits):
216
322
  assert len(input.shape) == 2, input.shape
217
323
  input = input.astype(jnp.float32)
324
+ padded_k_shape = (input.shape[0], padded_top_k)
218
325
  top_k_logits_lst = []
219
326
  top_k_indices_lst = []
220
327
  t2e = jnp.zeros(input.shape, dtype=jnp.int32)
221
- t2e_routing = jnp.zeros(input.shape, dtype=jnp.int32)
328
+ t2e_routing = jnp.zeros(padded_k_shape, dtype=jnp.int32)
222
329
  iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1)
330
+ padded_k_iota = jax.lax.broadcasted_iota(jnp.int32, padded_k_shape, 1)
331
+ top_k_logits_sum = jnp.zeros(padded_k_shape, jnp.float32)
332
+
223
333
  for k_id in range(top_k):
224
- # TODO(jevinjiang): return both top_k values and indices in op in Mosaic
334
+ # TODO(jevinjiang): return both top_k values and indices in Mosaic
225
335
  top_k_logits = jnp.broadcast_to(
226
- jnp.max(input, axis=1, keepdims=True),
227
- (input.shape[0], 128)).astype(input.dtype)
336
+ jnp.max(input[:, :num_experts], axis=1, keepdims=True),
337
+ padded_k_shape,
338
+ ).astype(input.dtype)
228
339
  top_k_logits_lst.append(top_k_logits)
340
+ if renormalize_topk_logits:
341
+ top_k_logits_sum += top_k_logits
229
342
  # TODO(jevinjiang): support bf16 argmax in Mosaic
230
343
  top_k_indices = jnp.broadcast_to(
231
- jnp.argmax(input, axis=1, keepdims=True), input.shape)
344
+ jnp.argmax(input[:, :num_experts], axis=1, keepdims=True),
345
+ padded_k_shape,
346
+ )
232
347
  top_k_indices_lst.append(top_k_indices)
233
- t2e_routing = jnp.where(iota == k_id, top_k_indices, t2e_routing)
234
- mask = iota == top_k_indices
348
+ t2e_routing = jnp.where(padded_k_iota == k_id, top_k_indices,
349
+ t2e_routing)
350
+ mask = iota == broadcast_minor(top_k_indices, input.shape)
235
351
  t2e += mask.astype(jnp.int32)
236
352
  if k_id != top_k - 1:
237
353
  input = jnp.where(mask, -jnp.inf, input)
238
354
 
355
+ if renormalize_topk_logits:
356
+ for k_id in range(top_k):
357
+ top_k_logits_lst[k_id] /= top_k_logits_sum
358
+
239
359
  expert_sizes = jnp.sum(t2e, axis=0, keepdims=True)
240
360
  expert_starts = jnp.zeros_like(expert_sizes)
241
361
  return top_k_logits_lst, t2e_routing, expert_sizes, expert_starts
@@ -277,7 +397,7 @@ def _fused_ep_moe_kernel(
277
397
  dst_ref=d2e_count_vmem.at[row_id],
278
398
  send_sem=send_sem,
279
399
  recv_sem=recv_sem,
280
- device_id=(0, right_id),
400
+ device_id=get_mesh_device_id(right_id),
281
401
  device_id_type=pltpu.DeviceIdType.MESH,
282
402
  ).wait()
283
403
  row_id = (row_id + num_devices - 1) % num_devices
@@ -359,10 +479,8 @@ def _fused_ep_moe_kernel(
359
479
  pl.ds(start, remote_sz)],
360
480
  send_sem=send_sems.at[e_sem_id],
361
481
  recv_sem=recv_sems.at[e_sem_id],
362
- device_id=(
363
- 0,
364
- recv_id,
365
- ),
482
+ device_id=get_mesh_device_id(recv_id),
483
+ device_id_type=pltpu.DeviceIdType.MESH,
366
484
  ).start()
367
485
  a2a_s_sends_x2_smem[e_sem_id] = send_sz
368
486
 
@@ -406,7 +524,8 @@ def _fused_ep_moe_kernel(
406
524
  dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)],
407
525
  send_sem=send_sems.at[e_sem_id],
408
526
  recv_sem=a2a_gather_sem,
409
- device_id=(0, recv_id),
527
+ device_id=get_mesh_device_id(recv_id),
528
+ device_id_type=pltpu.DeviceIdType.MESH,
410
529
  ).start()
411
530
  start += sz
412
531
 
@@ -435,68 +554,173 @@ def _fused_ep_moe_kernel(
435
554
 
436
555
  def start_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
437
556
  for p in range(t_packing):
438
- offset = p * h_per_packing + bd1_id * bd1_per_packing
557
+ offset = p * h_per_t_packing + bd1_id * bd1_per_t_packing
439
558
  pltpu.make_async_copy(
440
559
  src_ref=w1_hbm.at[
441
560
  local_e_id,
442
561
  0,
443
- pl.ds(offset, bd1_per_packing),
562
+ pl.ds(offset, bd1_per_t_packing),
444
563
  pl.ds(bf_id * bf, bf),
445
564
  ],
446
565
  dst_ref=b_w1_x2_vmem.at[bw1_sem_id, p],
447
566
  sem=local_sems.at[bw1_sem_id, 1],
448
567
  ).start()
568
+ if w1_scale_hbm is not None:
569
+ assert subc_quant_wsz is not None
570
+ pltpu.make_async_copy(
571
+ src_ref=w1_scale_hbm.at[
572
+ local_e_id,
573
+ 0,
574
+ pl.ds(
575
+ offset // subc_quant_wsz,
576
+ bd1_per_t_packing // subc_quant_wsz,
577
+ ),
578
+ pl.ds(0, 1),
579
+ pl.ds(bf_id * bf, bf),
580
+ ],
581
+ dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id, p],
582
+ sem=local_sems.at[bw1_sem_id, 1],
583
+ ).start()
584
+ if b1_hbm is not None and bd1_id == 0:
585
+ pltpu.make_async_copy(
586
+ src_ref=b1_hbm.at[local_e_id, 0,
587
+ pl.ds(0, 1),
588
+ pl.ds(bf_id * bf, bf)],
589
+ dst_ref=b_b1_x2_vmem.at[bf_id % 2],
590
+ sem=local_sems.at[bw1_sem_id, 1],
591
+ ).start()
449
592
 
450
593
  def start_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
451
594
  for p in range(t_packing):
452
- offset = p * h_per_packing + bd2_id * bd2_per_packing
595
+ offset = p * h_per_t_packing + bd2_id * bd2_per_t_packing
453
596
  pltpu.make_async_copy(
454
597
  src_ref=w2_hbm.at[
455
598
  local_e_id,
456
599
  pl.ds(bf_id * bf, bf),
457
- pl.ds(offset, bd2_per_packing),
600
+ pl.ds(offset, bd2_per_t_packing),
458
601
  ],
459
602
  dst_ref=b_w2_x2_vmem.at[bw2_sem_id, p],
460
603
  sem=local_sems.at[bw2_sem_id, 2],
461
604
  ).start()
605
+ if w2_scale_hbm is not None:
606
+ assert subc_quant_wsz is not None
607
+ pltpu.make_async_copy(
608
+ src_ref=w2_scale_hbm.at[
609
+ local_e_id,
610
+ pl.ds(bf_id * bf // subc_quant_wsz, bf //
611
+ subc_quant_wsz),
612
+ pl.ds(0, 1),
613
+ pl.ds(offset, bd2_per_t_packing),
614
+ ],
615
+ dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id, p],
616
+ sem=local_sems.at[bw2_sem_id, 2],
617
+ ).start()
618
+ if b2_hbm is not None and bf_id == 0:
619
+ pltpu.make_async_copy(
620
+ src_ref=b2_hbm.at[local_e_id,
621
+ pl.ds(0, 1),
622
+ pl.ds(offset, bd2_per_t_packing)],
623
+ dst_ref=b_b2_x2_vmem.at[bd2_id % 2, p],
624
+ sem=local_sems.at[bw2_sem_id, 2],
625
+ ).start()
462
626
 
463
627
  def start_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
464
628
  for p in range(t_packing):
465
- offset = p * h_per_packing + bd3_id * bd1_per_packing
629
+ offset = p * h_per_t_packing + bd3_id * bd1_per_t_packing
466
630
  pltpu.make_async_copy(
467
631
  src_ref=w1_hbm.at[
468
632
  local_e_id,
469
633
  1,
470
- pl.ds(offset, bd1_per_packing),
634
+ pl.ds(offset, bd1_per_t_packing),
471
635
  pl.ds(bf_id * bf, bf),
472
636
  ],
473
637
  dst_ref=b_w3_x2_vmem.at[bw3_sem_id, p],
474
638
  sem=local_sems.at[bw3_sem_id, 3],
475
639
  ).start()
640
+ if w1_scale_hbm is not None:
641
+ assert subc_quant_wsz is not None
642
+ pltpu.make_async_copy(
643
+ src_ref=w1_scale_hbm.at[
644
+ local_e_id,
645
+ 1,
646
+ pl.ds(
647
+ offset // subc_quant_wsz,
648
+ bd1_per_t_packing // subc_quant_wsz,
649
+ ),
650
+ pl.ds(0, 1),
651
+ pl.ds(bf_id * bf, bf),
652
+ ],
653
+ dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id, p],
654
+ sem=local_sems.at[bw3_sem_id, 3],
655
+ ).start()
656
+ if b1_hbm is not None and bd3_id == 0:
657
+ pltpu.make_async_copy(
658
+ src_ref=b1_hbm.at[local_e_id, 1,
659
+ pl.ds(0, 1),
660
+ pl.ds(bf_id * bf, bf)],
661
+ dst_ref=b_b3_x2_vmem.at[bf_id % 2],
662
+ sem=local_sems.at[bw3_sem_id, 3],
663
+ ).start()
476
664
 
477
665
  def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
478
- del local_e_id, bf_id, bd1_id
666
+ del local_e_id
479
667
  pltpu.make_async_copy(
480
668
  src_ref=b_w1_x2_vmem.at[bw1_sem_id],
481
669
  dst_ref=b_w1_x2_vmem.at[bw1_sem_id],
482
670
  sem=local_sems.at[bw1_sem_id, 1],
483
671
  ).wait()
672
+ if w1_scale_hbm is not None:
673
+ pltpu.make_async_copy(
674
+ src_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
675
+ dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
676
+ sem=local_sems.at[bw1_sem_id, 1],
677
+ ).wait()
678
+ if b1_hbm is not None and bd1_id == 0:
679
+ pltpu.make_async_copy(
680
+ src_ref=b_b1_x2_vmem.at[bf_id % 2],
681
+ dst_ref=b_b1_x2_vmem.at[bf_id % 2],
682
+ sem=local_sems.at[bw1_sem_id, 1],
683
+ ).wait()
484
684
 
485
685
  def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
486
- del local_e_id, bf_id, bd2_id
686
+ del local_e_id
487
687
  pltpu.make_async_copy(
488
688
  src_ref=b_w2_x2_vmem.at[bw2_sem_id],
489
689
  dst_ref=b_w2_x2_vmem.at[bw2_sem_id],
490
690
  sem=local_sems.at[bw2_sem_id, 2],
491
691
  ).wait()
692
+ if w2_scale_hbm is not None:
693
+ pltpu.make_async_copy(
694
+ src_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
695
+ dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
696
+ sem=local_sems.at[bw2_sem_id, 2],
697
+ ).wait()
698
+ if b2_hbm is not None and bf_id == 0:
699
+ pltpu.make_async_copy(
700
+ src_ref=b_b2_x2_vmem.at[bd2_id % 2],
701
+ dst_ref=b_b2_x2_vmem.at[bd2_id % 2],
702
+ sem=local_sems.at[bw2_sem_id, 2],
703
+ ).wait()
492
704
 
493
705
  def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
494
- del local_e_id, bf_id, bd3_id
706
+ del local_e_id
495
707
  pltpu.make_async_copy(
496
708
  src_ref=b_w3_x2_vmem.at[bw3_sem_id],
497
709
  dst_ref=b_w3_x2_vmem.at[bw3_sem_id],
498
710
  sem=local_sems.at[bw3_sem_id, 3],
499
711
  ).wait()
712
+ if w1_scale_hbm is not None:
713
+ pltpu.make_async_copy(
714
+ src_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
715
+ dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
716
+ sem=local_sems.at[bw3_sem_id, 3],
717
+ ).wait()
718
+ if b1_hbm is not None and bd3_id == 0:
719
+ pltpu.make_async_copy(
720
+ src_ref=b_b3_x2_vmem.at[bf_id % 2],
721
+ dst_ref=b_b3_x2_vmem.at[bf_id % 2],
722
+ sem=local_sems.at[bw3_sem_id, 3],
723
+ ).wait()
500
724
 
501
725
  def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id):
502
726
  next_bd1_id = bd1_id + 1
@@ -520,18 +744,38 @@ def _fused_ep_moe_kernel(
520
744
  def dynamic_ffn1(
521
745
  t_b32_vmem,
522
746
  w1_vmem,
747
+ w1_scale_vmem,
748
+ b1_vmem,
523
749
  w3_vmem,
750
+ w3_scale_vmem,
751
+ b3_vmem,
524
752
  acc1_vmem,
525
753
  acc3_vmem,
526
754
  dyn_sz,
527
755
  should_init,
528
756
  ):
529
757
  assert t_b32_vmem.shape == (bt * num_devices, bd1 // t_packing)
530
- assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_packing,
758
+ assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_t_packing,
531
759
  bf)
532
760
  assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
533
761
  assert bd1 % (t_packing * 128) == 0, (bd1, t_packing)
534
762
  assert bd1c % (t_packing * 128) == 0, (bd1c, t_packing)
763
+ if w1_scale_vmem is not None:
764
+ assert w1_scale_vmem.shape == (
765
+ t_packing,
766
+ bd1_per_t_packing // subc_quant_wsz,
767
+ 1,
768
+ bf,
769
+ )
770
+ assert bd1c_per_t_packing == subc_quant_wsz
771
+ if w3_scale_vmem is not None:
772
+ assert w3_scale_vmem.shape == (
773
+ t_packing,
774
+ bd1_per_t_packing // subc_quant_wsz,
775
+ 1,
776
+ bf,
777
+ )
778
+ assert bd1c_per_t_packing == subc_quant_wsz
535
779
 
536
780
  num_loops = cdiv(dyn_sz, btc)
537
781
  repack_ty = jnp.dtype(f"int{t_bitwidth}")
@@ -540,7 +784,7 @@ def _fused_ep_moe_kernel(
540
784
  for bd1c_id in range(cdiv(bd1, bd1c)):
541
785
  t_b32 = t_b32_vmem[
542
786
  pl.ds(btc_id * btc, btc),
543
- pl.ds(bd1c_id * bd1c_per_packing, bd1c_per_packing),
787
+ pl.ds(bd1c_id * bd1c_per_t_packing, bd1c_per_t_packing),
544
788
  ]
545
789
  for p_id in range(t_packing):
546
790
  t = pltpu.bitcast(t_b32.astype(repack_ty), t_dtype)
@@ -548,21 +792,64 @@ def _fused_ep_moe_kernel(
548
792
  for bfc_id in range(cdiv(bf, bfc)):
549
793
  w_slices = (
550
794
  p_id,
551
- pl.ds(bd1c_id * bd1c_per_packing,
552
- bd1c_per_packing),
795
+ pl.ds(bd1c_id * bd1c_per_t_packing,
796
+ bd1c_per_t_packing),
553
797
  pl.ds(bfc_id * bfc, bfc),
554
798
  )
555
799
  w1 = w1_vmem[*w_slices]
556
800
  acc1 = jnp.dot(t,
557
801
  w1,
558
802
  preferred_element_type=jnp.float32)
803
+
804
+ if w1_scale_vmem is not None:
805
+ w1_scale_slices = (
806
+ p_id,
807
+ bd1c_id,
808
+ pl.ds(0, 1),
809
+ pl.ds(bfc_id * bfc, bfc),
810
+ )
811
+ # TODO(jevinjiang): can use mosaic to load with stride 0.
812
+ w1_scale = jnp.broadcast_to(
813
+ w1_scale_vmem[*w1_scale_slices], acc1.shape)
814
+ acc1 *= w1_scale
815
+
559
816
  w3 = w3_vmem[*w_slices]
817
+
560
818
  acc3 = jnp.dot(t,
561
819
  w3,
562
820
  preferred_element_type=jnp.float32)
821
+
822
+ if w3_scale_vmem is not None:
823
+ w3_scale_slices = (
824
+ p_id,
825
+ bd1c_id,
826
+ pl.ds(0, 1),
827
+ pl.ds(bfc_id * bfc, bfc),
828
+ )
829
+ w3_scale = jnp.broadcast_to(
830
+ w3_scale_vmem[*w3_scale_slices], acc3.shape)
831
+ acc3 *= w3_scale
832
+
563
833
  acc_slices = (pl.ds(btc_id * btc,
564
834
  btc), pl.ds(bfc_id * bfc, bfc))
565
835
  if should_init and p_id == bd1c_id == 0:
836
+ if b1_vmem is not None:
837
+ b1_scale_slices = (
838
+ pl.ds(0, 1),
839
+ pl.ds(bfc_id * bfc, bfc),
840
+ )
841
+ b1 = jnp.broadcast_to(
842
+ b1_vmem[*b1_scale_slices], acc1.shape)
843
+ acc1 += b1
844
+ if b3_vmem is not None:
845
+ b3_scale_slices = (
846
+ pl.ds(0, 1),
847
+ pl.ds(bfc_id * bfc, bfc),
848
+ )
849
+ b3 = jnp.broadcast_to(
850
+ b3_vmem[*b3_scale_slices], acc1.shape)
851
+ acc3 += b3
852
+
566
853
  acc1_vmem[*acc_slices] = acc1
567
854
  acc3_vmem[*acc_slices] = acc3
568
855
  else:
@@ -575,22 +862,28 @@ def _fused_ep_moe_kernel(
575
862
  acc1_vmem,
576
863
  acc3_vmem,
577
864
  w2_vmem,
865
+ w2_scale_vmem,
866
+ b2_vmem,
578
867
  res_b32_vmem,
579
868
  dyn_sz,
580
869
  should_init,
581
870
  ):
582
- assert res_b32_vmem.shape == (bt * num_devices, bd2_per_packing)
583
- assert w2_vmem.shape == (t_packing, bf, bd2_per_packing), (
584
- w2_vmem.shape,
585
- t_packing,
586
- bf,
587
- bd2_per_packing,
588
- )
871
+ assert res_b32_vmem.shape == (bt * num_devices, bd2_per_t_packing)
872
+ assert w2_vmem.shape == (t_packing, bf, bd2_per_t_packing)
589
873
  assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
590
874
  assert bd2 % (t_packing * 128) == 0, (bd2, t_packing)
591
875
  assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
592
876
  assert t_dtype in (jnp.float32, jnp.bfloat16)
593
877
 
878
+ if w2_scale_vmem is not None:
879
+ assert w2_scale_vmem.shape == (
880
+ t_packing,
881
+ bf // subc_quant_wsz,
882
+ 1,
883
+ bd2_per_t_packing,
884
+ )
885
+ assert bfc == subc_quant_wsz
886
+
594
887
  num_loops = cdiv(dyn_sz, btc)
595
888
  assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
596
889
 
@@ -598,22 +891,47 @@ def _fused_ep_moe_kernel(
598
891
  for bd2c_id in range(cdiv(bd2, bd2c)):
599
892
  res_lst = []
600
893
  for p_id in range(t_packing):
601
- res = jnp.zeros((btc, bd2c_per_packing), dtype=jnp.float32)
894
+ res = jnp.zeros((btc, bd2c_per_t_packing),
895
+ dtype=jnp.float32)
896
+
897
+ if b2_vmem is not None and should_init:
898
+ b2_scale_slices = (
899
+ p_id,
900
+ pl.ds(0, 1),
901
+ pl.ds(bd2c_id * bd2c_per_t_packing,
902
+ bd2c_per_t_packing),
903
+ )
904
+ b2 = jnp.broadcast_to(b2_vmem[*b2_scale_slices],
905
+ res.shape)
906
+ res += b2
907
+
602
908
  for bfc_id in range(cdiv(bf, bfc)):
603
909
  acc_slices = (pl.ds(btc_id * btc,
604
910
  btc), pl.ds(bfc_id * bfc, bfc))
605
911
  acc1 = acc1_vmem[*acc_slices]
606
912
  acc3 = acc3_vmem[*acc_slices]
607
- act = jax.nn.silu(acc1) * acc3
913
+ act = activation_fn(acc1, acc3, act_fn)
608
914
  w2 = w2_vmem[
609
915
  p_id,
610
916
  pl.ds(bfc_id * bfc, bfc),
611
917
  pl.ds(bd2c_id *
612
- bd2c_per_packing, bd2c_per_packing),
918
+ bd2c_per_t_packing, bd2c_per_t_packing),
613
919
  ]
614
- res += jnp.dot(act,
615
- w2,
616
- preferred_element_type=jnp.float32)
920
+ acc = jnp.dot(act,
921
+ w2,
922
+ preferred_element_type=jnp.float32)
923
+ if w2_scale_vmem is not None:
924
+ w2_scale_slices = (
925
+ p_id,
926
+ bfc_id,
927
+ pl.ds(0, 1),
928
+ pl.ds(bd2c_id * bd2c_per_t_packing,
929
+ bd2c_per_t_packing),
930
+ )
931
+ w2_scale = jnp.broadcast_to(
932
+ w2_scale_vmem[*w2_scale_slices], acc.shape)
933
+ acc *= w2_scale
934
+ res += acc
617
935
  res = pltpu.bitcast(res, jnp.uint32)
618
936
  if t_packing == 2:
619
937
  res = res >> 16 << (16 * p_id)
@@ -626,7 +944,7 @@ def _fused_ep_moe_kernel(
626
944
  res |= res_lst[i]
627
945
  sliced_res_vmem = res_b32_vmem.at[
628
946
  pl.ds(btc_id * btc, btc),
629
- pl.ds(bd2c_id * bd2c_per_packing, bd2c_per_packing),
947
+ pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing),
630
948
  ]
631
949
  if should_init:
632
950
  sliced_res_vmem[...] = res
@@ -655,21 +973,33 @@ def _fused_ep_moe_kernel(
655
973
  e_id = my_id * local_num_experts + local_e_id
656
974
  dyn_sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
657
975
 
658
- bd1_per_packing = bd1 // t_packing
659
- bd2_per_packing = bd2 // t_packing
976
+ bd1_per_t_packing = bd1 // t_packing
977
+ bd2_per_t_packing = bd2 // t_packing
660
978
 
661
979
  for bf_id in range(num_bf):
662
980
  for bd1_id in range(num_bd1):
663
981
  start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, 0)
982
+ w1_scale_vmem = (None if b_w1_scale_x2_vmem is None else
983
+ b_w1_scale_x2_vmem.at[bw_sem_id])
984
+ w3_scale_vmem = (None if b_w3_scale_x2_vmem is None else
985
+ b_w3_scale_x2_vmem.at[bw_sem_id])
986
+ b1_vmem = None if b_b1_x2_vmem is None else b_b1_x2_vmem.at[
987
+ bf_id % 2]
988
+ b3_vmem = None if b_b3_x2_vmem is None else b_b3_x2_vmem.at[
989
+ bf_id % 2]
664
990
  wait_fetch_bw1(local_e_id, bw_sem_id, bf_id, bd1_id)
665
991
  wait_fetch_bw3(local_e_id, bw_sem_id, bf_id, bd1_id)
666
992
 
667
993
  dynamic_ffn1(
668
994
  t_b32_vmem=a2a_s_b32_vmem.at[
669
995
  ...,
670
- pl.ds(bd1_id * bd1_per_packing, bd1_per_packing)],
996
+ pl.ds(bd1_id * bd1_per_t_packing, bd1_per_t_packing)],
671
997
  w1_vmem=b_w1_x2_vmem.at[bw_sem_id],
998
+ w1_scale_vmem=w1_scale_vmem,
999
+ b1_vmem=b1_vmem,
672
1000
  w3_vmem=b_w3_x2_vmem.at[bw_sem_id],
1001
+ w3_scale_vmem=w3_scale_vmem,
1002
+ b3_vmem=b3_vmem,
673
1003
  acc1_vmem=b_acc1_vmem,
674
1004
  acc3_vmem=b_acc3_vmem,
675
1005
  dyn_sz=dyn_sz,
@@ -684,13 +1014,19 @@ def _fused_ep_moe_kernel(
684
1014
  if bf_id == bd2_id == 0:
685
1015
  wait_a2a_gather_send(bt_id, e_sem_id, local_e_id - 2)
686
1016
 
1017
+ w2_scale_vmem = (None if b_w2_scale_x2_vmem is None else
1018
+ b_w2_scale_x2_vmem.at[bw_sem_id])
1019
+ b2_vmem = None if b_b2_x2_vmem is None else b_b2_x2_vmem.at[
1020
+ bd2_id % 2]
687
1021
  dynamic_ffn2(
688
1022
  acc1_vmem=b_acc1_vmem,
689
1023
  acc3_vmem=b_acc3_vmem,
690
1024
  w2_vmem=b_w2_x2_vmem.at[bw_sem_id],
1025
+ w2_scale_vmem=w2_scale_vmem,
1026
+ b2_vmem=b2_vmem,
691
1027
  res_b32_vmem=a2a_s_acc_b32_vmem.at[
692
1028
  ...,
693
- pl.ds(bd2_id * bd2_per_packing, bd2_per_packing)],
1029
+ pl.ds(bd2_id * bd2_per_t_packing, bd2_per_t_packing)],
694
1030
  dyn_sz=dyn_sz,
695
1031
  should_init=(bf_id == 0),
696
1032
  )
@@ -757,31 +1093,42 @@ def _fused_ep_moe_kernel(
757
1093
  b_gating = b_gating_x2_vmem[bt_sem_id]
758
1094
  b_gating_score = jax.nn.softmax(b_gating, axis=-1)
759
1095
  top_k_logits_lst, t2e_routing, expert_sizes, expert_starts = get_top_k(
760
- b_gating_score, top_k)
1096
+ b_gating_score, top_k, renormalize_topk_logits)
761
1097
 
762
1098
  all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts,
763
1099
  expert_sizes)
1100
+ sync_barrier()
764
1101
 
1102
+ # Start a2a scatter for first active expert.
765
1103
  start_a2a_scatter(bt_id=bt_id, e_sem_id=e_sem_id, local_e_id=0)
766
1104
 
767
1105
  def run_per_expert(local_e_id, e_sem_id):
768
1106
  sync_barrier()
1107
+
1108
+ # Prefetch weights for CURRENT active expert.
1109
+ # TODO(jevinjiang): It is hard to prefetch weights in previous iteration
1110
+ # because the expert_ffn keeps overwriting the buffers. Triple buffering
1111
+ # could resolve this but it takes more VMEM scratch. Need further
1112
+ # experiment on this.
1113
+ start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
1114
+ start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
1115
+
1116
+ # Next ids.
769
1117
  next_e_sem_id = lax.select(e_sem_id == 0, 1, 0)
770
1118
  next_local_e_id = local_e_id + 1
771
1119
 
1120
+ # Start a2a scatter for NEXT active expert.
772
1121
  @pl.when(next_local_e_id < local_num_experts)
773
1122
  def _():
774
1123
  start_a2a_scatter(bt_id, next_e_sem_id, next_local_e_id)
775
1124
 
776
- # Prefetch weights for active expert.
777
- start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
778
- start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
779
-
780
- # Wait for a2a scatter and perform FFN for active expert.
1125
+ # Wait a2a scatter for CURRENT active expert.
781
1126
  wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id)
1127
+
1128
+ # Perform FFN for CURRENT active expert.
782
1129
  expert_ffn(bt_id, e_sem_id, local_e_id)
783
1130
 
784
- # Wait for a2a gather to send back tokens for active expert.
1131
+ # Start a2a gather to send back tokens for CURRENT active expert.
785
1132
  start_a2a_gather(bt_id, e_sem_id, local_e_id)
786
1133
 
787
1134
  # A must-wait before next sync_barrier.
@@ -794,7 +1141,10 @@ def _fused_ep_moe_kernel(
794
1141
  e_sem_id,
795
1142
  unroll=False)
796
1143
 
1144
+ # Wait to receive a2a gather for ALL experts.
797
1145
  wait_a2a_gather_recv_all()
1146
+
1147
+ # Accumulate results for current batch.
798
1148
  output = bt_acc(bt_id, top_k_logits_lst)
799
1149
 
800
1150
  # Make sure it is safe to overwrite output buffer.
@@ -827,6 +1177,9 @@ def _fused_ep_moe_kernel(
827
1177
  static_argnames=[
828
1178
  "mesh",
829
1179
  "top_k",
1180
+ "renormalize_topk_logits",
1181
+ "act_fn",
1182
+ "subc_quant_wsz",
830
1183
  "bt",
831
1184
  "bf",
832
1185
  "bd1",
@@ -846,6 +1199,17 @@ def fused_ep_moe(
846
1199
  gating_output: jax.Array, # (num_tokens, num_experts)
847
1200
  top_k: int,
848
1201
  *,
1202
+ renormalize_topk_logits: bool = False,
1203
+ act_fn: str = "silu",
1204
+ subc_quant_wsz: int | None = None,
1205
+ w1_scale: (
1206
+ jax.Array | None
1207
+ ) = None, # F32(num_experts, 2, hidden_size // subc_quant_wsz, 1, intermediate_size)
1208
+ w2_scale: (
1209
+ jax.Array | None
1210
+ ) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
1211
+ b1: jax.Array | None = None, # F32(num_experts, 2, 1, intermediate_size)
1212
+ b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
849
1213
  # Kernel tuning parameters.
850
1214
  bt: int,
851
1215
  bf: int,
@@ -855,52 +1219,164 @@ def fused_ep_moe(
855
1219
  bfc: int,
856
1220
  bd1c: int,
857
1221
  bd2c: int,
858
- ep_axis_name: str = 'model',
1222
+ ep_axis_name: str = "model",
859
1223
  ):
860
- # Assert all other axes have length of 1
861
- assert len(mesh.shape) == 2, "Expect 2D mesh in tpu-inference"
862
- assert 'data' in mesh.shape and mesh.shape['data'] == 1, \
863
- "Expect data axis size of 1 in tpu-inference"
1224
+ # TODO(jevinjiang): move all these assertions to validation function.
1225
+ if len(mesh.shape) != 2:
1226
+ raise NotImplementedError("Only 2D mesh is supported.")
1227
+
1228
+ for axis_name in mesh.axis_names:
1229
+ if axis_name == ep_axis_name:
1230
+ continue
1231
+ if mesh.shape[axis_name] != 1:
1232
+ raise NotImplementedError(
1233
+ f"Expected all non-ep axis to have size 1 in {mesh.shape=}")
864
1234
 
865
1235
  ep_size = mesh.shape[ep_axis_name]
866
1236
  num_devices = ep_size
867
1237
 
868
- num_tokens, actual_hidden_size = tokens.shape
1238
+ num_tokens, hidden_size = tokens.shape
869
1239
  num_experts, intermediate_size, _ = w2.shape
870
1240
 
871
- assert num_tokens % ep_size == 0
872
- assert num_experts % ep_size == 0
1241
+ if w1.shape != (num_experts, 2, hidden_size, intermediate_size):
1242
+ raise ValueError(
1243
+ f"Expected {w1.shape=} to be"
1244
+ f" {(num_experts, 2, hidden_size, intermediate_size)}.")
1245
+
1246
+ if w2.shape != (num_experts, intermediate_size, hidden_size):
1247
+ raise ValueError(f"Expected {w2.shape=} to be"
1248
+ f" {(num_experts, intermediate_size, hidden_size)}.")
1249
+
1250
+ if gating_output.shape != (num_tokens, num_experts):
1251
+ raise ValueError(
1252
+ f"Expected {gating_output.shape=} to be {(num_tokens, num_experts)}."
1253
+ )
1254
+
1255
+ if not (0 < top_k <= num_experts):
1256
+ raise ValueError(
1257
+ f"Expected {top_k=} to be in range (0, {num_experts=}].")
1258
+
1259
+ if hidden_size % 128 != 0 or intermediate_size % 128 != 0:
1260
+ raise ValueError(
1261
+ f"Expected {hidden_size=} and {intermediate_size=} to be aligned to"
1262
+ " 128. Did you pad them with zeros outside the kernel?")
1263
+ if num_tokens % ep_size != 0:
1264
+ raise ValueError(
1265
+ f"Expected {num_tokens=} to be aligned to {ep_size=}.")
1266
+ if num_experts % ep_size != 0:
1267
+ raise ValueError(
1268
+ f"Expected {num_experts=} to be aligned to {ep_size=}.")
873
1269
 
874
1270
  local_num_tokens = num_tokens // ep_size
875
1271
  # local_num_experts = num_experts // ep_size
876
1272
  padded_num_experts = align_to(num_experts, 128)
877
-
1273
+ padded_top_k = align_to(top_k, 128)
878
1274
  t_dtype = tokens.dtype
879
1275
  t_packing = get_dtype_packing(t_dtype)
880
- hidden_size = align_to(actual_hidden_size, 128 * t_packing)
881
- if hidden_size != actual_hidden_size:
882
- tokens = jnp.pad(
883
- tokens,
884
- ((0, 0), (0, hidden_size - actual_hidden_size)),
885
- constant_values=0,
886
- )
887
- tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
888
- bt = min(bt, local_num_tokens)
889
- bf = min(bf, intermediate_size)
890
- bd1 = min(bd1, hidden_size)
891
- bd2 = min(bd2, hidden_size)
892
-
893
- btc = min(btc, bt * num_devices)
894
- bfc = min(bfc, bf)
895
- bd1c = min(bd1c, bd1)
896
- bd2c = min(bd2c, bd2)
897
- assert bfc % 128 == 0
898
- assert bd1c % (t_packing * 128) == 0
899
- assert bd2c % (t_packing * 128) == 0
900
- assert bf % bfc == 0
901
- assert bd1 % bd1c == 0
902
- assert bd2 % bd2c == 0
903
1276
 
1277
+ # Override bt
1278
+ if local_num_tokens <= t_packing * 8:
1279
+ bt = local_num_tokens
1280
+ btc = bt
1281
+ bt = min(local_num_tokens, bt)
1282
+ # The worst case is that all devices send bt to one device.
1283
+ btc = min(bt, btc, bt * num_devices)
1284
+
1285
+ if local_num_tokens % t_packing != 0:
1286
+ raise ValueError(
1287
+ f"Expected {local_num_tokens=} to be aligned to {t_packing=}.")
1288
+
1289
+ if bt % t_packing != 0:
1290
+ raise ValueError(f"Expected {bt=} to be aligned to {t_packing=}.")
1291
+ if local_num_tokens % bt != 0:
1292
+ raise ValueError(
1293
+ f"Expected {local_num_tokens=} to be aligned to {bt=}.")
1294
+
1295
+ if subc_quant_wsz is not None:
1296
+ if subc_quant_wsz <= 0:
1297
+ raise ValueError(f"Expected {subc_quant_wsz=} to be non-negative.")
1298
+ if subc_quant_wsz % 256 != 0:
1299
+ raise ValueError(
1300
+ "Expected {subc_quant_wsz=} to be aligned to 256.")
1301
+ if hidden_size % subc_quant_wsz != 0:
1302
+ raise ValueError(
1303
+ f"Expected {hidden_size=} to be aligned to {subc_quant_wsz=}.")
1304
+ if intermediate_size % subc_quant_wsz != 0:
1305
+ raise ValueError(
1306
+ f"Expected {intermediate_size=} to be aligned to {subc_quant_wsz=}."
1307
+ )
1308
+ # We force compute size of contracting dim to be subc_quant_wsz. So we can
1309
+ # apply same scale after matmul and accumulation.
1310
+ bd1c = subc_quant_wsz * t_packing
1311
+ bfc = subc_quant_wsz
1312
+
1313
+ if bfc % 128 != 0:
1314
+ raise ValueError(f"Expected {bfc=} to be aligned to 128.")
1315
+ if bd1c % (t_packing * 128) != 0:
1316
+ raise ValueError(
1317
+ f"Expected {bd1c=} to be aligned to {t_packing * 128}.")
1318
+ if bd2c % (t_packing * 128) != 0:
1319
+ raise ValueError(
1320
+ f"Expected {bd2c=} to be aligned to {t_packing * 128}.")
1321
+ if bf % bfc != 0:
1322
+ raise ValueError(f"Expected {bf=} to be aligned to {bfc=}.")
1323
+ if bd1 % bd1c != 0:
1324
+ raise ValueError(f"Expected {bd1=} to be aligned to {bd1c=}.")
1325
+ if bd2 % bd2c != 0:
1326
+ raise ValueError(f"Expected {bd2=} to be aligned to {bd2c=}.")
1327
+ if hidden_size % bd1 != 0 or hidden_size % bd2 != 0:
1328
+ raise ValueError(
1329
+ f"Expected {hidden_size=} to be aligned to {bd1=} and {bd2=}.")
1330
+ if intermediate_size % bf != 0:
1331
+ raise ValueError(
1332
+ f"Expected {intermediate_size=} to be aligned to {bf=}.")
1333
+
1334
+ # Note: we should dump scale as the kernel expected shape in the
1335
+ # checkpoint offline or reshape right after weight loading.
1336
+ if w1_scale is not None:
1337
+ expected_w1_scale_shape = (
1338
+ num_experts,
1339
+ 2,
1340
+ hidden_size // subc_quant_wsz,
1341
+ 1,
1342
+ intermediate_size,
1343
+ )
1344
+ if w1_scale.shape != expected_w1_scale_shape:
1345
+ raise ValueError(
1346
+ f"Expected {w1_scale.shape=} to be {expected_w1_scale_shape}.")
1347
+ if w1_scale.dtype != jnp.float32:
1348
+ w1_scale = w1_scale.astype(jnp.float32)
1349
+
1350
+ if w2_scale is not None:
1351
+ expected_w2_scale_shape = (
1352
+ num_experts,
1353
+ intermediate_size // subc_quant_wsz,
1354
+ 1,
1355
+ hidden_size,
1356
+ )
1357
+ if w2_scale.shape != expected_w2_scale_shape:
1358
+ raise ValueError(
1359
+ f"Expected {w2_scale.shape=} to be {expected_w2_scale_shape}.")
1360
+ if w2_scale.dtype != jnp.float32:
1361
+ w2_scale = w2_scale.astype(jnp.float32)
1362
+
1363
+ if b1 is not None:
1364
+ expected_b1_shape = (num_experts, 2, 1, intermediate_size)
1365
+ if b1.shape != expected_b1_shape:
1366
+ raise ValueError(
1367
+ f"Expected {b1.shape=} to be {expected_b1_shape}.")
1368
+ if b1.dtype != jnp.float32:
1369
+ b1 = b1.astype(jnp.float32)
1370
+
1371
+ if b2 is not None:
1372
+ expected_b2_shape = (num_experts, 1, hidden_size)
1373
+ if b2.shape != expected_b2_shape:
1374
+ raise ValueError(
1375
+ f"Expected {b2.shape=} to be {expected_b2_shape}.")
1376
+ if b2.dtype != jnp.float32:
1377
+ b2 = b2.astype(jnp.float32)
1378
+
1379
+ # Prepare inputs for the kernel.
904
1380
  if padded_num_experts != gating_output.shape[-1]:
905
1381
  gating_output = jnp.pad(
906
1382
  gating_output,
@@ -908,128 +1384,229 @@ def fused_ep_moe(
908
1384
  constant_values=-jnp.inf,
909
1385
  )
910
1386
 
911
- scope_name = f"fused_moe_k-{top_k}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}"
912
- fused_moe = jax.named_scope(scope_name)(
913
- pl.pallas_call(
914
- functools.partial(
915
- _fused_ep_moe_kernel,
916
- top_k=top_k,
917
- ep_axis_name=ep_axis_name,
918
- bt=bt,
919
- bf=bf,
920
- bd1=bd1,
921
- bd2=bd2,
922
- btc=btc,
923
- bfc=bfc,
924
- bd1c=bd1c,
925
- bd2c=bd2c,
926
- ),
927
- out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
928
- t_dtype),
929
- grid_spec=pltpu.PrefetchScalarGridSpec(
930
- num_scalar_prefetch=0,
931
- in_specs=[
932
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
933
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
934
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
935
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
936
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
937
- ],
938
- out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
939
- scratch_shapes=([
940
- # t2e_routing_x2_smem
941
- pltpu.SMEM((2, bt, padded_num_experts), jnp.int32),
942
- # d2e_count_x2_smem
943
- pltpu.SMEM((2, num_devices, 1, padded_num_experts),
944
- jnp.int32),
945
- # expert_offsets_x2_smem
946
- pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
947
- # expert_starts_x2_smem
948
- pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
949
- # expert_sizes_x2_smem
950
- pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
951
- # a2a_s_sends_x2_smem
952
- pltpu.SMEM((2, ), jnp.int32),
953
- # a2a_s_x2_vmem
954
- pltpu.VMEM(
955
- (
956
- 2,
957
- bt * num_devices,
958
- t_packing,
959
- hidden_size // t_packing,
960
- ),
961
- t_dtype,
1387
+ tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
1388
+
1389
+ hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
1390
+ renorm_str = "-renorm_k" if renormalize_topk_logits else ""
1391
+ scope_name = f"fused-moe-k_{top_k}{renorm_str}-bt_{bt}_{btc}-bf_{bf}_{bfc}-bd1_{bd1}_{bd1c}-bd2_{bd2}_{bd2c}"
1392
+ fused_moe = pl.pallas_call(
1393
+ functools.partial(
1394
+ _fused_ep_moe_kernel,
1395
+ top_k=top_k,
1396
+ renormalize_topk_logits=renormalize_topk_logits,
1397
+ ep_axis_name=ep_axis_name,
1398
+ act_fn=act_fn,
1399
+ subc_quant_wsz=subc_quant_wsz,
1400
+ bt=bt,
1401
+ bf=bf,
1402
+ bd1=bd1,
1403
+ bd2=bd2,
1404
+ btc=btc,
1405
+ bfc=bfc,
1406
+ bd1c=bd1c,
1407
+ bd2c=bd2c,
1408
+ ),
1409
+ out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
1410
+ t_dtype),
1411
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1412
+ num_scalar_prefetch=0,
1413
+ in_specs=[
1414
+ hbm_block_spec, # tokens_hbm
1415
+ hbm_block_spec, # w1_hbm
1416
+ hbm_block_spec, # w2_hbm
1417
+ None if w1_scale is None else hbm_block_spec, # w1_scale_hbm
1418
+ None if w2_scale is None else hbm_block_spec, # w2_scale_hbm
1419
+ None if b1 is None else hbm_block_spec, # b1_hbm
1420
+ None if b2 is None else hbm_block_spec, # b2_hbm
1421
+ hbm_block_spec, # gating_output_hbm
1422
+ hbm_block_spec, # a2a_g_hbm
1423
+ ],
1424
+ out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
1425
+ scratch_shapes=([
1426
+ # t2e_routing_x2_smem
1427
+ pltpu.SMEM((2, bt, padded_top_k), jnp.int32),
1428
+ # d2e_count_x2_smem
1429
+ pltpu.SMEM((2, num_devices, 1, padded_num_experts), jnp.int32),
1430
+ # expert_offsets_x2_smem
1431
+ pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
1432
+ # expert_starts_x2_smem
1433
+ pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1434
+ # expert_sizes_x2_smem
1435
+ pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1436
+ # a2a_s_sends_x2_smem
1437
+ pltpu.SMEM((2, ), jnp.int32),
1438
+ # a2a_s_x2_vmem
1439
+ pltpu.VMEM(
1440
+ (
1441
+ 2,
1442
+ bt * num_devices,
1443
+ t_packing,
1444
+ hidden_size // t_packing,
962
1445
  ),
963
- # a2a_s_acc_x2_vmem
964
- pltpu.VMEM(
965
- (
966
- 2,
967
- bt * num_devices,
968
- t_packing,
969
- hidden_size // t_packing,
970
- ),
971
- t_dtype,
1446
+ t_dtype,
1447
+ ),
1448
+ # a2a_s_acc_x2_vmem
1449
+ pltpu.VMEM(
1450
+ (
1451
+ 2,
1452
+ bt * num_devices,
1453
+ t_packing,
1454
+ hidden_size // t_packing,
972
1455
  ),
973
- # a2a_g_acc_vmem
974
- pltpu.VMEM(
975
- (top_k, bt, t_packing, hidden_size // t_packing),
976
- t_dtype),
977
- # b_gating_x2_vmem
978
- pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
979
- # b_output_x2_vmem
980
- pltpu.VMEM((2, bt, hidden_size), t_dtype),
981
- # b_w1_x2_vmem
982
- pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
983
- # b_w3_x2_vmem
984
- pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
985
- # b_w2_x2_vmem
986
- pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
987
- # b_acc_vmem
988
- pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
989
- # local_sems
990
- pltpu.SemaphoreType.DMA((2, 5)),
991
- # send_sems
992
- pltpu.SemaphoreType.DMA((2, )),
993
- # recv_sems
994
- pltpu.SemaphoreType.DMA((2, )),
995
- # a2a_gather_sem
996
- pltpu.SemaphoreType.DMA,
997
- # a2a_acc_sem
998
- pltpu.SemaphoreType.DMA,
999
- ]),
1000
- ),
1001
- compiler_params=pltpu.CompilerParams(
1002
- collective_id=0,
1003
- vmem_limit_bytes=100 * 1024 * 1024,
1004
- ),
1005
- name=scope_name,
1006
- ))
1456
+ t_dtype,
1457
+ ),
1458
+ # a2a_g_acc_vmem
1459
+ pltpu.VMEM((top_k, bt, t_packing, hidden_size // t_packing),
1460
+ t_dtype),
1461
+ # b_gating_x2_vmem
1462
+ pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
1463
+ # b_output_x2_vmem
1464
+ pltpu.VMEM((2, bt, hidden_size), t_dtype),
1465
+ # b_w1_x2_vmem
1466
+ pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1467
+ # b_w3_x2_vmem
1468
+ pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1469
+ # b_w2_x2_vmem
1470
+ pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
1471
+ # b_w1_scale_x2_vmem
1472
+ (None if w1_scale is None else pltpu.VMEM(
1473
+ (
1474
+ 2,
1475
+ t_packing,
1476
+ bd1 // t_packing // subc_quant_wsz,
1477
+ 1,
1478
+ bf,
1479
+ ),
1480
+ jnp.float32,
1481
+ )),
1482
+ # b_w3_scale_x2_vmem
1483
+ (None if w1_scale is None else pltpu.VMEM(
1484
+ (
1485
+ 2,
1486
+ t_packing,
1487
+ bd1 // t_packing // subc_quant_wsz,
1488
+ 1,
1489
+ bf,
1490
+ ),
1491
+ jnp.float32,
1492
+ )),
1493
+ # b_w2_scale_x2_vmem
1494
+ (None if w2_scale is None else pltpu.VMEM(
1495
+ (
1496
+ 2,
1497
+ t_packing,
1498
+ bf // subc_quant_wsz,
1499
+ 1,
1500
+ bd2 // t_packing,
1501
+ ),
1502
+ jnp.float32,
1503
+ )),
1504
+ # b_b1_x2_vmem
1505
+ (None if b1 is None else pltpu.VMEM(
1506
+ (
1507
+ 2,
1508
+ 1,
1509
+ bf,
1510
+ ),
1511
+ jnp.float32,
1512
+ )),
1513
+ # b_b3_x2_vmem
1514
+ (None if b1 is None else pltpu.VMEM(
1515
+ (
1516
+ 2,
1517
+ 1,
1518
+ bf,
1519
+ ),
1520
+ jnp.float32,
1521
+ )),
1522
+ # b_b2_x2_vmem
1523
+ (None if b2 is None else pltpu.VMEM(
1524
+ (
1525
+ 2,
1526
+ t_packing,
1527
+ 1,
1528
+ bd2 // t_packing,
1529
+ ),
1530
+ jnp.float32,
1531
+ )),
1532
+ # b_acc_vmem
1533
+ pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
1534
+ # local_sems
1535
+ pltpu.SemaphoreType.DMA((2, 5)),
1536
+ # send_sems
1537
+ pltpu.SemaphoreType.DMA((2, )),
1538
+ # recv_sems
1539
+ pltpu.SemaphoreType.DMA((2, )),
1540
+ # a2a_gather_sem
1541
+ pltpu.SemaphoreType.DMA,
1542
+ # a2a_acc_sem
1543
+ pltpu.SemaphoreType.DMA,
1544
+ ]),
1545
+ ),
1546
+ compiler_params=pltpu.CompilerParams(
1547
+ collective_id=0,
1548
+ vmem_limit_bytes=100 * 1024 * 1024,
1549
+ ),
1550
+ name=scope_name,
1551
+ )
1007
1552
 
1008
1553
  @jax.jit
1009
- @functools.partial(
1010
- shard_map.shard_map,
1554
+ @jax.shard_map(
1011
1555
  mesh=mesh,
1012
- in_specs=(P(ep_axis_name), P(ep_axis_name), P(ep_axis_name),
1013
- P(ep_axis_name), P()),
1556
+ in_specs=(
1557
+ P(ep_axis_name), # tokens_hbm
1558
+ P(ep_axis_name), # w1_hbm
1559
+ P(ep_axis_name), # w2_hbm
1560
+ None if w1_scale is None else P(ep_axis_name), # w1_scale_hbm
1561
+ None if w2_scale is None else P(ep_axis_name), # w2_scale_hbm
1562
+ None if b1 is None else P(ep_axis_name), # b1_hbm
1563
+ None if b2 is None else P(ep_axis_name), # b2_hbm
1564
+ P(ep_axis_name), # gating_output_hbm
1565
+ P(), # a2a_g_hbm
1566
+ ),
1014
1567
  out_specs=P(ep_axis_name),
1015
- check_rep=False,
1568
+ check_vma=False,
1016
1569
  )
1017
- def kernel(tokens, w1, w2, gating_output, a2a_g_hbm_scratch):
1570
+ def kernel(
1571
+ tokens,
1572
+ w1,
1573
+ w2,
1574
+ w1_scale,
1575
+ w2_scale,
1576
+ b1,
1577
+ b2,
1578
+ gating_output,
1579
+ a2a_g_hbm_scratch,
1580
+ ):
1018
1581
  return fused_moe(
1019
- pltpu.with_memory_space_constraint(tokens, pltpu.HBM),
1020
- pltpu.with_memory_space_constraint(w1, pltpu.HBM),
1021
- pltpu.with_memory_space_constraint(w2, pltpu.HBM),
1022
- pltpu.with_memory_space_constraint(gating_output, pltpu.HBM),
1023
- pltpu.with_memory_space_constraint(a2a_g_hbm_scratch, pltpu.HBM),
1582
+ pltpu.with_memory_space_constraint(tokens,
1583
+ pltpu.HBM), # tokens_hbm
1584
+ pltpu.with_memory_space_constraint(w1, pltpu.HBM), # w1_hbm
1585
+ pltpu.with_memory_space_constraint(w2, pltpu.HBM), # w2_hbm
1586
+ (None if w1_scale is None else pltpu.with_memory_space_constraint(
1587
+ w1_scale, pltpu.HBM)), # w1_scale_hbm
1588
+ (None if w2_scale is None else pltpu.with_memory_space_constraint(
1589
+ w2_scale, pltpu.HBM)), # w2_scale_hbm
1590
+ (None if b1 is None else pltpu.with_memory_space_constraint(
1591
+ b1, pltpu.HBM)), # b1_hbm
1592
+ (None if b2 is None else pltpu.with_memory_space_constraint(
1593
+ b2, pltpu.HBM)), # b2_hbm
1594
+ pltpu.with_memory_space_constraint(gating_output,
1595
+ pltpu.HBM), # gating_output_hbm
1596
+ pltpu.with_memory_space_constraint(a2a_g_hbm_scratch,
1597
+ pltpu.HBM), # a2a_g_hbm
1024
1598
  )
1025
1599
 
1026
1600
  a2a_g_hbm_scratch = pl.empty(
1027
1601
  (num_experts, bt, t_packing, hidden_size // t_packing), t_dtype)
1028
- results = kernel(
1602
+ return kernel(
1029
1603
  tokens,
1030
1604
  w1,
1031
1605
  w2,
1606
+ w1_scale,
1607
+ w2_scale,
1608
+ b1,
1609
+ b2,
1032
1610
  gating_output,
1033
1611
  a2a_g_hbm_scratch,
1034
1612
  )
1035
- return results[:, :actual_hidden_size]