tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1612 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """TPU-Friendly Fused Mixture of Experts (MoE) kernel."""
15
+
16
+ import functools
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ from jax import lax
21
+ from jax._src import dtypes
22
+ from jax.experimental import pallas as pl
23
+ from jax.experimental.pallas import tpu as pltpu
24
+
25
+ P = jax.sharding.PartitionSpec
26
+
27
+ cdiv = pl.cdiv
28
+
29
+
30
+ def align_to(x, a):
31
+ return cdiv(x, a) * a
32
+
33
+
34
+ def get_dtype_packing(dtype):
35
+ bits = (dtypes.bit_width(dtype)
36
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
37
+ return 32 // bits
38
+
39
+
40
+ def broadcast_minor(src, shape):
41
+ if src.shape == shape:
42
+ return src
43
+ assert src.shape[:-1] == shape[:-1]
44
+ assert src.shape[-1] % 128 == 0
45
+ target_minor = align_to(shape[-1], src.shape[-1])
46
+ # no-op concatenation.
47
+ return jnp.concatenate([src for _ in range(target_minor // src.shape[-1])],
48
+ axis=-1)[..., :shape[-1]]
49
+
50
+
51
+ def swigluoai(gate: jax.Array,
52
+ up: jax.Array,
53
+ *,
54
+ alpha: float = 1.702,
55
+ limit: float = 7.0) -> jax.Array:
56
+ """Activation used in some models such as GPT-OSS."""
57
+ gate = jnp.clip(gate, a_max=limit)
58
+ up = jnp.clip(up, a_min=-limit, a_max=limit)
59
+ glu = gate * jax.nn.sigmoid(alpha * gate)
60
+ return (up + 1.0) * glu
61
+
62
+
63
+ def activation_fn(acc1, acc3, act_fn):
64
+ if act_fn == "silu":
65
+ return jax.nn.silu(acc1) * acc3
66
+ elif act_fn == "gelu":
67
+ return jax.nn.gelu(acc1) * acc3
68
+ elif act_fn == "swigluoai":
69
+ return swigluoai(acc1, acc3)
70
+ else:
71
+ raise RuntimeError(f"Unsupported activation function: {act_fn}")
72
+
73
+
74
+ def ref_moe(
75
+ tokens: jax.Array, # (num_tokens, hidden_size)
76
+ w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
77
+ w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
78
+ gating_output: jax.Array, # (num_tokens, num_experts)
79
+ top_k: int,
80
+ *,
81
+ renormalize_topk_logits: bool = False,
82
+ act_fn: str = "silu",
83
+ subc_quant_wsz: int | None = None,
84
+ w1_scale:
85
+ (
86
+ jax.Array | None
87
+ ) = None, # F32(num_experts, 2, hidden_size //subc_quant_wsz, 1, intermediate_size)
88
+ w2_scale:
89
+ (
90
+ jax.Array | None
91
+ ) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
92
+ b1: jax.Array
93
+ | None = None, # F32(num_experts, 2, 1, intermediate_size)
94
+ b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
95
+ ):
96
+ n_tokens = tokens.shape[0] # num_tokens
97
+
98
+ # Compute gating scores for all experts
99
+ gating_logits = jax.nn.softmax(gating_output,
100
+ axis=-1) # [num_tokens, n_experts]
101
+
102
+ # Select top-k experts per token
103
+ top_k_logits, top_k_indices = lax.top_k(
104
+ gating_logits, top_k) # [num_tokens, top_k], [num_tokens, top_k]
105
+
106
+ if renormalize_topk_logits:
107
+ top_k_logits = top_k_logits / jnp.sum(
108
+ top_k_logits, axis=-1, keepdims=True)
109
+
110
+ t_outputs = []
111
+ hidden_size, intermediate_size = w1.shape[-2:]
112
+
113
+ # Process each token individually
114
+ for i in range(n_tokens):
115
+ curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, hidden_size]
116
+ assigned_expert_ids = top_k_indices[
117
+ i] # [top_k] - indices of selected experts for token i
118
+ tok_expert_act = []
119
+
120
+ # Process each selected expert for the current token
121
+ for expert_id in assigned_expert_ids:
122
+ # Get expert weights
123
+ expert_w1 = w1[expert_id, 0].astype(jnp.float32)
124
+ expert_w3 = w1[expert_id, 1].astype(jnp.float32)
125
+ if w1_scale is not None:
126
+ expert_w1 *= jnp.repeat(w1_scale[expert_id, 0, :, 0],
127
+ subc_quant_wsz,
128
+ axis=0)[:hidden_size]
129
+ expert_w3 *= jnp.repeat(w1_scale[expert_id, 1, :, 0],
130
+ subc_quant_wsz,
131
+ axis=0)[:hidden_size]
132
+ expert_weight_1 = jnp.concat(
133
+ [expert_w1, expert_w3],
134
+ axis=-1) # [hidden_size, 2 * intermediate_size]
135
+ expert_weight_2 = w2[expert_id].astype(
136
+ jnp.float32) # [intermediate_size, hidden_size]
137
+ if w2_scale is not None:
138
+ expert_weight_2 *= jnp.repeat(w2_scale[expert_id, :, 0],
139
+ subc_quant_wsz,
140
+ axis=0)[:intermediate_size]
141
+
142
+ # First linear layer with SwiGLU activation
143
+ gmm_1_out = curr_token @ expert_weight_1 # [1, 2 * intermediate_size]
144
+
145
+ # Split into gate and up projections for SwiGLU
146
+ gmm1_w1_proj, gmm1_w3_proj = jnp.split(
147
+ gmm_1_out, 2,
148
+ axis=-1) # [1, intermediate_size], [1, intermediate_size]
149
+ if b1 is not None:
150
+ gmm1_w1_proj += b1[expert_id:expert_id + 1, 0, 0]
151
+ gmm1_w3_proj += b1[expert_id:expert_id + 1, 1, 0]
152
+
153
+ # Apply gated activation: activation(gate) * up
154
+ act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, act_fn)
155
+
156
+ # Second linear layer (down projection)
157
+ gmm_2_out = act @ expert_weight_2 # [1, hidden_size]
158
+ if b2 is not None:
159
+ gmm_2_out += b2[expert_id:expert_id + 1, 0]
160
+ tok_expert_act.append(gmm_2_out)
161
+
162
+ # Combine outputs from all selected experts
163
+ experts_act = jnp.concatenate(tok_expert_act,
164
+ axis=0) # [top_k, hidden_size]
165
+
166
+ # Weighted sum using top-k gating weights
167
+ top_k_weights = top_k_logits[i] # [top_k]
168
+ top_k_weights = jnp.expand_dims(top_k_weights, axis=1) # [top_k, 1]
169
+ weighted_output = jnp.sum(experts_act * top_k_weights,
170
+ axis=0,
171
+ keepdims=True) # [1, hidden_size]
172
+
173
+ t_outputs.append(weighted_output.astype(tokens.dtype))
174
+
175
+ return jnp.concatenate(t_outputs,
176
+ axis=0) # [actual_num_tokens, hidden_size]
177
+
178
+
179
+ def _fused_ep_moe_kernel(
180
+ # Input
181
+ tokens_hbm, # (local_num_tokens, t_packing, hidden_size // t_packing)
182
+ w1_hbm, # (local_num_experts, 2, hidden_size, intermediate_size)
183
+ w2_hbm, # (local_num_experts, intermediate_size, hidden_size)
184
+ # TODO(jevinjiang): We choose F32 scale for easier slicing. The extra
185
+ # latency should be hidden in the pipeline overlaping. But is there a better
186
+ # way to do this?
187
+ w1_scale_hbm, # None | F32(local_num_experts, 2, cdiv(hidden_size, subc_quant_wsz), 1, intermediate_size)
188
+ w2_scale_hbm, # None | F32(local_num_experts, cdiv(intermediate_size, subc_quant_wsz), 1, hidden_size)
189
+ b1_hbm, # None | F32(local_num_experts, 2, 1, intermediate_size)
190
+ b2_hbm, # None | F32(local_num_experts, 1, hidden_size)
191
+ gating_hbm, # (local_num_tokens, padded_num_experts)
192
+ a2a_g_hbm, # (num_experts, bt, t_packing, hidden_size // t_packing)
193
+ # Output
194
+ output_hbm, # (local_num_tokens, hidden_size)
195
+ # Scratch
196
+ t2e_routing_x2_smem, # <bt_sem_id> (2, bt, padded_top_k)
197
+ d2e_count_x2_smem, # <bt_sem_id> (2, num_devices, 1, padded_num_experts)
198
+ expert_offsets_x2_smem, # <bt_sem_id> (2, 2, padded_num_experts): for a2a_s and a2a_g
199
+ expert_starts_x2_smem, # <bt_sem_id> (2, 1, padded_num_experts)
200
+ expert_sizes_x2_smem, # <bt_sem_id> (2, 1, padded_num_experts)
201
+ a2a_s_sends_x2_smem, # <e_sem_id> (2,)
202
+ a2a_s_x2_vmem, # <e_sem_id> (2, bt * num_devices, t_packing, hidden_size // t_packing)
203
+ a2a_s_acc_x2_vmem, # <e_sem_id> (2, bt * num_devices, t_packing, hidden_size // t_packing)
204
+ ### Accumulation for gathered tokens:
205
+ a2a_g_acc_vmem, # (top_k, bt, t_packing, hidden_size // t_packing)
206
+ ### Expert weight double buffering:
207
+ b_gating_x2_vmem, # <bt_sem_id> (2, bt, padded_num_experts)
208
+ b_output_x2_vmem, # <bt_sem_id> (2, bt, hidden_size)
209
+ b_w1_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
210
+ b_w3_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
211
+ b_w2_x2_vmem, # <bw_sem_id> (2, t_packing, bf, bd2 // t_packing)
212
+ b_w1_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
213
+ b_w3_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
214
+ b_w2_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bf // subc_quant_wsz, 1, bd2 // t_packing)
215
+ b_b1_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
216
+ b_b3_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
217
+ b_b2_x2_vmem, # None | <bw_sem_id> (2, t_packing, 1, bd2 // t_packing)
218
+ b_acc_vmem, # F32(bt * num_devices, 1, bf * 2)
219
+ ### Semaphores:
220
+ local_sems, # (2, 5): 2 x [b_gating_sem, b_w1_sem, b_w2_sem, b_w3_sem, b_output_sem]
221
+ send_sems, # <e_sem_id> (2,)
222
+ recv_sems, # <e_sem_id> (2,)
223
+ a2a_gather_sem,
224
+ a2a_acc_sem,
225
+ *,
226
+ top_k: int,
227
+ renormalize_topk_logits: bool,
228
+ ep_axis_name: str,
229
+ act_fn: str,
230
+ subc_quant_wsz: int | None = None,
231
+ # Kernel tuning params.
232
+ bt: int, # Block size of local_num_tokens.
233
+ bf: int, # Block size of intermediate_size.
234
+ bd1: int, # Block size of hidden_size in w1.
235
+ bd2: int, # Block size of hidden_size in w2.
236
+ btc: int, # Compute size of block tokens for active expert.
237
+ bfc: int, # Compute size of block intermediate_size.
238
+ bd1c: int, # Compute size of block hidden_size.
239
+ bd2c: int, # Compute size of block hidden_size.
240
+ ):
241
+ my_id = lax.axis_index(ep_axis_name)
242
+ num_devices = lax.axis_size(ep_axis_name)
243
+ local_num_tokens = tokens_hbm.shape[0]
244
+ local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
245
+ right_id = (my_id + 1) % num_devices
246
+ num_experts = a2a_g_hbm.shape[0]
247
+ padded_num_experts = d2e_count_x2_smem.shape[-1]
248
+ padded_top_k = t2e_routing_x2_smem.shape[-1]
249
+ assert padded_num_experts == align_to(num_experts, 128)
250
+ assert padded_top_k == align_to(top_k, 128)
251
+
252
+ t_dtype = tokens_hbm.dtype
253
+ t_packing = get_dtype_packing(t_dtype)
254
+ t_bitwidth = 32 // t_packing
255
+ assert a2a_g_hbm.dtype == t_dtype
256
+ assert w1_hbm.dtype == w2_hbm.dtype
257
+
258
+ assert bd1 % bd1c == 0
259
+ assert bd2 % bd2c == 0
260
+ assert bf % bfc == 0
261
+ assert hidden_size % t_packing == 0
262
+ assert bd1 % t_packing == 0
263
+ assert bd2 % t_packing == 0
264
+ assert bd1c % t_packing == 0
265
+ assert bd2c % t_packing == 0
266
+
267
+ h_per_t_packing = hidden_size // t_packing
268
+ assert tokens_hbm.shape[-1] == h_per_t_packing
269
+ bd1_per_t_packing = bd1 // t_packing
270
+ bd2_per_t_packing = bd2 // t_packing
271
+ bd1c_per_t_packing = bd1c // t_packing
272
+ bd2c_per_t_packing = bd2c // t_packing
273
+
274
+ if subc_quant_wsz is not None:
275
+ assert subc_quant_wsz % 256 == 0
276
+ assert bd1c_per_t_packing == subc_quant_wsz
277
+ assert bfc == subc_quant_wsz
278
+ assert bd1 % subc_quant_wsz == 0
279
+ assert bf % subc_quant_wsz == 0
280
+ assert bd1_per_t_packing % subc_quant_wsz == 0
281
+ assert h_per_t_packing % subc_quant_wsz == 0
282
+
283
+ num_bt = cdiv(local_num_tokens, bt)
284
+ num_bf = cdiv(intermediate_size, bf)
285
+ num_bd1 = cdiv(hidden_size, bd1)
286
+ num_bd2 = cdiv(hidden_size, bd2)
287
+
288
+ def get_mesh_device_id(ep_rank):
289
+ dp_rank = jax.lax.axis_index("data")
290
+ return (dp_rank, ep_rank)
291
+
292
+ def sync_barrier():
293
+ barrier_sem = pltpu.get_barrier_semaphore()
294
+ pltpu.semaphore_signal(
295
+ barrier_sem,
296
+ device_id=get_mesh_device_id(right_id),
297
+ device_id_type=pltpu.DeviceIdType.MESH,
298
+ )
299
+ pltpu.semaphore_wait(barrier_sem, 1)
300
+
301
+ def start_fetch_b_gating(bt_id, priority=0):
302
+ is_valid = jnp.logical_and(0 <= bt_id, bt_id < num_bt)
303
+ sz = pl.multiple_of(lax.select(is_valid, bt, 0), bt)
304
+ bt_sem_id = (bt_id + 2) % 2
305
+ b_gating_sem = local_sems.at[bt_sem_id, 0]
306
+ pltpu.make_async_copy(
307
+ src_ref=gating_hbm.at[pl.ds(bt_id * bt, sz)],
308
+ dst_ref=b_gating_x2_vmem.at[bt_sem_id, pl.ds(0, sz)],
309
+ sem=b_gating_sem,
310
+ ).start(priority=priority)
311
+
312
+ def wait_fetch_b_gating(bt_id):
313
+ bt_sem_id = bt_id % 2
314
+ b_gating_sem = local_sems.at[bt_sem_id, 0]
315
+ pltpu.make_async_copy(
316
+ src_ref=b_gating_x2_vmem.at[bt_sem_id],
317
+ dst_ref=b_gating_x2_vmem.at[bt_sem_id],
318
+ sem=b_gating_sem,
319
+ ).wait()
320
+
321
+ def get_top_k(input, top_k, renormalize_topk_logits):
322
+ assert len(input.shape) == 2, input.shape
323
+ input = input.astype(jnp.float32)
324
+ padded_k_shape = (input.shape[0], padded_top_k)
325
+ top_k_logits_lst = []
326
+ top_k_indices_lst = []
327
+ t2e = jnp.zeros(input.shape, dtype=jnp.int32)
328
+ t2e_routing = jnp.zeros(padded_k_shape, dtype=jnp.int32)
329
+ iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1)
330
+ padded_k_iota = jax.lax.broadcasted_iota(jnp.int32, padded_k_shape, 1)
331
+ top_k_logits_sum = jnp.zeros(padded_k_shape, jnp.float32)
332
+
333
+ for k_id in range(top_k):
334
+ # TODO(jevinjiang): return both top_k values and indices in Mosaic
335
+ top_k_logits = jnp.broadcast_to(
336
+ jnp.max(input[:, :num_experts], axis=1, keepdims=True),
337
+ padded_k_shape,
338
+ ).astype(input.dtype)
339
+ top_k_logits_lst.append(top_k_logits)
340
+ if renormalize_topk_logits:
341
+ top_k_logits_sum += top_k_logits
342
+ # TODO(jevinjiang): support bf16 argmax in Mosaic
343
+ top_k_indices = jnp.broadcast_to(
344
+ jnp.argmax(input[:, :num_experts], axis=1, keepdims=True),
345
+ padded_k_shape,
346
+ )
347
+ top_k_indices_lst.append(top_k_indices)
348
+ t2e_routing = jnp.where(padded_k_iota == k_id, top_k_indices,
349
+ t2e_routing)
350
+ mask = iota == broadcast_minor(top_k_indices, input.shape)
351
+ t2e += mask.astype(jnp.int32)
352
+ if k_id != top_k - 1:
353
+ input = jnp.where(mask, -jnp.inf, input)
354
+
355
+ if renormalize_topk_logits:
356
+ for k_id in range(top_k):
357
+ top_k_logits_lst[k_id] /= top_k_logits_sum
358
+
359
+ expert_sizes = jnp.sum(t2e, axis=0, keepdims=True)
360
+ expert_starts = jnp.zeros_like(expert_sizes)
361
+ return top_k_logits_lst, t2e_routing, expert_sizes, expert_starts
362
+
363
+ def all_reduce_metadata(bt_sem_id, t2e_routing, starts, sizes):
364
+ send_sem = send_sems.at[0]
365
+ recv_sem = recv_sems.at[0]
366
+
367
+ # All-reduce to accumulate starts and sizes and transfer to SMEM.
368
+ def _all_reduce_metadata(
369
+ t2e_routing_vmem,
370
+ d2e_count_vmem,
371
+ offsets_vmem,
372
+ starts_vmem,
373
+ sizes_vmem,
374
+ ):
375
+ offsets_vmem[...] = jnp.zeros_like(offsets_vmem)
376
+ # TODO(jevinjiang): check how slow is VMEM -> SMEM.
377
+ offsets_copy = pltpu.async_copy(
378
+ src_ref=offsets_vmem,
379
+ dst_ref=expert_offsets_x2_smem.at[bt_sem_id],
380
+ sem=send_sem,
381
+ )
382
+ t2e_routing_vmem[...] = t2e_routing
383
+ t2e_routing_copy = pltpu.async_copy(
384
+ src_ref=t2e_routing_vmem,
385
+ dst_ref=t2e_routing_x2_smem.at[bt_sem_id],
386
+ sem=send_sem,
387
+ )
388
+ reduced_sizes = sizes
389
+ reduced_starts = starts
390
+ row_id = my_id
391
+ d2e_count_vmem[row_id] = sizes
392
+ for i in range(num_devices - 1):
393
+ sync_barrier()
394
+ # TODO(jevinjiang): we can use double buffering to improve AR if needed.
395
+ pltpu.async_remote_copy(
396
+ src_ref=d2e_count_vmem.at[row_id],
397
+ dst_ref=d2e_count_vmem.at[row_id],
398
+ send_sem=send_sem,
399
+ recv_sem=recv_sem,
400
+ device_id=get_mesh_device_id(right_id),
401
+ device_id_type=pltpu.DeviceIdType.MESH,
402
+ ).wait()
403
+ row_id = (row_id + num_devices - 1) % num_devices
404
+ new_sizes = d2e_count_vmem[row_id]
405
+ reduced_sizes += new_sizes
406
+ reduced_starts += lax.select(my_id > i, new_sizes,
407
+ jnp.zeros_like(new_sizes))
408
+ starts_vmem[...] = reduced_starts
409
+ sizes_vmem[...] = reduced_sizes
410
+
411
+ starts_copy = pltpu.async_copy(
412
+ src_ref=starts_vmem,
413
+ dst_ref=expert_starts_x2_smem.at[bt_sem_id],
414
+ sem=send_sem,
415
+ )
416
+ sizes_copy = pltpu.async_copy(
417
+ src_ref=sizes_vmem,
418
+ dst_ref=expert_sizes_x2_smem.at[bt_sem_id],
419
+ sem=send_sem,
420
+ )
421
+
422
+ # TODO(jevinjiang): if d2e_count is too big, we can store in HBM and fetch
423
+ # to SMEM partially.
424
+ d2e_count_copy = pltpu.async_copy(
425
+ src_ref=d2e_count_vmem,
426
+ dst_ref=d2e_count_x2_smem.at[bt_sem_id],
427
+ sem=send_sem,
428
+ )
429
+
430
+ t2e_routing_copy.wait()
431
+ d2e_count_copy.wait()
432
+ offsets_copy.wait()
433
+ starts_copy.wait()
434
+ sizes_copy.wait()
435
+
436
+ pl.run_scoped(
437
+ _all_reduce_metadata,
438
+ pltpu.VMEM(t2e_routing_x2_smem.shape[1:],
439
+ t2e_routing_x2_smem.dtype),
440
+ pltpu.VMEM(d2e_count_x2_smem.shape[1:], d2e_count_x2_smem.dtype),
441
+ pltpu.VMEM(expert_offsets_x2_smem.shape[1:],
442
+ expert_offsets_x2_smem.dtype),
443
+ pltpu.VMEM(expert_starts_x2_smem.shape[1:],
444
+ expert_starts_x2_smem.dtype),
445
+ pltpu.VMEM(expert_sizes_x2_smem.shape[1:],
446
+ expert_sizes_x2_smem.dtype),
447
+ )
448
+
449
+ def start_a2a_scatter(bt_id, e_sem_id, local_e_id):
450
+ bt_sem_id = bt_id % 2
451
+
452
+ # Counting the number of remote sends from the current device.
453
+ send_sz = 0
454
+ for bt_t_id in range(bt):
455
+ for k_id in range(top_k):
456
+ e_id = t2e_routing_x2_smem[bt_sem_id, bt_t_id, k_id]
457
+ is_active_expert = e_id % local_num_experts == local_e_id
458
+ recv_id = e_id // local_num_experts
459
+ offset = expert_offsets_x2_smem[bt_sem_id, 0, e_id]
460
+ sz = lax.select(is_active_expert, 1, 0)
461
+ is_local = recv_id == my_id
462
+ local_sz = lax.select(is_local, sz, 0)
463
+ remote_sz = lax.select(is_local, 0, sz)
464
+ send_sz += remote_sz
465
+ expert_offsets_x2_smem[bt_sem_id, 0,
466
+ e_id] = (offset + local_sz + remote_sz)
467
+ start = expert_starts_x2_smem[bt_sem_id, 0, e_id] + offset
468
+ t_id = bt * bt_id + bt_t_id
469
+ # TODO(jevinjiang): compare the perf when using branches.
470
+ pltpu.make_async_copy(
471
+ src_ref=tokens_hbm.at[pl.ds(t_id, local_sz)],
472
+ dst_ref=a2a_s_x2_vmem.at[e_sem_id,
473
+ pl.ds(start, local_sz)],
474
+ sem=recv_sems.at[e_sem_id],
475
+ ).start()
476
+ pltpu.make_async_remote_copy(
477
+ src_ref=tokens_hbm.at[pl.ds(t_id, remote_sz)],
478
+ dst_ref=a2a_s_x2_vmem.at[e_sem_id,
479
+ pl.ds(start, remote_sz)],
480
+ send_sem=send_sems.at[e_sem_id],
481
+ recv_sem=recv_sems.at[e_sem_id],
482
+ device_id=get_mesh_device_id(recv_id),
483
+ device_id_type=pltpu.DeviceIdType.MESH,
484
+ ).start()
485
+ a2a_s_sends_x2_smem[e_sem_id] = send_sz
486
+
487
+ def wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id):
488
+ bt_sem_id = bt_id % 2
489
+ e_id = my_id * local_num_experts + local_e_id
490
+ sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
491
+ pltpu.make_async_copy(
492
+ src_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
493
+ dst_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
494
+ sem=recv_sems.at[e_sem_id],
495
+ ).wait()
496
+
497
+ def wait_a2a_scatter_send(bt_id, e_sem_id, local_e_id):
498
+ del bt_id, local_e_id
499
+ sz = a2a_s_sends_x2_smem[e_sem_id]
500
+ pltpu.make_async_copy(
501
+ src_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
502
+ dst_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
503
+ sem=send_sems.at[e_sem_id],
504
+ ).wait()
505
+
506
+ def start_a2a_gather(bt_id, e_sem_id, local_e_id):
507
+ my_e_id = my_id * local_num_experts + local_e_id
508
+ bt_sem_id = bt_id % 2
509
+ start = 0
510
+ for recv_id in range(num_devices):
511
+ sz = d2e_count_x2_smem[bt_sem_id, recv_id, 0, my_e_id]
512
+ is_local = recv_id == my_id
513
+ local_sz = lax.select(is_local, sz, 0)
514
+ remote_sz = lax.select(is_local, 0, sz)
515
+ pltpu.make_async_copy(
516
+ src_ref=a2a_s_acc_x2_vmem.at[e_sem_id,
517
+ pl.ds(start, local_sz)],
518
+ dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, local_sz)],
519
+ sem=a2a_gather_sem,
520
+ ).start()
521
+ pltpu.make_async_remote_copy(
522
+ src_ref=a2a_s_acc_x2_vmem.at[e_sem_id,
523
+ pl.ds(start, remote_sz)],
524
+ dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)],
525
+ send_sem=send_sems.at[e_sem_id],
526
+ recv_sem=a2a_gather_sem,
527
+ device_id=get_mesh_device_id(recv_id),
528
+ device_id_type=pltpu.DeviceIdType.MESH,
529
+ ).start()
530
+ start += sz
531
+
532
+ def wait_a2a_gather_send(bt_id, e_sem_id, local_e_id):
533
+ my_e_id = my_id * local_num_experts + local_e_id
534
+ bt_sem_id = bt_id % 2
535
+ sz = expert_sizes_x2_smem[bt_sem_id, 0, my_e_id]
536
+ local_sz = d2e_count_x2_smem[bt_sem_id, my_id, 0, my_e_id]
537
+ remote_sz = sz - local_sz
538
+ is_valid = jnp.logical_and(0 <= local_e_id, local_e_id
539
+ < local_num_experts)
540
+ remote_sz = lax.select(is_valid, remote_sz, 0)
541
+ pltpu.make_async_copy(
542
+ src_ref=a2a_g_hbm.at[0, pl.ds(0, remote_sz)],
543
+ dst_ref=a2a_g_hbm.at[0, pl.ds(0, remote_sz)],
544
+ sem=send_sems.at[e_sem_id],
545
+ ).wait()
546
+
547
+ def wait_a2a_gather_recv_all():
548
+ sz = top_k * bt
549
+ pltpu.make_async_copy(
550
+ src_ref=a2a_g_hbm.at[0, pl.ds(0, sz)],
551
+ dst_ref=a2a_g_hbm.at[0, pl.ds(0, sz)],
552
+ sem=a2a_gather_sem,
553
+ ).wait()
554
+
555
+ def start_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
556
+ for p in range(t_packing):
557
+ offset = p * h_per_t_packing + bd1_id * bd1_per_t_packing
558
+ pltpu.make_async_copy(
559
+ src_ref=w1_hbm.at[
560
+ local_e_id,
561
+ 0,
562
+ pl.ds(offset, bd1_per_t_packing),
563
+ pl.ds(bf_id * bf, bf),
564
+ ],
565
+ dst_ref=b_w1_x2_vmem.at[bw1_sem_id, p],
566
+ sem=local_sems.at[bw1_sem_id, 1],
567
+ ).start()
568
+ if w1_scale_hbm is not None:
569
+ assert subc_quant_wsz is not None
570
+ pltpu.make_async_copy(
571
+ src_ref=w1_scale_hbm.at[
572
+ local_e_id,
573
+ 0,
574
+ pl.ds(
575
+ offset // subc_quant_wsz,
576
+ bd1_per_t_packing // subc_quant_wsz,
577
+ ),
578
+ pl.ds(0, 1),
579
+ pl.ds(bf_id * bf, bf),
580
+ ],
581
+ dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id, p],
582
+ sem=local_sems.at[bw1_sem_id, 1],
583
+ ).start()
584
+ if b1_hbm is not None and bd1_id == 0:
585
+ pltpu.make_async_copy(
586
+ src_ref=b1_hbm.at[local_e_id, 0,
587
+ pl.ds(0, 1),
588
+ pl.ds(bf_id * bf, bf)],
589
+ dst_ref=b_b1_x2_vmem.at[bf_id % 2],
590
+ sem=local_sems.at[bw1_sem_id, 1],
591
+ ).start()
592
+
593
+ def start_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
594
+ for p in range(t_packing):
595
+ offset = p * h_per_t_packing + bd2_id * bd2_per_t_packing
596
+ pltpu.make_async_copy(
597
+ src_ref=w2_hbm.at[
598
+ local_e_id,
599
+ pl.ds(bf_id * bf, bf),
600
+ pl.ds(offset, bd2_per_t_packing),
601
+ ],
602
+ dst_ref=b_w2_x2_vmem.at[bw2_sem_id, p],
603
+ sem=local_sems.at[bw2_sem_id, 2],
604
+ ).start()
605
+ if w2_scale_hbm is not None:
606
+ assert subc_quant_wsz is not None
607
+ pltpu.make_async_copy(
608
+ src_ref=w2_scale_hbm.at[
609
+ local_e_id,
610
+ pl.ds(bf_id * bf // subc_quant_wsz, bf //
611
+ subc_quant_wsz),
612
+ pl.ds(0, 1),
613
+ pl.ds(offset, bd2_per_t_packing),
614
+ ],
615
+ dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id, p],
616
+ sem=local_sems.at[bw2_sem_id, 2],
617
+ ).start()
618
+ if b2_hbm is not None and bf_id == 0:
619
+ pltpu.make_async_copy(
620
+ src_ref=b2_hbm.at[local_e_id,
621
+ pl.ds(0, 1),
622
+ pl.ds(offset, bd2_per_t_packing)],
623
+ dst_ref=b_b2_x2_vmem.at[bd2_id % 2, p],
624
+ sem=local_sems.at[bw2_sem_id, 2],
625
+ ).start()
626
+
627
+ def start_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
628
+ for p in range(t_packing):
629
+ offset = p * h_per_t_packing + bd3_id * bd1_per_t_packing
630
+ pltpu.make_async_copy(
631
+ src_ref=w1_hbm.at[
632
+ local_e_id,
633
+ 1,
634
+ pl.ds(offset, bd1_per_t_packing),
635
+ pl.ds(bf_id * bf, bf),
636
+ ],
637
+ dst_ref=b_w3_x2_vmem.at[bw3_sem_id, p],
638
+ sem=local_sems.at[bw3_sem_id, 3],
639
+ ).start()
640
+ if w1_scale_hbm is not None:
641
+ assert subc_quant_wsz is not None
642
+ pltpu.make_async_copy(
643
+ src_ref=w1_scale_hbm.at[
644
+ local_e_id,
645
+ 1,
646
+ pl.ds(
647
+ offset // subc_quant_wsz,
648
+ bd1_per_t_packing // subc_quant_wsz,
649
+ ),
650
+ pl.ds(0, 1),
651
+ pl.ds(bf_id * bf, bf),
652
+ ],
653
+ dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id, p],
654
+ sem=local_sems.at[bw3_sem_id, 3],
655
+ ).start()
656
+ if b1_hbm is not None and bd3_id == 0:
657
+ pltpu.make_async_copy(
658
+ src_ref=b1_hbm.at[local_e_id, 1,
659
+ pl.ds(0, 1),
660
+ pl.ds(bf_id * bf, bf)],
661
+ dst_ref=b_b3_x2_vmem.at[bf_id % 2],
662
+ sem=local_sems.at[bw3_sem_id, 3],
663
+ ).start()
664
+
665
+ def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
666
+ del local_e_id
667
+ pltpu.make_async_copy(
668
+ src_ref=b_w1_x2_vmem.at[bw1_sem_id],
669
+ dst_ref=b_w1_x2_vmem.at[bw1_sem_id],
670
+ sem=local_sems.at[bw1_sem_id, 1],
671
+ ).wait()
672
+ if w1_scale_hbm is not None:
673
+ pltpu.make_async_copy(
674
+ src_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
675
+ dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
676
+ sem=local_sems.at[bw1_sem_id, 1],
677
+ ).wait()
678
+ if b1_hbm is not None and bd1_id == 0:
679
+ pltpu.make_async_copy(
680
+ src_ref=b_b1_x2_vmem.at[bf_id % 2],
681
+ dst_ref=b_b1_x2_vmem.at[bf_id % 2],
682
+ sem=local_sems.at[bw1_sem_id, 1],
683
+ ).wait()
684
+
685
+ def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
686
+ del local_e_id
687
+ pltpu.make_async_copy(
688
+ src_ref=b_w2_x2_vmem.at[bw2_sem_id],
689
+ dst_ref=b_w2_x2_vmem.at[bw2_sem_id],
690
+ sem=local_sems.at[bw2_sem_id, 2],
691
+ ).wait()
692
+ if w2_scale_hbm is not None:
693
+ pltpu.make_async_copy(
694
+ src_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
695
+ dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
696
+ sem=local_sems.at[bw2_sem_id, 2],
697
+ ).wait()
698
+ if b2_hbm is not None and bf_id == 0:
699
+ pltpu.make_async_copy(
700
+ src_ref=b_b2_x2_vmem.at[bd2_id % 2],
701
+ dst_ref=b_b2_x2_vmem.at[bd2_id % 2],
702
+ sem=local_sems.at[bw2_sem_id, 2],
703
+ ).wait()
704
+
705
+ def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
706
+ del local_e_id
707
+ pltpu.make_async_copy(
708
+ src_ref=b_w3_x2_vmem.at[bw3_sem_id],
709
+ dst_ref=b_w3_x2_vmem.at[bw3_sem_id],
710
+ sem=local_sems.at[bw3_sem_id, 3],
711
+ ).wait()
712
+ if w1_scale_hbm is not None:
713
+ pltpu.make_async_copy(
714
+ src_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
715
+ dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
716
+ sem=local_sems.at[bw3_sem_id, 3],
717
+ ).wait()
718
+ if b1_hbm is not None and bd3_id == 0:
719
+ pltpu.make_async_copy(
720
+ src_ref=b_b3_x2_vmem.at[bf_id % 2],
721
+ dst_ref=b_b3_x2_vmem.at[bf_id % 2],
722
+ sem=local_sems.at[bw3_sem_id, 3],
723
+ ).wait()
724
+
725
+ def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id):
726
+ next_bd1_id = bd1_id + 1
727
+ next_bd2_id = bd2_id + 1
728
+ next_sem_id = (bw_sem_id + 1) % 2
729
+
730
+ if bf_id >= num_bf:
731
+ return
732
+ if next_bd1_id < num_bd1:
733
+ start_fetch_bw1(local_e_id, next_sem_id, bf_id, next_bd1_id)
734
+ start_fetch_bw3(local_e_id, next_sem_id, bf_id, next_bd1_id)
735
+ elif next_bd1_id == num_bd1:
736
+ start_fetch_bw2(local_e_id, next_sem_id, bf_id, 0)
737
+ elif next_bd2_id < num_bd2:
738
+ start_fetch_bw2(local_e_id, next_sem_id, bf_id, next_bd2_id)
739
+ elif next_bd2_id == num_bd2:
740
+ start_fetch_next_bw(local_e_id, bw_sem_id, bf_id + 1, -1, -1)
741
+ else:
742
+ raise RuntimeError("Unreachable")
743
+
744
+ def dynamic_ffn1(
745
+ t_b32_vmem,
746
+ w1_vmem,
747
+ w1_scale_vmem,
748
+ b1_vmem,
749
+ w3_vmem,
750
+ w3_scale_vmem,
751
+ b3_vmem,
752
+ acc1_vmem,
753
+ acc3_vmem,
754
+ dyn_sz,
755
+ should_init,
756
+ ):
757
+ assert t_b32_vmem.shape == (bt * num_devices, bd1 // t_packing)
758
+ assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_t_packing,
759
+ bf)
760
+ assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
761
+ assert bd1 % (t_packing * 128) == 0, (bd1, t_packing)
762
+ assert bd1c % (t_packing * 128) == 0, (bd1c, t_packing)
763
+ if w1_scale_vmem is not None:
764
+ assert w1_scale_vmem.shape == (
765
+ t_packing,
766
+ bd1_per_t_packing // subc_quant_wsz,
767
+ 1,
768
+ bf,
769
+ )
770
+ assert bd1c_per_t_packing == subc_quant_wsz
771
+ if w3_scale_vmem is not None:
772
+ assert w3_scale_vmem.shape == (
773
+ t_packing,
774
+ bd1_per_t_packing // subc_quant_wsz,
775
+ 1,
776
+ bf,
777
+ )
778
+ assert bd1c_per_t_packing == subc_quant_wsz
779
+
780
+ num_loops = cdiv(dyn_sz, btc)
781
+ repack_ty = jnp.dtype(f"int{t_bitwidth}")
782
+
783
+ def body(btc_id, _):
784
+ for bd1c_id in range(cdiv(bd1, bd1c)):
785
+ t_b32 = t_b32_vmem[
786
+ pl.ds(btc_id * btc, btc),
787
+ pl.ds(bd1c_id * bd1c_per_t_packing, bd1c_per_t_packing),
788
+ ]
789
+ for p_id in range(t_packing):
790
+ t = pltpu.bitcast(t_b32.astype(repack_ty), t_dtype)
791
+ t_b32 = t_b32 >> t_bitwidth
792
+ for bfc_id in range(cdiv(bf, bfc)):
793
+ w_slices = (
794
+ p_id,
795
+ pl.ds(bd1c_id * bd1c_per_t_packing,
796
+ bd1c_per_t_packing),
797
+ pl.ds(bfc_id * bfc, bfc),
798
+ )
799
+ w1 = w1_vmem[*w_slices]
800
+ acc1 = jnp.dot(t,
801
+ w1,
802
+ preferred_element_type=jnp.float32)
803
+
804
+ if w1_scale_vmem is not None:
805
+ w1_scale_slices = (
806
+ p_id,
807
+ bd1c_id,
808
+ pl.ds(0, 1),
809
+ pl.ds(bfc_id * bfc, bfc),
810
+ )
811
+ # TODO(jevinjiang): can use mosaic to load with stride 0.
812
+ w1_scale = jnp.broadcast_to(
813
+ w1_scale_vmem[*w1_scale_slices], acc1.shape)
814
+ acc1 *= w1_scale
815
+
816
+ w3 = w3_vmem[*w_slices]
817
+
818
+ acc3 = jnp.dot(t,
819
+ w3,
820
+ preferred_element_type=jnp.float32)
821
+
822
+ if w3_scale_vmem is not None:
823
+ w3_scale_slices = (
824
+ p_id,
825
+ bd1c_id,
826
+ pl.ds(0, 1),
827
+ pl.ds(bfc_id * bfc, bfc),
828
+ )
829
+ w3_scale = jnp.broadcast_to(
830
+ w3_scale_vmem[*w3_scale_slices], acc3.shape)
831
+ acc3 *= w3_scale
832
+
833
+ acc_slices = (pl.ds(btc_id * btc,
834
+ btc), pl.ds(bfc_id * bfc, bfc))
835
+ if should_init and p_id == bd1c_id == 0:
836
+ if b1_vmem is not None:
837
+ b1_scale_slices = (
838
+ pl.ds(0, 1),
839
+ pl.ds(bfc_id * bfc, bfc),
840
+ )
841
+ b1 = jnp.broadcast_to(
842
+ b1_vmem[*b1_scale_slices], acc1.shape)
843
+ acc1 += b1
844
+ if b3_vmem is not None:
845
+ b3_scale_slices = (
846
+ pl.ds(0, 1),
847
+ pl.ds(bfc_id * bfc, bfc),
848
+ )
849
+ b3 = jnp.broadcast_to(
850
+ b3_vmem[*b3_scale_slices], acc1.shape)
851
+ acc3 += b3
852
+
853
+ acc1_vmem[*acc_slices] = acc1
854
+ acc3_vmem[*acc_slices] = acc3
855
+ else:
856
+ acc1_vmem[*acc_slices] += acc1
857
+ acc3_vmem[*acc_slices] += acc3
858
+
859
+ lax.fori_loop(0, num_loops, body, None)
860
+
861
+ def dynamic_ffn2(
862
+ acc1_vmem,
863
+ acc3_vmem,
864
+ w2_vmem,
865
+ w2_scale_vmem,
866
+ b2_vmem,
867
+ res_b32_vmem,
868
+ dyn_sz,
869
+ should_init,
870
+ ):
871
+ assert res_b32_vmem.shape == (bt * num_devices, bd2_per_t_packing)
872
+ assert w2_vmem.shape == (t_packing, bf, bd2_per_t_packing)
873
+ assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
874
+ assert bd2 % (t_packing * 128) == 0, (bd2, t_packing)
875
+ assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
876
+ assert t_dtype in (jnp.float32, jnp.bfloat16)
877
+
878
+ if w2_scale_vmem is not None:
879
+ assert w2_scale_vmem.shape == (
880
+ t_packing,
881
+ bf // subc_quant_wsz,
882
+ 1,
883
+ bd2_per_t_packing,
884
+ )
885
+ assert bfc == subc_quant_wsz
886
+
887
+ num_loops = cdiv(dyn_sz, btc)
888
+ assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
889
+
890
+ def body(btc_id, _):
891
+ for bd2c_id in range(cdiv(bd2, bd2c)):
892
+ res_lst = []
893
+ for p_id in range(t_packing):
894
+ res = jnp.zeros((btc, bd2c_per_t_packing),
895
+ dtype=jnp.float32)
896
+
897
+ if b2_vmem is not None and should_init:
898
+ b2_scale_slices = (
899
+ p_id,
900
+ pl.ds(0, 1),
901
+ pl.ds(bd2c_id * bd2c_per_t_packing,
902
+ bd2c_per_t_packing),
903
+ )
904
+ b2 = jnp.broadcast_to(b2_vmem[*b2_scale_slices],
905
+ res.shape)
906
+ res += b2
907
+
908
+ for bfc_id in range(cdiv(bf, bfc)):
909
+ acc_slices = (pl.ds(btc_id * btc,
910
+ btc), pl.ds(bfc_id * bfc, bfc))
911
+ acc1 = acc1_vmem[*acc_slices]
912
+ acc3 = acc3_vmem[*acc_slices]
913
+ act = activation_fn(acc1, acc3, act_fn)
914
+ w2 = w2_vmem[
915
+ p_id,
916
+ pl.ds(bfc_id * bfc, bfc),
917
+ pl.ds(bd2c_id *
918
+ bd2c_per_t_packing, bd2c_per_t_packing),
919
+ ]
920
+ acc = jnp.dot(act,
921
+ w2,
922
+ preferred_element_type=jnp.float32)
923
+ if w2_scale_vmem is not None:
924
+ w2_scale_slices = (
925
+ p_id,
926
+ bfc_id,
927
+ pl.ds(0, 1),
928
+ pl.ds(bd2c_id * bd2c_per_t_packing,
929
+ bd2c_per_t_packing),
930
+ )
931
+ w2_scale = jnp.broadcast_to(
932
+ w2_scale_vmem[*w2_scale_slices], acc.shape)
933
+ acc *= w2_scale
934
+ res += acc
935
+ res = pltpu.bitcast(res, jnp.uint32)
936
+ if t_packing == 2:
937
+ res = res >> 16 << (16 * p_id)
938
+ else:
939
+ assert t_packing == 1
940
+ res_lst.append(res)
941
+ res = res_lst[0]
942
+ # TODO(jevinjiang): use interleaved packing when it is exposed to Pallas
943
+ for i in range(1, t_packing):
944
+ res |= res_lst[i]
945
+ sliced_res_vmem = res_b32_vmem.at[
946
+ pl.ds(btc_id * btc, btc),
947
+ pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing),
948
+ ]
949
+ if should_init:
950
+ sliced_res_vmem[...] = res
951
+ else:
952
+ sliced_res_vmem[...] = pltpu.bitcast(
953
+ sliced_res_vmem.bitcast(t_dtype)[...] +
954
+ pltpu.bitcast(res, t_dtype),
955
+ sliced_res_vmem.dtype,
956
+ )
957
+
958
+ lax.fori_loop(0, num_loops, body, None)
959
+
960
+ def expert_ffn(bt_id, e_sem_id, local_e_id):
961
+ bt_sem_id = bt_id % 2
962
+ bw_sem_id = 0
963
+ # start_fetch_bw1(local_e_id, bw_sem_id, 0, 0)
964
+ # start_fetch_bw3(local_e_id, bw_sem_id, 0, 0)
965
+ a2a_s_b32_vmem = (a2a_s_x2_vmem.bitcast(jnp.uint32).reshape(
966
+ 2, bt * num_devices, hidden_size // t_packing).at[e_sem_id])
967
+ a2a_s_acc_b32_vmem = (a2a_s_acc_x2_vmem.bitcast(jnp.uint32).reshape(
968
+ 2, bt * num_devices, hidden_size // t_packing).at[e_sem_id])
969
+ b_acc_vmem_2d = b_acc_vmem.reshape(bt * num_devices, bf * 2)
970
+ b_acc1_vmem = b_acc_vmem_2d.at[:, :bf]
971
+ b_acc3_vmem = b_acc_vmem_2d.at[:, bf:]
972
+
973
+ e_id = my_id * local_num_experts + local_e_id
974
+ dyn_sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
975
+
976
+ bd1_per_t_packing = bd1 // t_packing
977
+ bd2_per_t_packing = bd2 // t_packing
978
+
979
+ for bf_id in range(num_bf):
980
+ for bd1_id in range(num_bd1):
981
+ start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, 0)
982
+ w1_scale_vmem = (None if b_w1_scale_x2_vmem is None else
983
+ b_w1_scale_x2_vmem.at[bw_sem_id])
984
+ w3_scale_vmem = (None if b_w3_scale_x2_vmem is None else
985
+ b_w3_scale_x2_vmem.at[bw_sem_id])
986
+ b1_vmem = None if b_b1_x2_vmem is None else b_b1_x2_vmem.at[
987
+ bf_id % 2]
988
+ b3_vmem = None if b_b3_x2_vmem is None else b_b3_x2_vmem.at[
989
+ bf_id % 2]
990
+ wait_fetch_bw1(local_e_id, bw_sem_id, bf_id, bd1_id)
991
+ wait_fetch_bw3(local_e_id, bw_sem_id, bf_id, bd1_id)
992
+
993
+ dynamic_ffn1(
994
+ t_b32_vmem=a2a_s_b32_vmem.at[
995
+ ...,
996
+ pl.ds(bd1_id * bd1_per_t_packing, bd1_per_t_packing)],
997
+ w1_vmem=b_w1_x2_vmem.at[bw_sem_id],
998
+ w1_scale_vmem=w1_scale_vmem,
999
+ b1_vmem=b1_vmem,
1000
+ w3_vmem=b_w3_x2_vmem.at[bw_sem_id],
1001
+ w3_scale_vmem=w3_scale_vmem,
1002
+ b3_vmem=b3_vmem,
1003
+ acc1_vmem=b_acc1_vmem,
1004
+ acc3_vmem=b_acc3_vmem,
1005
+ dyn_sz=dyn_sz,
1006
+ should_init=(bd1_id == 0),
1007
+ )
1008
+ bw_sem_id = (bw_sem_id + 1) % 2
1009
+
1010
+ for bd2_id in range(num_bd2):
1011
+ start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, num_bd1,
1012
+ bd2_id)
1013
+ wait_fetch_bw2(local_e_id, bw_sem_id, bf_id, bd2_id)
1014
+ if bf_id == bd2_id == 0:
1015
+ wait_a2a_gather_send(bt_id, e_sem_id, local_e_id - 2)
1016
+
1017
+ w2_scale_vmem = (None if b_w2_scale_x2_vmem is None else
1018
+ b_w2_scale_x2_vmem.at[bw_sem_id])
1019
+ b2_vmem = None if b_b2_x2_vmem is None else b_b2_x2_vmem.at[
1020
+ bd2_id % 2]
1021
+ dynamic_ffn2(
1022
+ acc1_vmem=b_acc1_vmem,
1023
+ acc3_vmem=b_acc3_vmem,
1024
+ w2_vmem=b_w2_x2_vmem.at[bw_sem_id],
1025
+ w2_scale_vmem=w2_scale_vmem,
1026
+ b2_vmem=b2_vmem,
1027
+ res_b32_vmem=a2a_s_acc_b32_vmem.at[
1028
+ ...,
1029
+ pl.ds(bd2_id * bd2_per_t_packing, bd2_per_t_packing)],
1030
+ dyn_sz=dyn_sz,
1031
+ should_init=(bf_id == 0),
1032
+ )
1033
+ bw_sem_id = (bw_sem_id + 1) % 2
1034
+
1035
+ def bt_acc(bt_id, top_k_logits_lst):
1036
+ bt_sem_id = bt_id % 2
1037
+ for bt_t_id in range(bt):
1038
+ for k_id in range(top_k):
1039
+ e_id = t2e_routing_x2_smem[bt_sem_id, bt_t_id, k_id]
1040
+ offset = expert_offsets_x2_smem[bt_sem_id, 1, e_id]
1041
+ expert_offsets_x2_smem[bt_sem_id, 1, e_id] = offset + 1
1042
+ pltpu.make_async_copy(
1043
+ src_ref=a2a_g_hbm.at[e_id, pl.ds(offset, 1)],
1044
+ dst_ref=a2a_g_acc_vmem.at[k_id, pl.ds(bt_t_id, 1)],
1045
+ sem=a2a_acc_sem,
1046
+ ).start()
1047
+ pltpu.make_async_copy(
1048
+ src_ref=a2a_g_acc_vmem,
1049
+ dst_ref=a2a_g_acc_vmem,
1050
+ sem=a2a_acc_sem,
1051
+ ).wait()
1052
+ output = None
1053
+ for k_id in range(top_k):
1054
+ acc = a2a_g_acc_vmem[k_id].reshape(bt, hidden_size)
1055
+ logits = broadcast_minor(top_k_logits_lst[k_id], acc.shape)
1056
+ acc *= logits
1057
+ if output is None:
1058
+ output = acc
1059
+ else:
1060
+ output += acc
1061
+ assert output is not None
1062
+ return output.astype(output_hbm.dtype)
1063
+
1064
+ def start_send_bo(bt_id, priority=0):
1065
+ bt_sem_id = bt_id % 2
1066
+ b_output_sem = local_sems.at[bt_sem_id, 4]
1067
+ pltpu.make_async_copy(
1068
+ src_ref=b_output_x2_vmem.at[bt_sem_id],
1069
+ dst_ref=output_hbm.at[pl.ds(bt_id * bt, bt)],
1070
+ sem=b_output_sem,
1071
+ ).start(priority=priority)
1072
+
1073
+ def wait_send_bo(bt_id):
1074
+ is_valid = jnp.logical_and(0 <= bt_id, bt_id < num_bt)
1075
+ sz = pl.multiple_of(lax.select(is_valid, bt, 0), bt)
1076
+ bt_sem_id = (bt_id + 2) % 2
1077
+ b_output_sem = local_sems.at[bt_sem_id, 4]
1078
+ pltpu.make_async_copy(
1079
+ src_ref=output_hbm.at[pl.ds(0, sz)],
1080
+ dst_ref=output_hbm.at[pl.ds(0, sz)],
1081
+ sem=b_output_sem,
1082
+ ).wait()
1083
+
1084
+ ### ------- Kernel start ------- ###
1085
+ start_fetch_b_gating(bt_id=0)
1086
+
1087
+ def run_per_bt(bt_id, e_sem_id):
1088
+ bt_sem_id = bt_id % 2
1089
+ next_bt_id = bt_id + 1
1090
+ start_fetch_b_gating(next_bt_id)
1091
+ wait_fetch_b_gating(bt_id)
1092
+
1093
+ b_gating = b_gating_x2_vmem[bt_sem_id]
1094
+ b_gating_score = jax.nn.softmax(b_gating, axis=-1)
1095
+ top_k_logits_lst, t2e_routing, expert_sizes, expert_starts = get_top_k(
1096
+ b_gating_score, top_k, renormalize_topk_logits)
1097
+
1098
+ all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts,
1099
+ expert_sizes)
1100
+ sync_barrier()
1101
+
1102
+ # Start a2a scatter for first active expert.
1103
+ start_a2a_scatter(bt_id=bt_id, e_sem_id=e_sem_id, local_e_id=0)
1104
+
1105
+ def run_per_expert(local_e_id, e_sem_id):
1106
+ sync_barrier()
1107
+
1108
+ # Prefetch weights for CURRENT active expert.
1109
+ # TODO(jevinjiang): It is hard to prefetch weights in previous iteration
1110
+ # because the expert_ffn keeps overwriting the buffers. Triple buffering
1111
+ # could resolve this but it takes more VMEM scratch. Need further
1112
+ # experiment on this.
1113
+ start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
1114
+ start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
1115
+
1116
+ # Next ids.
1117
+ next_e_sem_id = lax.select(e_sem_id == 0, 1, 0)
1118
+ next_local_e_id = local_e_id + 1
1119
+
1120
+ # Start a2a scatter for NEXT active expert.
1121
+ @pl.when(next_local_e_id < local_num_experts)
1122
+ def _():
1123
+ start_a2a_scatter(bt_id, next_e_sem_id, next_local_e_id)
1124
+
1125
+ # Wait a2a scatter for CURRENT active expert.
1126
+ wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id)
1127
+
1128
+ # Perform FFN for CURRENT active expert.
1129
+ expert_ffn(bt_id, e_sem_id, local_e_id)
1130
+
1131
+ # Start a2a gather to send back tokens for CURRENT active expert.
1132
+ start_a2a_gather(bt_id, e_sem_id, local_e_id)
1133
+
1134
+ # A must-wait before next sync_barrier.
1135
+ wait_a2a_scatter_send(bt_id, e_sem_id, local_e_id)
1136
+ return next_e_sem_id
1137
+
1138
+ e_sem_id = lax.fori_loop(0,
1139
+ local_num_experts,
1140
+ run_per_expert,
1141
+ e_sem_id,
1142
+ unroll=False)
1143
+
1144
+ # Wait to receive a2a gather for ALL experts.
1145
+ wait_a2a_gather_recv_all()
1146
+
1147
+ # Accumulate results for current batch.
1148
+ output = bt_acc(bt_id, top_k_logits_lst)
1149
+
1150
+ # Make sure it is safe to overwrite output buffer.
1151
+ wait_send_bo(bt_id=bt_id - 2)
1152
+ b_output_x2_vmem[bt_sem_id] = output
1153
+
1154
+ start_send_bo(bt_id)
1155
+
1156
+ wait_a2a_gather_send(
1157
+ bt_id,
1158
+ e_sem_id=e_sem_id,
1159
+ local_e_id=local_num_experts - 2,
1160
+ )
1161
+ wait_a2a_gather_send(
1162
+ bt_id,
1163
+ e_sem_id=lax.select(e_sem_id == 0, 1, 0),
1164
+ local_e_id=local_num_experts - 1,
1165
+ )
1166
+ return e_sem_id
1167
+
1168
+ lax.fori_loop(0, num_bt, run_per_bt, 0, unroll=False)
1169
+ wait_send_bo(bt_id=num_bt - 2)
1170
+ wait_send_bo(bt_id=num_bt - 1)
1171
+
1172
+ ### ------- Kernel end ------- ###
1173
+
1174
+
1175
+ @functools.partial(
1176
+ jax.jit,
1177
+ static_argnames=[
1178
+ "mesh",
1179
+ "top_k",
1180
+ "renormalize_topk_logits",
1181
+ "act_fn",
1182
+ "subc_quant_wsz",
1183
+ "bt",
1184
+ "bf",
1185
+ "bd1",
1186
+ "bd2",
1187
+ "btc",
1188
+ "bfc",
1189
+ "bd1c",
1190
+ "bd2c",
1191
+ "ep_axis_name",
1192
+ ],
1193
+ )
1194
+ def fused_ep_moe(
1195
+ mesh: jax.sharding.Mesh,
1196
+ tokens: jax.Array, # (num_tokens, hidden_size)
1197
+ w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
1198
+ w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
1199
+ gating_output: jax.Array, # (num_tokens, num_experts)
1200
+ top_k: int,
1201
+ *,
1202
+ renormalize_topk_logits: bool = False,
1203
+ act_fn: str = "silu",
1204
+ subc_quant_wsz: int | None = None,
1205
+ w1_scale: (
1206
+ jax.Array | None
1207
+ ) = None, # F32(num_experts, 2, hidden_size // subc_quant_wsz, 1, intermediate_size)
1208
+ w2_scale: (
1209
+ jax.Array | None
1210
+ ) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
1211
+ b1: jax.Array | None = None, # F32(num_experts, 2, 1, intermediate_size)
1212
+ b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
1213
+ # Kernel tuning parameters.
1214
+ bt: int,
1215
+ bf: int,
1216
+ bd1: int,
1217
+ bd2: int,
1218
+ btc: int,
1219
+ bfc: int,
1220
+ bd1c: int,
1221
+ bd2c: int,
1222
+ ep_axis_name: str = "model",
1223
+ ):
1224
+ # TODO(jevinjiang): move all these assertions to validation function.
1225
+ if len(mesh.shape) != 2:
1226
+ raise NotImplementedError("Only 2D mesh is supported.")
1227
+
1228
+ for axis_name in mesh.axis_names:
1229
+ if axis_name == ep_axis_name:
1230
+ continue
1231
+ if mesh.shape[axis_name] != 1:
1232
+ raise NotImplementedError(
1233
+ f"Expected all non-ep axis to have size 1 in {mesh.shape=}")
1234
+
1235
+ ep_size = mesh.shape[ep_axis_name]
1236
+ num_devices = ep_size
1237
+
1238
+ num_tokens, hidden_size = tokens.shape
1239
+ num_experts, intermediate_size, _ = w2.shape
1240
+
1241
+ if w1.shape != (num_experts, 2, hidden_size, intermediate_size):
1242
+ raise ValueError(
1243
+ f"Expected {w1.shape=} to be"
1244
+ f" {(num_experts, 2, hidden_size, intermediate_size)}.")
1245
+
1246
+ if w2.shape != (num_experts, intermediate_size, hidden_size):
1247
+ raise ValueError(f"Expected {w2.shape=} to be"
1248
+ f" {(num_experts, intermediate_size, hidden_size)}.")
1249
+
1250
+ if gating_output.shape != (num_tokens, num_experts):
1251
+ raise ValueError(
1252
+ f"Expected {gating_output.shape=} to be {(num_tokens, num_experts)}."
1253
+ )
1254
+
1255
+ if not (0 < top_k <= num_experts):
1256
+ raise ValueError(
1257
+ f"Expected {top_k=} to be in range (0, {num_experts=}].")
1258
+
1259
+ if hidden_size % 128 != 0 or intermediate_size % 128 != 0:
1260
+ raise ValueError(
1261
+ f"Expected {hidden_size=} and {intermediate_size=} to be aligned to"
1262
+ " 128. Did you pad them with zeros outside the kernel?")
1263
+ if num_tokens % ep_size != 0:
1264
+ raise ValueError(
1265
+ f"Expected {num_tokens=} to be aligned to {ep_size=}.")
1266
+ if num_experts % ep_size != 0:
1267
+ raise ValueError(
1268
+ f"Expected {num_experts=} to be aligned to {ep_size=}.")
1269
+
1270
+ local_num_tokens = num_tokens // ep_size
1271
+ # local_num_experts = num_experts // ep_size
1272
+ padded_num_experts = align_to(num_experts, 128)
1273
+ padded_top_k = align_to(top_k, 128)
1274
+ t_dtype = tokens.dtype
1275
+ t_packing = get_dtype_packing(t_dtype)
1276
+
1277
+ # Override bt
1278
+ if local_num_tokens <= t_packing * 8:
1279
+ bt = local_num_tokens
1280
+ btc = bt
1281
+ bt = min(local_num_tokens, bt)
1282
+ # The worst case is that all devices send bt to one device.
1283
+ btc = min(bt, btc, bt * num_devices)
1284
+
1285
+ if local_num_tokens % t_packing != 0:
1286
+ raise ValueError(
1287
+ f"Expected {local_num_tokens=} to be aligned to {t_packing=}.")
1288
+
1289
+ if bt % t_packing != 0:
1290
+ raise ValueError(f"Expected {bt=} to be aligned to {t_packing=}.")
1291
+ if local_num_tokens % bt != 0:
1292
+ raise ValueError(
1293
+ f"Expected {local_num_tokens=} to be aligned to {bt=}.")
1294
+
1295
+ if subc_quant_wsz is not None:
1296
+ if subc_quant_wsz <= 0:
1297
+ raise ValueError(f"Expected {subc_quant_wsz=} to be non-negative.")
1298
+ if subc_quant_wsz % 256 != 0:
1299
+ raise ValueError(
1300
+ "Expected {subc_quant_wsz=} to be aligned to 256.")
1301
+ if hidden_size % subc_quant_wsz != 0:
1302
+ raise ValueError(
1303
+ f"Expected {hidden_size=} to be aligned to {subc_quant_wsz=}.")
1304
+ if intermediate_size % subc_quant_wsz != 0:
1305
+ raise ValueError(
1306
+ f"Expected {intermediate_size=} to be aligned to {subc_quant_wsz=}."
1307
+ )
1308
+ # We force compute size of contracting dim to be subc_quant_wsz. So we can
1309
+ # apply same scale after matmul and accumulation.
1310
+ bd1c = subc_quant_wsz * t_packing
1311
+ bfc = subc_quant_wsz
1312
+
1313
+ if bfc % 128 != 0:
1314
+ raise ValueError(f"Expected {bfc=} to be aligned to 128.")
1315
+ if bd1c % (t_packing * 128) != 0:
1316
+ raise ValueError(
1317
+ f"Expected {bd1c=} to be aligned to {t_packing * 128}.")
1318
+ if bd2c % (t_packing * 128) != 0:
1319
+ raise ValueError(
1320
+ f"Expected {bd2c=} to be aligned to {t_packing * 128}.")
1321
+ if bf % bfc != 0:
1322
+ raise ValueError(f"Expected {bf=} to be aligned to {bfc=}.")
1323
+ if bd1 % bd1c != 0:
1324
+ raise ValueError(f"Expected {bd1=} to be aligned to {bd1c=}.")
1325
+ if bd2 % bd2c != 0:
1326
+ raise ValueError(f"Expected {bd2=} to be aligned to {bd2c=}.")
1327
+ if hidden_size % bd1 != 0 or hidden_size % bd2 != 0:
1328
+ raise ValueError(
1329
+ f"Expected {hidden_size=} to be aligned to {bd1=} and {bd2=}.")
1330
+ if intermediate_size % bf != 0:
1331
+ raise ValueError(
1332
+ f"Expected {intermediate_size=} to be aligned to {bf=}.")
1333
+
1334
+ # Note: we should dump scale as the kernel expected shape in the
1335
+ # checkpoint offline or reshape right after weight loading.
1336
+ if w1_scale is not None:
1337
+ expected_w1_scale_shape = (
1338
+ num_experts,
1339
+ 2,
1340
+ hidden_size // subc_quant_wsz,
1341
+ 1,
1342
+ intermediate_size,
1343
+ )
1344
+ if w1_scale.shape != expected_w1_scale_shape:
1345
+ raise ValueError(
1346
+ f"Expected {w1_scale.shape=} to be {expected_w1_scale_shape}.")
1347
+ if w1_scale.dtype != jnp.float32:
1348
+ w1_scale = w1_scale.astype(jnp.float32)
1349
+
1350
+ if w2_scale is not None:
1351
+ expected_w2_scale_shape = (
1352
+ num_experts,
1353
+ intermediate_size // subc_quant_wsz,
1354
+ 1,
1355
+ hidden_size,
1356
+ )
1357
+ if w2_scale.shape != expected_w2_scale_shape:
1358
+ raise ValueError(
1359
+ f"Expected {w2_scale.shape=} to be {expected_w2_scale_shape}.")
1360
+ if w2_scale.dtype != jnp.float32:
1361
+ w2_scale = w2_scale.astype(jnp.float32)
1362
+
1363
+ if b1 is not None:
1364
+ expected_b1_shape = (num_experts, 2, 1, intermediate_size)
1365
+ if b1.shape != expected_b1_shape:
1366
+ raise ValueError(
1367
+ f"Expected {b1.shape=} to be {expected_b1_shape}.")
1368
+ if b1.dtype != jnp.float32:
1369
+ b1 = b1.astype(jnp.float32)
1370
+
1371
+ if b2 is not None:
1372
+ expected_b2_shape = (num_experts, 1, hidden_size)
1373
+ if b2.shape != expected_b2_shape:
1374
+ raise ValueError(
1375
+ f"Expected {b2.shape=} to be {expected_b2_shape}.")
1376
+ if b2.dtype != jnp.float32:
1377
+ b2 = b2.astype(jnp.float32)
1378
+
1379
+ # Prepare inputs for the kernel.
1380
+ if padded_num_experts != gating_output.shape[-1]:
1381
+ gating_output = jnp.pad(
1382
+ gating_output,
1383
+ ((0, 0), (0, padded_num_experts - gating_output.shape[-1])),
1384
+ constant_values=-jnp.inf,
1385
+ )
1386
+
1387
+ tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
1388
+
1389
+ hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
1390
+ renorm_str = "-renorm_k" if renormalize_topk_logits else ""
1391
+ scope_name = f"fused-moe-k_{top_k}{renorm_str}-bt_{bt}_{btc}-bf_{bf}_{bfc}-bd1_{bd1}_{bd1c}-bd2_{bd2}_{bd2c}"
1392
+ fused_moe = pl.pallas_call(
1393
+ functools.partial(
1394
+ _fused_ep_moe_kernel,
1395
+ top_k=top_k,
1396
+ renormalize_topk_logits=renormalize_topk_logits,
1397
+ ep_axis_name=ep_axis_name,
1398
+ act_fn=act_fn,
1399
+ subc_quant_wsz=subc_quant_wsz,
1400
+ bt=bt,
1401
+ bf=bf,
1402
+ bd1=bd1,
1403
+ bd2=bd2,
1404
+ btc=btc,
1405
+ bfc=bfc,
1406
+ bd1c=bd1c,
1407
+ bd2c=bd2c,
1408
+ ),
1409
+ out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
1410
+ t_dtype),
1411
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1412
+ num_scalar_prefetch=0,
1413
+ in_specs=[
1414
+ hbm_block_spec, # tokens_hbm
1415
+ hbm_block_spec, # w1_hbm
1416
+ hbm_block_spec, # w2_hbm
1417
+ None if w1_scale is None else hbm_block_spec, # w1_scale_hbm
1418
+ None if w2_scale is None else hbm_block_spec, # w2_scale_hbm
1419
+ None if b1 is None else hbm_block_spec, # b1_hbm
1420
+ None if b2 is None else hbm_block_spec, # b2_hbm
1421
+ hbm_block_spec, # gating_output_hbm
1422
+ hbm_block_spec, # a2a_g_hbm
1423
+ ],
1424
+ out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
1425
+ scratch_shapes=([
1426
+ # t2e_routing_x2_smem
1427
+ pltpu.SMEM((2, bt, padded_top_k), jnp.int32),
1428
+ # d2e_count_x2_smem
1429
+ pltpu.SMEM((2, num_devices, 1, padded_num_experts), jnp.int32),
1430
+ # expert_offsets_x2_smem
1431
+ pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
1432
+ # expert_starts_x2_smem
1433
+ pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1434
+ # expert_sizes_x2_smem
1435
+ pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1436
+ # a2a_s_sends_x2_smem
1437
+ pltpu.SMEM((2, ), jnp.int32),
1438
+ # a2a_s_x2_vmem
1439
+ pltpu.VMEM(
1440
+ (
1441
+ 2,
1442
+ bt * num_devices,
1443
+ t_packing,
1444
+ hidden_size // t_packing,
1445
+ ),
1446
+ t_dtype,
1447
+ ),
1448
+ # a2a_s_acc_x2_vmem
1449
+ pltpu.VMEM(
1450
+ (
1451
+ 2,
1452
+ bt * num_devices,
1453
+ t_packing,
1454
+ hidden_size // t_packing,
1455
+ ),
1456
+ t_dtype,
1457
+ ),
1458
+ # a2a_g_acc_vmem
1459
+ pltpu.VMEM((top_k, bt, t_packing, hidden_size // t_packing),
1460
+ t_dtype),
1461
+ # b_gating_x2_vmem
1462
+ pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
1463
+ # b_output_x2_vmem
1464
+ pltpu.VMEM((2, bt, hidden_size), t_dtype),
1465
+ # b_w1_x2_vmem
1466
+ pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1467
+ # b_w3_x2_vmem
1468
+ pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1469
+ # b_w2_x2_vmem
1470
+ pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
1471
+ # b_w1_scale_x2_vmem
1472
+ (None if w1_scale is None else pltpu.VMEM(
1473
+ (
1474
+ 2,
1475
+ t_packing,
1476
+ bd1 // t_packing // subc_quant_wsz,
1477
+ 1,
1478
+ bf,
1479
+ ),
1480
+ jnp.float32,
1481
+ )),
1482
+ # b_w3_scale_x2_vmem
1483
+ (None if w1_scale is None else pltpu.VMEM(
1484
+ (
1485
+ 2,
1486
+ t_packing,
1487
+ bd1 // t_packing // subc_quant_wsz,
1488
+ 1,
1489
+ bf,
1490
+ ),
1491
+ jnp.float32,
1492
+ )),
1493
+ # b_w2_scale_x2_vmem
1494
+ (None if w2_scale is None else pltpu.VMEM(
1495
+ (
1496
+ 2,
1497
+ t_packing,
1498
+ bf // subc_quant_wsz,
1499
+ 1,
1500
+ bd2 // t_packing,
1501
+ ),
1502
+ jnp.float32,
1503
+ )),
1504
+ # b_b1_x2_vmem
1505
+ (None if b1 is None else pltpu.VMEM(
1506
+ (
1507
+ 2,
1508
+ 1,
1509
+ bf,
1510
+ ),
1511
+ jnp.float32,
1512
+ )),
1513
+ # b_b3_x2_vmem
1514
+ (None if b1 is None else pltpu.VMEM(
1515
+ (
1516
+ 2,
1517
+ 1,
1518
+ bf,
1519
+ ),
1520
+ jnp.float32,
1521
+ )),
1522
+ # b_b2_x2_vmem
1523
+ (None if b2 is None else pltpu.VMEM(
1524
+ (
1525
+ 2,
1526
+ t_packing,
1527
+ 1,
1528
+ bd2 // t_packing,
1529
+ ),
1530
+ jnp.float32,
1531
+ )),
1532
+ # b_acc_vmem
1533
+ pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
1534
+ # local_sems
1535
+ pltpu.SemaphoreType.DMA((2, 5)),
1536
+ # send_sems
1537
+ pltpu.SemaphoreType.DMA((2, )),
1538
+ # recv_sems
1539
+ pltpu.SemaphoreType.DMA((2, )),
1540
+ # a2a_gather_sem
1541
+ pltpu.SemaphoreType.DMA,
1542
+ # a2a_acc_sem
1543
+ pltpu.SemaphoreType.DMA,
1544
+ ]),
1545
+ ),
1546
+ compiler_params=pltpu.CompilerParams(
1547
+ collective_id=0,
1548
+ vmem_limit_bytes=100 * 1024 * 1024,
1549
+ ),
1550
+ name=scope_name,
1551
+ )
1552
+
1553
+ @jax.jit
1554
+ @jax.shard_map(
1555
+ mesh=mesh,
1556
+ in_specs=(
1557
+ P(ep_axis_name), # tokens_hbm
1558
+ P(ep_axis_name), # w1_hbm
1559
+ P(ep_axis_name), # w2_hbm
1560
+ None if w1_scale is None else P(ep_axis_name), # w1_scale_hbm
1561
+ None if w2_scale is None else P(ep_axis_name), # w2_scale_hbm
1562
+ None if b1 is None else P(ep_axis_name), # b1_hbm
1563
+ None if b2 is None else P(ep_axis_name), # b2_hbm
1564
+ P(ep_axis_name), # gating_output_hbm
1565
+ P(), # a2a_g_hbm
1566
+ ),
1567
+ out_specs=P(ep_axis_name),
1568
+ check_vma=False,
1569
+ )
1570
+ def kernel(
1571
+ tokens,
1572
+ w1,
1573
+ w2,
1574
+ w1_scale,
1575
+ w2_scale,
1576
+ b1,
1577
+ b2,
1578
+ gating_output,
1579
+ a2a_g_hbm_scratch,
1580
+ ):
1581
+ return fused_moe(
1582
+ pltpu.with_memory_space_constraint(tokens,
1583
+ pltpu.HBM), # tokens_hbm
1584
+ pltpu.with_memory_space_constraint(w1, pltpu.HBM), # w1_hbm
1585
+ pltpu.with_memory_space_constraint(w2, pltpu.HBM), # w2_hbm
1586
+ (None if w1_scale is None else pltpu.with_memory_space_constraint(
1587
+ w1_scale, pltpu.HBM)), # w1_scale_hbm
1588
+ (None if w2_scale is None else pltpu.with_memory_space_constraint(
1589
+ w2_scale, pltpu.HBM)), # w2_scale_hbm
1590
+ (None if b1 is None else pltpu.with_memory_space_constraint(
1591
+ b1, pltpu.HBM)), # b1_hbm
1592
+ (None if b2 is None else pltpu.with_memory_space_constraint(
1593
+ b2, pltpu.HBM)), # b2_hbm
1594
+ pltpu.with_memory_space_constraint(gating_output,
1595
+ pltpu.HBM), # gating_output_hbm
1596
+ pltpu.with_memory_space_constraint(a2a_g_hbm_scratch,
1597
+ pltpu.HBM), # a2a_g_hbm
1598
+ )
1599
+
1600
+ a2a_g_hbm_scratch = pl.empty(
1601
+ (num_experts, bt, t_packing, hidden_size // t_packing), t_dtype)
1602
+ return kernel(
1603
+ tokens,
1604
+ w1,
1605
+ w2,
1606
+ w1_scale,
1607
+ w2_scale,
1608
+ b1,
1609
+ b2,
1610
+ gating_output,
1611
+ a2a_g_hbm_scratch,
1612
+ )