tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1566 @@
1
+ """TPU-Friendly Fused Mixture of Experts (MoE) kernel."""
2
+
3
+ import functools
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from jax import lax
8
+ from jax._src import dtypes
9
+ from jax.experimental import pallas as pl
10
+ from jax.experimental.pallas import tpu as pltpu
11
+
12
+ P = jax.sharding.PartitionSpec
13
+
14
+ cdiv = pl.cdiv
15
+
16
+
17
+ def align_to(x, a):
18
+ return cdiv(x, a) * a
19
+
20
+
21
+ def get_dtype_packing(dtype):
22
+ bits = dtypes.bit_width(dtype)
23
+ return 32 // bits
24
+
25
+
26
+ def broadcast_minor(src, shape):
27
+ if src.shape == shape:
28
+ return src
29
+ assert src.shape[:-1] == shape[:-1]
30
+ assert src.shape[-1] % 128 == 0
31
+ target_minor = align_to(shape[-1], src.shape[-1])
32
+ # no-op concatenation.
33
+ return jnp.concatenate([src for _ in range(target_minor // src.shape[-1])],
34
+ axis=-1)[..., :shape[-1]]
35
+
36
+
37
+ def swigluoai(gate: jax.Array,
38
+ up: jax.Array,
39
+ *,
40
+ alpha: float = 1.702,
41
+ limit: float = 7.0) -> jax.Array:
42
+ """Activation used in some models such as GPT-OSS."""
43
+ gate = jnp.clip(gate, a_max=limit)
44
+ up = jnp.clip(up, a_min=-limit, a_max=limit)
45
+ glu = gate * jax.nn.sigmoid(alpha * gate)
46
+ return (up + 1.0) * glu
47
+
48
+
49
+ def activation_fn(acc1, acc3, act_fn):
50
+ if act_fn == "silu":
51
+ return jax.nn.silu(acc1) * acc3
52
+ elif act_fn == "gelu":
53
+ return jax.nn.gelu(acc1) * acc3
54
+ elif act_fn == "swigluoai":
55
+ return swigluoai(acc1, acc3)
56
+ else:
57
+ raise RuntimeError(f"Unsupported activation function: {act_fn}")
58
+
59
+
60
+ def ref_moe(
61
+ tokens: jax.Array, # (num_tokens, hidden_size)
62
+ w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
63
+ w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
64
+ gating_output: jax.Array, # (num_tokens, num_experts)
65
+ top_k: int,
66
+ *,
67
+ renormalize_topk_logits: bool = False,
68
+ activation="silu",
69
+ subc_quant_wsz: int | None = None,
70
+ w1_scale:
71
+ (
72
+ jax.Array | None
73
+ ) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size)
74
+ w2_scale:
75
+ (
76
+ jax.Array | None
77
+ ) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size)
78
+ b1: jax.Array | None = None, # (num_experts, 2, intermediate_size)
79
+ b2: jax.Array | None = None, # (num_experts, hidden_size)
80
+ ):
81
+ n_tokens = tokens.shape[0] # num_tokens
82
+
83
+ # Compute gating scores for all experts
84
+ gating_logits = jax.nn.softmax(gating_output,
85
+ axis=-1) # [num_tokens, n_experts]
86
+
87
+ # Select top-k experts per token
88
+ top_k_logits, top_k_indices = lax.top_k(
89
+ gating_logits, top_k) # [num_tokens, top_k], [num_tokens, top_k]
90
+
91
+ if renormalize_topk_logits:
92
+ top_k_logits = top_k_logits / jnp.sum(
93
+ top_k_logits, axis=-1, keepdims=True)
94
+
95
+ t_outputs = []
96
+ hidden_size, intermediate_size = w1.shape[-2:]
97
+
98
+ # Process each token individually
99
+ for i in range(n_tokens):
100
+ curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, d_model]
101
+ assigned_expert_ids = top_k_indices[
102
+ i] # [top_k] - indices of selected experts for token i
103
+ tok_expert_act = []
104
+
105
+ # Process each selected expert for the current token
106
+ for expert_id in assigned_expert_ids:
107
+ # Get expert weights
108
+ expert_w1 = w1[expert_id, 0].astype(jnp.float32)
109
+ expert_w3 = w1[expert_id, 1].astype(jnp.float32)
110
+ if w1_scale is not None:
111
+ expert_w1 *= jnp.repeat(w1_scale[expert_id, 0],
112
+ subc_quant_wsz,
113
+ axis=0)[:hidden_size]
114
+ expert_w3 *= jnp.repeat(w1_scale[expert_id, 1],
115
+ subc_quant_wsz,
116
+ axis=0)[:hidden_size]
117
+ expert_weight_1 = jnp.concat(
118
+ [expert_w1, expert_w3],
119
+ axis=-1) # [d_model, 2 * intermediate_size]
120
+ expert_weight_2 = w2[expert_id].astype(
121
+ jnp.float32) # [intermediate_size, d_model]
122
+ if w2_scale is not None:
123
+ expert_weight_2 *= jnp.repeat(w2_scale[expert_id],
124
+ subc_quant_wsz,
125
+ axis=0)[:intermediate_size]
126
+
127
+ # First linear layer with SwiGLU activation
128
+ gmm_1_out = curr_token @ expert_weight_1 # [1, 2 * intermediate_size]
129
+
130
+ # Split into gate and up projections for SwiGLU
131
+ gmm1_w1_proj, gmm1_w3_proj = jnp.split(
132
+ gmm_1_out, 2,
133
+ axis=-1) # [1, intermediate_size], [1, intermediate_size]
134
+ if b1 is not None:
135
+ gmm1_w1_proj += b1[expert_id:expert_id + 1, 0]
136
+ gmm1_w3_proj += b1[expert_id:expert_id + 1, 1]
137
+
138
+ # Apply gated activation: activation(gate) * up
139
+ act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, activation)
140
+
141
+ # Second linear layer (down projection)
142
+ gmm_2_out = act @ expert_weight_2 # [1, d_model]
143
+ if b2 is not None:
144
+ gmm_2_out += b2[expert_id:expert_id + 1]
145
+ tok_expert_act.append(gmm_2_out)
146
+
147
+ # Combine outputs from all selected experts
148
+ experts_act = jnp.concatenate(tok_expert_act,
149
+ axis=0) # [top_k, d_model]
150
+
151
+ # Weighted sum using top-k gating weights
152
+ top_k_weights = top_k_logits[i] # [top_k]
153
+ top_k_weights = jnp.expand_dims(top_k_weights, axis=1) # [top_k, 1]
154
+ weighted_output = jnp.sum(experts_act * top_k_weights,
155
+ axis=0,
156
+ keepdims=True) # [1, d_model]
157
+
158
+ t_outputs.append(weighted_output.astype(tokens.dtype))
159
+
160
+ return jnp.concatenate(t_outputs, axis=0) # [num_tokens, d_model]
161
+
162
+
163
+ def _fused_ep_moe_kernel(
164
+ # Input
165
+ tokens_hbm, # (local_num_tokens, t_packing, hidden_size // t_packing)
166
+ w1_hbm, # (local_num_experts, 2, hidden_size, intermediate_size)
167
+ w2_hbm, # (local_num_experts, intermediate_size, hidden_size)
168
+ # TODO(jevinjiang): We choose F32 scale for easier slicing. The extra
169
+ # latency should be hidden in the pipeline overlaping. But is there a better
170
+ # way to do this?
171
+ w1_scale_hbm, # None | F32(local_num_experts, 2, cdiv(hidden_size, subc_quant_wsz), 1, intermediate_size)
172
+ w2_scale_hbm, # None | F32(local_num_experts, cdiv(intermediate_size, subc_quant_wsz), 1, hidden_size)
173
+ b1_hbm, # None | F32(local_num_experts, 2, 1, intermediate_size)
174
+ b2_hbm, # None | F32(local_num_experts, 1, hidden_size)
175
+ gating_hbm, # (local_num_tokens, padded_num_experts)
176
+ a2a_g_hbm, # (num_experts, bt, t_packing, hidden_size // t_packing)
177
+ # Output
178
+ output_hbm, # (local_num_tokens, hidden_size)
179
+ # Scratch
180
+ t2e_routing_x2_smem, # <bt_sem_id> (2, bt, padded_num_experts)
181
+ d2e_count_x2_smem, # <bt_sem_id> (2, num_devices, 1, padded_num_experts)
182
+ expert_offsets_x2_smem, # <bt_sem_id> (2, 2, padded_num_experts): for a2a_s and a2a_g
183
+ expert_starts_x2_smem, # <bt_sem_id> (2, 1, padded_num_experts)
184
+ expert_sizes_x2_smem, # <bt_sem_id> (2, 1, padded_num_experts)
185
+ a2a_s_sends_x2_smem, # <e_sem_id> (2,)
186
+ a2a_s_x2_vmem, # <e_sem_id> (2, bt * num_devices, t_packing, hidden_size // t_packing)
187
+ a2a_s_acc_x2_vmem, # <e_sem_id> (2, bt * num_devices, t_packing, hidden_size // t_packing)
188
+ ### Accumulation for gathered tokens:
189
+ a2a_g_acc_vmem, # (top_k, bt, t_packing, hidden_size // t_packing)
190
+ ### Expert weight double buffering:
191
+ b_gating_x2_vmem, # <bt_sem_id> (2, bt, padded_num_experts)
192
+ b_output_x2_vmem, # <bt_sem_id> (2, bt, hidden_size)
193
+ b_w1_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
194
+ b_w3_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
195
+ b_w2_x2_vmem, # <bw_sem_id> (2, t_packing, bf, bd2 // t_packing)
196
+ b_w1_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
197
+ b_w3_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
198
+ b_w2_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bf // subc_quant_wsz, 1, bd2 // t_packing)
199
+ b_b1_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
200
+ b_b3_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
201
+ b_b2_x2_vmem, # None | <bw_sem_id> (2, t_packing, 1, bd2 // t_packing)
202
+ b_acc_vmem, # F32(bt * num_devices, 1, bf * 2)
203
+ ### Semaphores:
204
+ local_sems, # (2, 5): 2 x [b_gating_sem, b_w1_sem, b_w2_sem, b_w3_sem, b_output_sem]
205
+ send_sems, # <e_sem_id> (2,)
206
+ recv_sems, # <e_sem_id> (2,)
207
+ a2a_gather_sem,
208
+ a2a_acc_sem,
209
+ *,
210
+ top_k: int,
211
+ renormalize_topk_logits: bool,
212
+ ep_axis_name: str,
213
+ act_fn: str,
214
+ subc_quant_wsz: int | None = None,
215
+ # Kernel tuning params.
216
+ bt: int, # Block size of local_num_tokens.
217
+ bf: int, # Block size of intermediate_size.
218
+ bd1: int, # Block size of hidden_size in w1.
219
+ bd2: int, # Block size of hidden_size in w2.
220
+ btc: int, # Compute size of block tokens for active expert.
221
+ bfc: int, # Compute size of block intermediate_size.
222
+ bd1c: int, # Compute size of block hidden_size.
223
+ bd2c: int, # Compute size of block hidden_size.
224
+ ):
225
+ my_id = lax.axis_index(ep_axis_name)
226
+ num_devices = lax.axis_size(ep_axis_name)
227
+ local_num_tokens = tokens_hbm.shape[0]
228
+ local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
229
+ right_id = (my_id + 1) % num_devices
230
+
231
+ t_dtype = tokens_hbm.dtype
232
+ t_packing = get_dtype_packing(t_dtype)
233
+ t_bitwidth = 32 // t_packing
234
+ assert a2a_g_hbm.dtype == t_dtype
235
+ assert w1_hbm.dtype == w2_hbm.dtype
236
+
237
+ assert bd1 % bd1c == 0
238
+ assert bd2 % bd2c == 0
239
+ assert bf % bfc == 0
240
+ assert hidden_size % t_packing == 0
241
+ assert bd1 % t_packing == 0
242
+ assert bd2 % t_packing == 0
243
+ assert bd1c % t_packing == 0
244
+ assert bd2c % t_packing == 0
245
+
246
+ h_per_t_packing = hidden_size // t_packing
247
+ assert tokens_hbm.shape[-1] == h_per_t_packing
248
+ bd1_per_t_packing = bd1 // t_packing
249
+ bd2_per_t_packing = bd2 // t_packing
250
+ bd1c_per_t_packing = bd1c // t_packing
251
+ bd2c_per_t_packing = bd2c // t_packing
252
+
253
+ if subc_quant_wsz is not None:
254
+ assert subc_quant_wsz % 256 == 0
255
+ assert bd1c_per_t_packing == subc_quant_wsz
256
+ assert bfc == subc_quant_wsz
257
+ assert bd1 % subc_quant_wsz == 0
258
+ assert bf % subc_quant_wsz == 0
259
+ assert bd1_per_t_packing % subc_quant_wsz == 0
260
+ assert h_per_t_packing % subc_quant_wsz == 0
261
+
262
+ num_bt = cdiv(local_num_tokens, bt)
263
+ num_bf = cdiv(intermediate_size, bf)
264
+ num_bd1 = cdiv(hidden_size, bd1)
265
+ num_bd2 = cdiv(hidden_size, bd2)
266
+
267
+ def get_mesh_device_id(ep_rank):
268
+ dp_rank = jax.lax.axis_index("data")
269
+ return (dp_rank, ep_rank)
270
+
271
+ def sync_barrier():
272
+ barrier_sem = pltpu.get_barrier_semaphore()
273
+ pltpu.semaphore_signal(
274
+ barrier_sem,
275
+ device_id=get_mesh_device_id(right_id),
276
+ device_id_type=pltpu.DeviceIdType.MESH,
277
+ )
278
+ pltpu.semaphore_wait(barrier_sem, 1)
279
+
280
+ def start_fetch_b_gating(bt_id, priority=0):
281
+ is_valid = jnp.logical_and(0 <= bt_id, bt_id < num_bt)
282
+ sz = pl.multiple_of(lax.select(is_valid, bt, 0), bt)
283
+ bt_sem_id = (bt_id + 2) % 2
284
+ b_gating_sem = local_sems.at[bt_sem_id, 0]
285
+ pltpu.make_async_copy(
286
+ src_ref=gating_hbm.at[pl.ds(bt_id * bt, sz)],
287
+ dst_ref=b_gating_x2_vmem.at[bt_sem_id, pl.ds(0, sz)],
288
+ sem=b_gating_sem,
289
+ ).start(priority=priority)
290
+
291
+ def wait_fetch_b_gating(bt_id):
292
+ bt_sem_id = bt_id % 2
293
+ b_gating_sem = local_sems.at[bt_sem_id, 0]
294
+ pltpu.make_async_copy(
295
+ src_ref=b_gating_x2_vmem.at[bt_sem_id],
296
+ dst_ref=b_gating_x2_vmem.at[bt_sem_id],
297
+ sem=b_gating_sem,
298
+ ).wait()
299
+
300
+ def get_top_k(input, top_k, renormalize_topk_logits):
301
+ assert len(input.shape) == 2, input.shape
302
+ input = input.astype(jnp.float32)
303
+ top_k_logits_lst = []
304
+ top_k_indices_lst = []
305
+ t2e = jnp.zeros(input.shape, dtype=jnp.int32)
306
+ t2e_routing = jnp.zeros(input.shape, dtype=jnp.int32)
307
+ iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1)
308
+ top_k_logits_sum = jnp.zeros((input.shape[0], 128), jnp.float32)
309
+
310
+ for k_id in range(top_k):
311
+ # TODO(jevinjiang): return both top_k values and indices in Mosaic
312
+ top_k_logits = jnp.broadcast_to(
313
+ jnp.max(input, axis=1, keepdims=True),
314
+ (input.shape[0], 128)).astype(input.dtype)
315
+ if renormalize_topk_logits:
316
+ top_k_logits_sum += top_k_logits
317
+ top_k_logits_lst.append(top_k_logits)
318
+ # TODO(jevinjiang): support bf16 argmax in Mosaic
319
+ top_k_indices = jnp.broadcast_to(
320
+ jnp.argmax(input, axis=1, keepdims=True), input.shape)
321
+ top_k_indices_lst.append(top_k_indices)
322
+ t2e_routing = jnp.where(iota == k_id, top_k_indices, t2e_routing)
323
+ mask = iota == top_k_indices
324
+ t2e += mask.astype(jnp.int32)
325
+ if k_id != top_k - 1:
326
+ input = jnp.where(mask, -jnp.inf, input)
327
+
328
+ if renormalize_topk_logits:
329
+ for k_id in range(top_k):
330
+ top_k_logits_lst[
331
+ k_id] = top_k_logits_lst[k_id] / top_k_logits_sum
332
+
333
+ expert_sizes = jnp.sum(t2e, axis=0, keepdims=True)
334
+ expert_starts = jnp.zeros_like(expert_sizes)
335
+ return top_k_logits_lst, t2e_routing, expert_sizes, expert_starts
336
+
337
+ def all_reduce_metadata(bt_sem_id, t2e_routing, starts, sizes):
338
+ send_sem = send_sems.at[0]
339
+ recv_sem = recv_sems.at[0]
340
+
341
+ # All-reduce to accumulate starts and sizes and transfer to SMEM.
342
+ def _all_reduce_metadata(
343
+ t2e_routing_vmem,
344
+ d2e_count_vmem,
345
+ offsets_vmem,
346
+ starts_vmem,
347
+ sizes_vmem,
348
+ ):
349
+ offsets_vmem[...] = jnp.zeros_like(offsets_vmem)
350
+ # TODO(jevinjiang): check how slow is VMEM -> SMEM.
351
+ offsets_copy = pltpu.async_copy(
352
+ src_ref=offsets_vmem,
353
+ dst_ref=expert_offsets_x2_smem.at[bt_sem_id],
354
+ sem=send_sem,
355
+ )
356
+ t2e_routing_vmem[...] = t2e_routing
357
+ t2e_routing_copy = pltpu.async_copy(
358
+ src_ref=t2e_routing_vmem,
359
+ dst_ref=t2e_routing_x2_smem.at[bt_sem_id],
360
+ sem=send_sem,
361
+ )
362
+ reduced_sizes = sizes
363
+ reduced_starts = starts
364
+ row_id = my_id
365
+ d2e_count_vmem[row_id] = sizes
366
+ for i in range(num_devices - 1):
367
+ sync_barrier()
368
+ # TODO(jevinjiang): we can use double buffering to improve AR if needed.
369
+ pltpu.async_remote_copy(
370
+ src_ref=d2e_count_vmem.at[row_id],
371
+ dst_ref=d2e_count_vmem.at[row_id],
372
+ send_sem=send_sem,
373
+ recv_sem=recv_sem,
374
+ device_id=get_mesh_device_id(right_id),
375
+ device_id_type=pltpu.DeviceIdType.MESH,
376
+ ).wait()
377
+ row_id = (row_id + num_devices - 1) % num_devices
378
+ new_sizes = d2e_count_vmem[row_id]
379
+ reduced_sizes += new_sizes
380
+ reduced_starts += lax.select(my_id > i, new_sizes,
381
+ jnp.zeros_like(new_sizes))
382
+ starts_vmem[...] = reduced_starts
383
+ sizes_vmem[...] = reduced_sizes
384
+
385
+ starts_copy = pltpu.async_copy(
386
+ src_ref=starts_vmem,
387
+ dst_ref=expert_starts_x2_smem.at[bt_sem_id],
388
+ sem=send_sem,
389
+ )
390
+ sizes_copy = pltpu.async_copy(
391
+ src_ref=sizes_vmem,
392
+ dst_ref=expert_sizes_x2_smem.at[bt_sem_id],
393
+ sem=send_sem,
394
+ )
395
+
396
+ # TODO(jevinjiang): if d2e_count is too big, we can store in HBM and fetch
397
+ # to SMEM partially.
398
+ d2e_count_copy = pltpu.async_copy(
399
+ src_ref=d2e_count_vmem,
400
+ dst_ref=d2e_count_x2_smem.at[bt_sem_id],
401
+ sem=send_sem,
402
+ )
403
+
404
+ t2e_routing_copy.wait()
405
+ d2e_count_copy.wait()
406
+ offsets_copy.wait()
407
+ starts_copy.wait()
408
+ sizes_copy.wait()
409
+
410
+ pl.run_scoped(
411
+ _all_reduce_metadata,
412
+ pltpu.VMEM(t2e_routing_x2_smem.shape[1:],
413
+ t2e_routing_x2_smem.dtype),
414
+ pltpu.VMEM(d2e_count_x2_smem.shape[1:], d2e_count_x2_smem.dtype),
415
+ pltpu.VMEM(expert_offsets_x2_smem.shape[1:],
416
+ expert_offsets_x2_smem.dtype),
417
+ pltpu.VMEM(expert_starts_x2_smem.shape[1:],
418
+ expert_starts_x2_smem.dtype),
419
+ pltpu.VMEM(expert_sizes_x2_smem.shape[1:],
420
+ expert_sizes_x2_smem.dtype),
421
+ )
422
+
423
+ def start_a2a_scatter(bt_id, e_sem_id, local_e_id):
424
+ bt_sem_id = bt_id % 2
425
+
426
+ # Counting the number of remote sends from the current device.
427
+ send_sz = 0
428
+ for bt_t_id in range(bt):
429
+ for k_id in range(top_k):
430
+ e_id = t2e_routing_x2_smem[bt_sem_id, bt_t_id, k_id]
431
+ is_active_expert = e_id % local_num_experts == local_e_id
432
+ recv_id = e_id // local_num_experts
433
+ offset = expert_offsets_x2_smem[bt_sem_id, 0, e_id]
434
+ sz = lax.select(is_active_expert, 1, 0)
435
+ is_local = recv_id == my_id
436
+ local_sz = lax.select(is_local, sz, 0)
437
+ remote_sz = lax.select(is_local, 0, sz)
438
+ send_sz += remote_sz
439
+ expert_offsets_x2_smem[bt_sem_id, 0,
440
+ e_id] = (offset + local_sz + remote_sz)
441
+ start = expert_starts_x2_smem[bt_sem_id, 0, e_id] + offset
442
+ t_id = bt * bt_id + bt_t_id
443
+ # TODO(jevinjiang): compare the perf when using branches.
444
+ pltpu.make_async_copy(
445
+ src_ref=tokens_hbm.at[pl.ds(t_id, local_sz)],
446
+ dst_ref=a2a_s_x2_vmem.at[e_sem_id,
447
+ pl.ds(start, local_sz)],
448
+ sem=recv_sems.at[e_sem_id],
449
+ ).start()
450
+ pltpu.make_async_remote_copy(
451
+ src_ref=tokens_hbm.at[pl.ds(t_id, remote_sz)],
452
+ dst_ref=a2a_s_x2_vmem.at[e_sem_id,
453
+ pl.ds(start, remote_sz)],
454
+ send_sem=send_sems.at[e_sem_id],
455
+ recv_sem=recv_sems.at[e_sem_id],
456
+ device_id=get_mesh_device_id(recv_id),
457
+ device_id_type=pltpu.DeviceIdType.MESH,
458
+ ).start()
459
+ a2a_s_sends_x2_smem[e_sem_id] = send_sz
460
+
461
+ def wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id):
462
+ bt_sem_id = bt_id % 2
463
+ e_id = my_id * local_num_experts + local_e_id
464
+ sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
465
+ pltpu.make_async_copy(
466
+ src_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
467
+ dst_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
468
+ sem=recv_sems.at[e_sem_id],
469
+ ).wait()
470
+
471
+ def wait_a2a_scatter_send(bt_id, e_sem_id, local_e_id):
472
+ del bt_id, local_e_id
473
+ sz = a2a_s_sends_x2_smem[e_sem_id]
474
+ pltpu.make_async_copy(
475
+ src_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
476
+ dst_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
477
+ sem=send_sems.at[e_sem_id],
478
+ ).wait()
479
+
480
+ def start_a2a_gather(bt_id, e_sem_id, local_e_id):
481
+ my_e_id = my_id * local_num_experts + local_e_id
482
+ bt_sem_id = bt_id % 2
483
+ start = 0
484
+ for recv_id in range(num_devices):
485
+ sz = d2e_count_x2_smem[bt_sem_id, recv_id, 0, my_e_id]
486
+ is_local = recv_id == my_id
487
+ local_sz = lax.select(is_local, sz, 0)
488
+ remote_sz = lax.select(is_local, 0, sz)
489
+ pltpu.make_async_copy(
490
+ src_ref=a2a_s_acc_x2_vmem.at[e_sem_id,
491
+ pl.ds(start, local_sz)],
492
+ dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, local_sz)],
493
+ sem=a2a_gather_sem,
494
+ ).start()
495
+ pltpu.make_async_remote_copy(
496
+ src_ref=a2a_s_acc_x2_vmem.at[e_sem_id,
497
+ pl.ds(start, remote_sz)],
498
+ dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)],
499
+ send_sem=send_sems.at[e_sem_id],
500
+ recv_sem=a2a_gather_sem,
501
+ device_id=get_mesh_device_id(recv_id),
502
+ device_id_type=pltpu.DeviceIdType.MESH,
503
+ ).start()
504
+ start += sz
505
+
506
+ def wait_a2a_gather_send(bt_id, e_sem_id, local_e_id):
507
+ my_e_id = my_id * local_num_experts + local_e_id
508
+ bt_sem_id = bt_id % 2
509
+ sz = expert_sizes_x2_smem[bt_sem_id, 0, my_e_id]
510
+ local_sz = d2e_count_x2_smem[bt_sem_id, my_id, 0, my_e_id]
511
+ remote_sz = sz - local_sz
512
+ is_valid = jnp.logical_and(0 <= local_e_id, local_e_id
513
+ < local_num_experts)
514
+ remote_sz = lax.select(is_valid, remote_sz, 0)
515
+ pltpu.make_async_copy(
516
+ src_ref=a2a_g_hbm.at[0, pl.ds(0, remote_sz)],
517
+ dst_ref=a2a_g_hbm.at[0, pl.ds(0, remote_sz)],
518
+ sem=send_sems.at[e_sem_id],
519
+ ).wait()
520
+
521
+ def wait_a2a_gather_recv_all():
522
+ sz = top_k * bt
523
+ pltpu.make_async_copy(
524
+ src_ref=a2a_g_hbm.at[0, pl.ds(0, sz)],
525
+ dst_ref=a2a_g_hbm.at[0, pl.ds(0, sz)],
526
+ sem=a2a_gather_sem,
527
+ ).wait()
528
+
529
+ def start_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
530
+ for p in range(t_packing):
531
+ offset = p * h_per_t_packing + bd1_id * bd1_per_t_packing
532
+ pltpu.make_async_copy(
533
+ src_ref=w1_hbm.at[
534
+ local_e_id,
535
+ 0,
536
+ pl.ds(offset, bd1_per_t_packing),
537
+ pl.ds(bf_id * bf, bf),
538
+ ],
539
+ dst_ref=b_w1_x2_vmem.at[bw1_sem_id, p],
540
+ sem=local_sems.at[bw1_sem_id, 1],
541
+ ).start()
542
+ if w1_scale_hbm is not None:
543
+ assert subc_quant_wsz is not None
544
+ pltpu.make_async_copy(
545
+ src_ref=w1_scale_hbm.at[
546
+ local_e_id,
547
+ 0,
548
+ pl.ds(
549
+ offset // subc_quant_wsz,
550
+ bd1_per_t_packing // subc_quant_wsz,
551
+ ),
552
+ pl.ds(0, 1),
553
+ pl.ds(bf_id * bf, bf),
554
+ ],
555
+ dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id, p],
556
+ sem=local_sems.at[bw1_sem_id, 1],
557
+ ).start()
558
+ if b1_hbm is not None and bd1_id == 0:
559
+ pltpu.make_async_copy(
560
+ src_ref=b1_hbm.at[local_e_id, 0,
561
+ pl.ds(0, 1),
562
+ pl.ds(bf_id * bf, bf)],
563
+ dst_ref=b_b1_x2_vmem.at[bf_id % 2],
564
+ sem=local_sems.at[bw1_sem_id, 1],
565
+ ).start()
566
+
567
+ def start_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
568
+ for p in range(t_packing):
569
+ offset = p * h_per_t_packing + bd2_id * bd2_per_t_packing
570
+ pltpu.make_async_copy(
571
+ src_ref=w2_hbm.at[
572
+ local_e_id,
573
+ pl.ds(bf_id * bf, bf),
574
+ pl.ds(offset, bd2_per_t_packing),
575
+ ],
576
+ dst_ref=b_w2_x2_vmem.at[bw2_sem_id, p],
577
+ sem=local_sems.at[bw2_sem_id, 2],
578
+ ).start()
579
+ if w2_scale_hbm is not None:
580
+ assert subc_quant_wsz is not None
581
+ pltpu.make_async_copy(
582
+ src_ref=w2_scale_hbm.at[
583
+ local_e_id,
584
+ pl.ds(bf_id * bf // subc_quant_wsz, bf //
585
+ subc_quant_wsz),
586
+ pl.ds(0, 1),
587
+ pl.ds(offset, bd2_per_t_packing),
588
+ ],
589
+ dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id, p],
590
+ sem=local_sems.at[bw2_sem_id, 2],
591
+ ).start()
592
+ if b2_hbm is not None and bf_id == 0:
593
+ pltpu.make_async_copy(
594
+ src_ref=b2_hbm.at[local_e_id,
595
+ pl.ds(0, 1),
596
+ pl.ds(offset, bd2_per_t_packing)],
597
+ dst_ref=b_b2_x2_vmem.at[bd2_id % 2, p],
598
+ sem=local_sems.at[bw2_sem_id, 2],
599
+ ).start()
600
+
601
+ def start_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
602
+ for p in range(t_packing):
603
+ offset = p * h_per_t_packing + bd3_id * bd1_per_t_packing
604
+ pltpu.make_async_copy(
605
+ src_ref=w1_hbm.at[
606
+ local_e_id,
607
+ 1,
608
+ pl.ds(offset, bd1_per_t_packing),
609
+ pl.ds(bf_id * bf, bf),
610
+ ],
611
+ dst_ref=b_w3_x2_vmem.at[bw3_sem_id, p],
612
+ sem=local_sems.at[bw3_sem_id, 3],
613
+ ).start()
614
+ if w1_scale_hbm is not None:
615
+ assert subc_quant_wsz is not None
616
+ pltpu.make_async_copy(
617
+ src_ref=w1_scale_hbm.at[
618
+ local_e_id,
619
+ 1,
620
+ pl.ds(
621
+ offset // subc_quant_wsz,
622
+ bd1_per_t_packing // subc_quant_wsz,
623
+ ),
624
+ pl.ds(0, 1),
625
+ pl.ds(bf_id * bf, bf),
626
+ ],
627
+ dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id, p],
628
+ sem=local_sems.at[bw3_sem_id, 3],
629
+ ).start()
630
+ if b1_hbm is not None and bd3_id == 0:
631
+ pltpu.make_async_copy(
632
+ src_ref=b1_hbm.at[local_e_id, 1,
633
+ pl.ds(0, 1),
634
+ pl.ds(bf_id * bf, bf)],
635
+ dst_ref=b_b3_x2_vmem.at[bf_id % 2],
636
+ sem=local_sems.at[bw3_sem_id, 3],
637
+ ).start()
638
+
639
+ def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
640
+ del local_e_id
641
+ pltpu.make_async_copy(
642
+ src_ref=b_w1_x2_vmem.at[bw1_sem_id],
643
+ dst_ref=b_w1_x2_vmem.at[bw1_sem_id],
644
+ sem=local_sems.at[bw1_sem_id, 1],
645
+ ).wait()
646
+ if w1_scale_hbm is not None:
647
+ pltpu.make_async_copy(
648
+ src_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
649
+ dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
650
+ sem=local_sems.at[bw1_sem_id, 1],
651
+ ).wait()
652
+ if b1_hbm is not None and bd1_id == 0:
653
+ pltpu.make_async_copy(
654
+ src_ref=b_b1_x2_vmem.at[bf_id % 2],
655
+ dst_ref=b_b1_x2_vmem.at[bf_id % 2],
656
+ sem=local_sems.at[bw1_sem_id, 1],
657
+ ).wait()
658
+
659
+ def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
660
+ del local_e_id
661
+ pltpu.make_async_copy(
662
+ src_ref=b_w2_x2_vmem.at[bw2_sem_id],
663
+ dst_ref=b_w2_x2_vmem.at[bw2_sem_id],
664
+ sem=local_sems.at[bw2_sem_id, 2],
665
+ ).wait()
666
+ if w2_scale_hbm is not None:
667
+ pltpu.make_async_copy(
668
+ src_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
669
+ dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
670
+ sem=local_sems.at[bw2_sem_id, 2],
671
+ ).wait()
672
+ if b2_hbm is not None and bf_id == 0:
673
+ pltpu.make_async_copy(
674
+ src_ref=b_b2_x2_vmem.at[bd2_id % 2],
675
+ dst_ref=b_b2_x2_vmem.at[bd2_id % 2],
676
+ sem=local_sems.at[bw2_sem_id, 2],
677
+ ).wait()
678
+
679
+ def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
680
+ del local_e_id
681
+ pltpu.make_async_copy(
682
+ src_ref=b_w3_x2_vmem.at[bw3_sem_id],
683
+ dst_ref=b_w3_x2_vmem.at[bw3_sem_id],
684
+ sem=local_sems.at[bw3_sem_id, 3],
685
+ ).wait()
686
+ if w1_scale_hbm is not None:
687
+ pltpu.make_async_copy(
688
+ src_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
689
+ dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
690
+ sem=local_sems.at[bw3_sem_id, 3],
691
+ ).wait()
692
+ if b1_hbm is not None and bd3_id == 0:
693
+ pltpu.make_async_copy(
694
+ src_ref=b_b3_x2_vmem.at[bf_id % 2],
695
+ dst_ref=b_b3_x2_vmem.at[bf_id % 2],
696
+ sem=local_sems.at[bw3_sem_id, 3],
697
+ ).wait()
698
+
699
+ def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id):
700
+ next_bd1_id = bd1_id + 1
701
+ next_bd2_id = bd2_id + 1
702
+ next_sem_id = (bw_sem_id + 1) % 2
703
+
704
+ if bf_id >= num_bf:
705
+ return
706
+ if next_bd1_id < num_bd1:
707
+ start_fetch_bw1(local_e_id, next_sem_id, bf_id, next_bd1_id)
708
+ start_fetch_bw3(local_e_id, next_sem_id, bf_id, next_bd1_id)
709
+ elif next_bd1_id == num_bd1:
710
+ start_fetch_bw2(local_e_id, next_sem_id, bf_id, 0)
711
+ elif next_bd2_id < num_bd2:
712
+ start_fetch_bw2(local_e_id, next_sem_id, bf_id, next_bd2_id)
713
+ elif next_bd2_id == num_bd2:
714
+ start_fetch_next_bw(local_e_id, bw_sem_id, bf_id + 1, -1, -1)
715
+ else:
716
+ raise RuntimeError("Unreachable")
717
+
718
+ def dynamic_ffn1(
719
+ t_b32_vmem,
720
+ w1_vmem,
721
+ w1_scale_vmem,
722
+ b1_vmem,
723
+ w3_vmem,
724
+ w3_scale_vmem,
725
+ b3_vmem,
726
+ acc1_vmem,
727
+ acc3_vmem,
728
+ dyn_sz,
729
+ should_init,
730
+ ):
731
+ assert t_b32_vmem.shape == (bt * num_devices, bd1 // t_packing)
732
+ assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_t_packing,
733
+ bf)
734
+ assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
735
+ assert bd1 % (t_packing * 128) == 0, (bd1, t_packing)
736
+ assert bd1c % (t_packing * 128) == 0, (bd1c, t_packing)
737
+ if w1_scale_vmem is not None:
738
+ assert w1_scale_vmem.shape == (
739
+ t_packing,
740
+ bd1_per_t_packing // subc_quant_wsz,
741
+ 1,
742
+ bf,
743
+ )
744
+ assert bd1c_per_t_packing == subc_quant_wsz
745
+ if w3_scale_vmem is not None:
746
+ assert w3_scale_vmem.shape == (
747
+ t_packing,
748
+ bd1_per_t_packing // subc_quant_wsz,
749
+ 1,
750
+ bf,
751
+ )
752
+ assert bd1c_per_t_packing == subc_quant_wsz
753
+
754
+ num_loops = cdiv(dyn_sz, btc)
755
+ repack_ty = jnp.dtype(f"int{t_bitwidth}")
756
+
757
+ def body(btc_id, _):
758
+ for bd1c_id in range(cdiv(bd1, bd1c)):
759
+ t_b32 = t_b32_vmem[
760
+ pl.ds(btc_id * btc, btc),
761
+ pl.ds(bd1c_id * bd1c_per_t_packing, bd1c_per_t_packing),
762
+ ]
763
+ for p_id in range(t_packing):
764
+ t = pltpu.bitcast(t_b32.astype(repack_ty), t_dtype)
765
+ t_b32 = t_b32 >> t_bitwidth
766
+ for bfc_id in range(cdiv(bf, bfc)):
767
+ w_slices = (
768
+ p_id,
769
+ pl.ds(bd1c_id * bd1c_per_t_packing,
770
+ bd1c_per_t_packing),
771
+ pl.ds(bfc_id * bfc, bfc),
772
+ )
773
+ w1 = w1_vmem[*w_slices]
774
+ acc1 = jnp.dot(t,
775
+ w1,
776
+ preferred_element_type=jnp.float32)
777
+
778
+ if w1_scale_vmem is not None:
779
+ w1_scale_slices = (
780
+ p_id,
781
+ bd1c_id,
782
+ pl.ds(0, 1),
783
+ pl.ds(bfc_id * bfc, bfc),
784
+ )
785
+ # TODO(jevinjiang): can use mosaic to load with stride 0.
786
+ w1_scale = jnp.broadcast_to(
787
+ w1_scale_vmem[*w1_scale_slices], acc1.shape)
788
+ acc1 *= w1_scale
789
+
790
+ w3 = w3_vmem[*w_slices]
791
+
792
+ acc3 = jnp.dot(t,
793
+ w3,
794
+ preferred_element_type=jnp.float32)
795
+
796
+ if w3_scale_vmem is not None:
797
+ w3_scale_slices = (
798
+ p_id,
799
+ bd1c_id,
800
+ pl.ds(0, 1),
801
+ pl.ds(bfc_id * bfc, bfc),
802
+ )
803
+ w3_scale = jnp.broadcast_to(
804
+ w3_scale_vmem[*w3_scale_slices], acc3.shape)
805
+ acc3 *= w3_scale
806
+
807
+ acc_slices = (pl.ds(btc_id * btc,
808
+ btc), pl.ds(bfc_id * bfc, bfc))
809
+ if should_init and p_id == bd1c_id == 0:
810
+ if b1_vmem is not None:
811
+ b1_scale_slices = (
812
+ pl.ds(0, 1),
813
+ pl.ds(bfc_id * bfc, bfc),
814
+ )
815
+ b1 = jnp.broadcast_to(
816
+ b1_vmem[*b1_scale_slices], acc1.shape)
817
+ acc1 += b1
818
+ if b3_vmem is not None:
819
+ b3_scale_slices = (
820
+ pl.ds(0, 1),
821
+ pl.ds(bfc_id * bfc, bfc),
822
+ )
823
+ b3 = jnp.broadcast_to(
824
+ b3_vmem[*b3_scale_slices], acc1.shape)
825
+ acc3 += b3
826
+
827
+ acc1_vmem[*acc_slices] = acc1
828
+ acc3_vmem[*acc_slices] = acc3
829
+ else:
830
+ acc1_vmem[*acc_slices] += acc1
831
+ acc3_vmem[*acc_slices] += acc3
832
+
833
+ lax.fori_loop(0, num_loops, body, None)
834
+
835
+ def dynamic_ffn2(
836
+ acc1_vmem,
837
+ acc3_vmem,
838
+ w2_vmem,
839
+ w2_scale_vmem,
840
+ b2_vmem,
841
+ res_b32_vmem,
842
+ dyn_sz,
843
+ should_init,
844
+ ):
845
+ assert res_b32_vmem.shape == (bt * num_devices, bd2_per_t_packing)
846
+ assert w2_vmem.shape == (t_packing, bf, bd2_per_t_packing)
847
+ assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
848
+ assert bd2 % (t_packing * 128) == 0, (bd2, t_packing)
849
+ assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
850
+ assert t_dtype in (jnp.float32, jnp.bfloat16)
851
+
852
+ if w2_scale_vmem is not None:
853
+ assert w2_scale_vmem.shape == (
854
+ t_packing,
855
+ bf // subc_quant_wsz,
856
+ 1,
857
+ bd2_per_t_packing,
858
+ )
859
+ assert bfc == subc_quant_wsz
860
+
861
+ num_loops = cdiv(dyn_sz, btc)
862
+ assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
863
+
864
+ def body(btc_id, _):
865
+ for bd2c_id in range(cdiv(bd2, bd2c)):
866
+ res_lst = []
867
+ for p_id in range(t_packing):
868
+ res = jnp.zeros((btc, bd2c_per_t_packing),
869
+ dtype=jnp.float32)
870
+
871
+ if b2_vmem is not None and should_init:
872
+ b2_scale_slices = (
873
+ p_id,
874
+ pl.ds(0, 1),
875
+ pl.ds(bd2c_id * bd2c_per_t_packing,
876
+ bd2c_per_t_packing),
877
+ )
878
+ b2 = jnp.broadcast_to(b2_vmem[*b2_scale_slices],
879
+ res.shape)
880
+ res += b2
881
+
882
+ for bfc_id in range(cdiv(bf, bfc)):
883
+ acc_slices = (pl.ds(btc_id * btc,
884
+ btc), pl.ds(bfc_id * bfc, bfc))
885
+ acc1 = acc1_vmem[*acc_slices]
886
+ acc3 = acc3_vmem[*acc_slices]
887
+ act = activation_fn(acc1, acc3, act_fn)
888
+ w2 = w2_vmem[
889
+ p_id,
890
+ pl.ds(bfc_id * bfc, bfc),
891
+ pl.ds(bd2c_id *
892
+ bd2c_per_t_packing, bd2c_per_t_packing),
893
+ ]
894
+ acc = jnp.dot(act,
895
+ w2,
896
+ preferred_element_type=jnp.float32)
897
+ if w2_scale_vmem is not None:
898
+ w2_scale_slices = (
899
+ p_id,
900
+ bfc_id,
901
+ pl.ds(0, 1),
902
+ pl.ds(bd2c_id * bd2c_per_t_packing,
903
+ bd2c_per_t_packing),
904
+ )
905
+ w2_scale = jnp.broadcast_to(
906
+ w2_scale_vmem[*w2_scale_slices], acc.shape)
907
+ acc *= w2_scale
908
+ res += acc
909
+ res = pltpu.bitcast(res, jnp.uint32)
910
+ if t_packing == 2:
911
+ res = res >> 16 << (16 * p_id)
912
+ else:
913
+ assert t_packing == 1
914
+ res_lst.append(res)
915
+ res = res_lst[0]
916
+ # TODO(jevinjiang): use interleaved packing when it is exposed to Pallas
917
+ for i in range(1, t_packing):
918
+ res |= res_lst[i]
919
+ sliced_res_vmem = res_b32_vmem.at[
920
+ pl.ds(btc_id * btc, btc),
921
+ pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing),
922
+ ]
923
+ if should_init:
924
+ sliced_res_vmem[...] = res
925
+ else:
926
+ sliced_res_vmem[...] = pltpu.bitcast(
927
+ sliced_res_vmem.bitcast(t_dtype)[...] +
928
+ pltpu.bitcast(res, t_dtype),
929
+ sliced_res_vmem.dtype,
930
+ )
931
+
932
+ lax.fori_loop(0, num_loops, body, None)
933
+
934
+ def expert_ffn(bt_id, e_sem_id, local_e_id):
935
+ bt_sem_id = bt_id % 2
936
+ bw_sem_id = 0
937
+ # start_fetch_bw1(local_e_id, bw_sem_id, 0, 0)
938
+ # start_fetch_bw3(local_e_id, bw_sem_id, 0, 0)
939
+ a2a_s_b32_vmem = (a2a_s_x2_vmem.bitcast(jnp.uint32).reshape(
940
+ 2, bt * num_devices, hidden_size // t_packing).at[e_sem_id])
941
+ a2a_s_acc_b32_vmem = (a2a_s_acc_x2_vmem.bitcast(jnp.uint32).reshape(
942
+ 2, bt * num_devices, hidden_size // t_packing).at[e_sem_id])
943
+ b_acc_vmem_2d = b_acc_vmem.reshape(bt * num_devices, bf * 2)
944
+ b_acc1_vmem = b_acc_vmem_2d.at[:, :bf]
945
+ b_acc3_vmem = b_acc_vmem_2d.at[:, bf:]
946
+
947
+ e_id = my_id * local_num_experts + local_e_id
948
+ dyn_sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
949
+
950
+ bd1_per_t_packing = bd1 // t_packing
951
+ bd2_per_t_packing = bd2 // t_packing
952
+
953
+ for bf_id in range(num_bf):
954
+ for bd1_id in range(num_bd1):
955
+ start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, 0)
956
+ w1_scale_vmem = (None if b_w1_scale_x2_vmem is None else
957
+ b_w1_scale_x2_vmem.at[bw_sem_id])
958
+ w3_scale_vmem = (None if b_w3_scale_x2_vmem is None else
959
+ b_w3_scale_x2_vmem.at[bw_sem_id])
960
+ b1_vmem = None if b_b1_x2_vmem is None else b_b1_x2_vmem.at[
961
+ bf_id % 2]
962
+ b3_vmem = None if b_b3_x2_vmem is None else b_b3_x2_vmem.at[
963
+ bf_id % 2]
964
+ wait_fetch_bw1(local_e_id, bw_sem_id, bf_id, bd1_id)
965
+ wait_fetch_bw3(local_e_id, bw_sem_id, bf_id, bd1_id)
966
+
967
+ dynamic_ffn1(
968
+ t_b32_vmem=a2a_s_b32_vmem.at[
969
+ ...,
970
+ pl.ds(bd1_id * bd1_per_t_packing, bd1_per_t_packing)],
971
+ w1_vmem=b_w1_x2_vmem.at[bw_sem_id],
972
+ w1_scale_vmem=w1_scale_vmem,
973
+ b1_vmem=b1_vmem,
974
+ w3_vmem=b_w3_x2_vmem.at[bw_sem_id],
975
+ w3_scale_vmem=w3_scale_vmem,
976
+ b3_vmem=b3_vmem,
977
+ acc1_vmem=b_acc1_vmem,
978
+ acc3_vmem=b_acc3_vmem,
979
+ dyn_sz=dyn_sz,
980
+ should_init=(bd1_id == 0),
981
+ )
982
+ bw_sem_id = (bw_sem_id + 1) % 2
983
+
984
+ for bd2_id in range(num_bd2):
985
+ start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, num_bd1,
986
+ bd2_id)
987
+ wait_fetch_bw2(local_e_id, bw_sem_id, bf_id, bd2_id)
988
+ if bf_id == bd2_id == 0:
989
+ wait_a2a_gather_send(bt_id, e_sem_id, local_e_id - 2)
990
+
991
+ w2_scale_vmem = (None if b_w2_scale_x2_vmem is None else
992
+ b_w2_scale_x2_vmem.at[bw_sem_id])
993
+ b2_vmem = None if b_b2_x2_vmem is None else b_b2_x2_vmem.at[
994
+ bd2_id % 2]
995
+ dynamic_ffn2(
996
+ acc1_vmem=b_acc1_vmem,
997
+ acc3_vmem=b_acc3_vmem,
998
+ w2_vmem=b_w2_x2_vmem.at[bw_sem_id],
999
+ w2_scale_vmem=w2_scale_vmem,
1000
+ b2_vmem=b2_vmem,
1001
+ res_b32_vmem=a2a_s_acc_b32_vmem.at[
1002
+ ...,
1003
+ pl.ds(bd2_id * bd2_per_t_packing, bd2_per_t_packing)],
1004
+ dyn_sz=dyn_sz,
1005
+ should_init=(bf_id == 0),
1006
+ )
1007
+ bw_sem_id = (bw_sem_id + 1) % 2
1008
+
1009
+ def bt_acc(bt_id, top_k_logits_lst):
1010
+ bt_sem_id = bt_id % 2
1011
+ for bt_t_id in range(bt):
1012
+ for k_id in range(top_k):
1013
+ e_id = t2e_routing_x2_smem[bt_sem_id, bt_t_id, k_id]
1014
+ offset = expert_offsets_x2_smem[bt_sem_id, 1, e_id]
1015
+ expert_offsets_x2_smem[bt_sem_id, 1, e_id] = offset + 1
1016
+ pltpu.make_async_copy(
1017
+ src_ref=a2a_g_hbm.at[e_id, pl.ds(offset, 1)],
1018
+ dst_ref=a2a_g_acc_vmem.at[k_id, pl.ds(bt_t_id, 1)],
1019
+ sem=a2a_acc_sem,
1020
+ ).start()
1021
+ pltpu.make_async_copy(
1022
+ src_ref=a2a_g_acc_vmem,
1023
+ dst_ref=a2a_g_acc_vmem,
1024
+ sem=a2a_acc_sem,
1025
+ ).wait()
1026
+ output = None
1027
+ for k_id in range(top_k):
1028
+ acc = a2a_g_acc_vmem[k_id].reshape(bt, hidden_size)
1029
+ logits = broadcast_minor(top_k_logits_lst[k_id], acc.shape)
1030
+ acc *= logits
1031
+ if output is None:
1032
+ output = acc
1033
+ else:
1034
+ output += acc
1035
+ assert output is not None
1036
+ return output.astype(output_hbm.dtype)
1037
+
1038
+ def start_send_bo(bt_id, priority=0):
1039
+ bt_sem_id = bt_id % 2
1040
+ b_output_sem = local_sems.at[bt_sem_id, 4]
1041
+ pltpu.make_async_copy(
1042
+ src_ref=b_output_x2_vmem.at[bt_sem_id],
1043
+ dst_ref=output_hbm.at[pl.ds(bt_id * bt, bt)],
1044
+ sem=b_output_sem,
1045
+ ).start(priority=priority)
1046
+
1047
+ def wait_send_bo(bt_id):
1048
+ is_valid = jnp.logical_and(0 <= bt_id, bt_id < num_bt)
1049
+ sz = pl.multiple_of(lax.select(is_valid, bt, 0), bt)
1050
+ bt_sem_id = (bt_id + 2) % 2
1051
+ b_output_sem = local_sems.at[bt_sem_id, 4]
1052
+ pltpu.make_async_copy(
1053
+ src_ref=output_hbm.at[pl.ds(0, sz)],
1054
+ dst_ref=output_hbm.at[pl.ds(0, sz)],
1055
+ sem=b_output_sem,
1056
+ ).wait()
1057
+
1058
+ ### ------- Kernel start ------- ###
1059
+ start_fetch_b_gating(bt_id=0)
1060
+
1061
+ def run_per_bt(bt_id, e_sem_id):
1062
+ bt_sem_id = bt_id % 2
1063
+ next_bt_id = bt_id + 1
1064
+ start_fetch_b_gating(next_bt_id)
1065
+ wait_fetch_b_gating(bt_id)
1066
+
1067
+ b_gating = b_gating_x2_vmem[bt_sem_id]
1068
+ b_gating_score = jax.nn.softmax(b_gating, axis=-1)
1069
+ top_k_logits_lst, t2e_routing, expert_sizes, expert_starts = get_top_k(
1070
+ b_gating_score, top_k, renormalize_topk_logits)
1071
+
1072
+ all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts,
1073
+ expert_sizes)
1074
+
1075
+ start_a2a_scatter(bt_id=bt_id, e_sem_id=e_sem_id, local_e_id=0)
1076
+
1077
+ def run_per_expert(local_e_id, e_sem_id):
1078
+ sync_barrier()
1079
+ next_e_sem_id = lax.select(e_sem_id == 0, 1, 0)
1080
+ next_local_e_id = local_e_id + 1
1081
+
1082
+ @pl.when(next_local_e_id < local_num_experts)
1083
+ def _():
1084
+ start_a2a_scatter(bt_id, next_e_sem_id, next_local_e_id)
1085
+
1086
+ # Prefetch weights for active expert.
1087
+ start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
1088
+ start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
1089
+
1090
+ # Wait for a2a scatter and perform FFN for active expert.
1091
+ wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id)
1092
+ expert_ffn(bt_id, e_sem_id, local_e_id)
1093
+
1094
+ # Wait for a2a gather to send back tokens for active expert.
1095
+ start_a2a_gather(bt_id, e_sem_id, local_e_id)
1096
+
1097
+ # A must-wait before next sync_barrier.
1098
+ wait_a2a_scatter_send(bt_id, e_sem_id, local_e_id)
1099
+ return next_e_sem_id
1100
+
1101
+ e_sem_id = lax.fori_loop(0,
1102
+ local_num_experts,
1103
+ run_per_expert,
1104
+ e_sem_id,
1105
+ unroll=False)
1106
+
1107
+ wait_a2a_gather_recv_all()
1108
+ output = bt_acc(bt_id, top_k_logits_lst)
1109
+
1110
+ # Make sure it is safe to overwrite output buffer.
1111
+ wait_send_bo(bt_id=bt_id - 2)
1112
+ b_output_x2_vmem[bt_sem_id] = output
1113
+
1114
+ start_send_bo(bt_id)
1115
+
1116
+ wait_a2a_gather_send(
1117
+ bt_id,
1118
+ e_sem_id=e_sem_id,
1119
+ local_e_id=local_num_experts - 2,
1120
+ )
1121
+ wait_a2a_gather_send(
1122
+ bt_id,
1123
+ e_sem_id=lax.select(e_sem_id == 0, 1, 0),
1124
+ local_e_id=local_num_experts - 1,
1125
+ )
1126
+ return e_sem_id
1127
+
1128
+ lax.fori_loop(0, num_bt, run_per_bt, 0, unroll=False)
1129
+ wait_send_bo(bt_id=num_bt - 2)
1130
+ wait_send_bo(bt_id=num_bt - 1)
1131
+
1132
+ ### ------- Kernel end ------- ###
1133
+
1134
+
1135
+ @functools.partial(
1136
+ jax.jit,
1137
+ static_argnames=[
1138
+ "mesh",
1139
+ "top_k",
1140
+ "renormalize_topk_logits",
1141
+ "act_fn",
1142
+ "subc_quant_wsz",
1143
+ "bt",
1144
+ "bf",
1145
+ "bd1",
1146
+ "bd2",
1147
+ "btc",
1148
+ "bfc",
1149
+ "bd1c",
1150
+ "bd2c",
1151
+ "ep_axis_name",
1152
+ ],
1153
+ )
1154
+ def fused_ep_moe(
1155
+ mesh: jax.sharding.Mesh,
1156
+ tokens: jax.Array, # (num_tokens, hidden_size)
1157
+ w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
1158
+ w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
1159
+ gating_output: jax.Array, # (num_tokens, num_experts)
1160
+ top_k: int,
1161
+ renormalize_topk_logits: bool = False,
1162
+ act_fn: str = "silu",
1163
+ *,
1164
+ subc_quant_wsz: int | None = None,
1165
+ w1_scale: (
1166
+ jax.Array | None
1167
+ ) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size)
1168
+ w2_scale: (
1169
+ jax.Array | None
1170
+ ) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size)
1171
+ b1: jax.Array | None = None, # (num_experts, 2, intermediate_size)
1172
+ b2: jax.Array | None = None, # (num_experts, hidden_size)
1173
+ # Kernel tuning parameters.
1174
+ bt: int,
1175
+ bf: int,
1176
+ bd1: int,
1177
+ bd2: int,
1178
+ btc: int,
1179
+ bfc: int,
1180
+ bd1c: int,
1181
+ bd2c: int,
1182
+ ep_axis_name: str = "model",
1183
+ ):
1184
+ # TODO(jevinjiang): move all these assertions to validation function.
1185
+ # Assert all other axes have length of 1
1186
+ assert len(mesh.shape) == 2, "Expect 2D mesh"
1187
+ assert ("data" in mesh.shape
1188
+ and mesh.shape["data"] == 1), "Expect data axis size of 1"
1189
+
1190
+ ep_size = mesh.shape[ep_axis_name]
1191
+ num_devices = ep_size
1192
+
1193
+ num_tokens, actual_hidden_size = tokens.shape
1194
+ num_experts, actual_intermediate_size, _ = w2.shape
1195
+
1196
+ assert num_tokens % ep_size == 0
1197
+ assert num_experts % ep_size == 0
1198
+
1199
+ local_num_tokens = num_tokens // ep_size
1200
+ # local_num_experts = num_experts // ep_size
1201
+ padded_num_experts = align_to(num_experts, 128)
1202
+ t_dtype = tokens.dtype
1203
+ t_packing = get_dtype_packing(t_dtype)
1204
+
1205
+ if subc_quant_wsz is not None:
1206
+ if subc_quant_wsz % 256 != 0:
1207
+ raise NotImplementedError(
1208
+ "Sub-quantized window is not aligned to 256.")
1209
+ # We force compute size of contracting dim to subc_quant_wsz. So we can
1210
+ # apply same scale after matmul and accumulation.
1211
+ bd1c = subc_quant_wsz * t_packing
1212
+ bfc = subc_quant_wsz
1213
+
1214
+ assert bfc % 128 == 0
1215
+ assert bd1c % (t_packing * 128) == 0
1216
+ assert bd2c % (t_packing * 128) == 0
1217
+ assert bf % bfc == 0
1218
+ assert bd1 % bd1c == 0
1219
+ assert bd2 % bd2c == 0
1220
+
1221
+ btc = min(btc, bt * num_devices)
1222
+ hidden_size = align_to(actual_hidden_size, 128 * t_packing)
1223
+ # TODO(jevinjiang): instead of padding outside the kernel, we can try dynammic
1224
+ # masking inside the kernel.
1225
+ hidden_size = align_to(hidden_size, bd1)
1226
+ hidden_size = align_to(hidden_size, bd2)
1227
+ intermediate_size = align_to(actual_intermediate_size, bf)
1228
+
1229
+ # TODO(jevinjiang): we should dump scale as the kernel expected shape in the
1230
+ # checkpoint offline or reshape right after weight loading.
1231
+ if w1_scale is not None:
1232
+ assert w1_scale.shape[0] == w1.shape[0]
1233
+ assert w1_scale.shape[1] == w1.shape[1] == 2
1234
+ assert w1_scale.shape[2] == cdiv(w1.shape[2], subc_quant_wsz)
1235
+ assert w1_scale.shape[3] == w1.shape[3]
1236
+ w1_scale = jnp.expand_dims(w1_scale.astype(jnp.float32), axis=-2)
1237
+
1238
+ if w2_scale is not None:
1239
+ assert w2_scale.shape[0] == w2.shape[0]
1240
+ assert w2_scale.shape[1] == cdiv(w2.shape[1], subc_quant_wsz)
1241
+ assert w2_scale.shape[2] == w2.shape[2]
1242
+ w2_scale = jnp.expand_dims(w2_scale.astype(jnp.float32), axis=-2)
1243
+
1244
+ if b1 is not None:
1245
+ assert b1.shape[0] == w1.shape[0]
1246
+ assert b1.shape[1] == w1.shape[1] == 2
1247
+ assert b1.shape[2] == w1.shape[3]
1248
+ b1 = jnp.expand_dims(b1.astype(jnp.float32), axis=-2)
1249
+
1250
+ if b2 is not None:
1251
+ assert b2.shape[0] == w2.shape[0]
1252
+ assert b2.shape[1] == w2.shape[2]
1253
+ b2 = jnp.expand_dims(b2.astype(jnp.float32), axis=-2)
1254
+
1255
+ # Prepare inputs for the kernel.
1256
+ if padded_num_experts != gating_output.shape[-1]:
1257
+ gating_output = jnp.pad(
1258
+ gating_output,
1259
+ ((0, 0), (0, padded_num_experts - gating_output.shape[-1])),
1260
+ constant_values=-jnp.inf,
1261
+ )
1262
+
1263
+ if (hidden_size != actual_hidden_size
1264
+ or intermediate_size != actual_intermediate_size):
1265
+ tokens = jnp.pad(
1266
+ tokens,
1267
+ ((0, 0), (0, hidden_size - actual_hidden_size)),
1268
+ constant_values=0,
1269
+ )
1270
+ w1 = jnp.pad(
1271
+ w1,
1272
+ (
1273
+ (0, 0),
1274
+ (0, 0),
1275
+ (0, hidden_size - actual_hidden_size),
1276
+ (0, intermediate_size - actual_intermediate_size),
1277
+ ),
1278
+ constant_values=0,
1279
+ )
1280
+ w2 = jnp.pad(
1281
+ w2,
1282
+ (
1283
+ (0, 0),
1284
+ (0, intermediate_size - actual_intermediate_size),
1285
+ (0, hidden_size - actual_hidden_size),
1286
+ ),
1287
+ constant_values=0,
1288
+ )
1289
+ if w1_scale is not None:
1290
+ w1_scale = jnp.pad(
1291
+ w1_scale,
1292
+ (
1293
+ (0, 0),
1294
+ (0, 0),
1295
+ (0,
1296
+ cdiv(hidden_size, subc_quant_wsz) - w1_scale.shape[-3]),
1297
+ (0, 0),
1298
+ (0, intermediate_size - w1_scale.shape[-1]),
1299
+ ),
1300
+ constant_values=0,
1301
+ )
1302
+ if w2_scale is not None:
1303
+ w2_scale = jnp.pad(
1304
+ w2_scale,
1305
+ (
1306
+ (0, 0),
1307
+ (0, cdiv(intermediate_size, subc_quant_wsz) -
1308
+ w2_scale.shape[-3]),
1309
+ (0, 0),
1310
+ (0, hidden_size - w2_scale.shape[-1]),
1311
+ ),
1312
+ constant_values=0,
1313
+ )
1314
+ if b1 is not None:
1315
+ b1 = jnp.pad(
1316
+ b1,
1317
+ (
1318
+ (0, 0),
1319
+ (0, 0),
1320
+ (0, 0),
1321
+ (0, intermediate_size - b1.shape[-1]),
1322
+ ),
1323
+ constant_values=0,
1324
+ )
1325
+ if b2 is not None:
1326
+ b2 = jnp.pad(
1327
+ b2,
1328
+ (
1329
+ (0, 0),
1330
+ (0, 0),
1331
+ (0, hidden_size - b2.shape[-1]),
1332
+ ),
1333
+ constant_values=0,
1334
+ )
1335
+
1336
+ tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
1337
+
1338
+ hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
1339
+ scope_name = f"fused_moe_k-{top_k}_renorm-{renormalize_topk_logits}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}"
1340
+ fused_moe = jax.named_scope(scope_name)(
1341
+ pl.pallas_call(
1342
+ functools.partial(
1343
+ _fused_ep_moe_kernel,
1344
+ top_k=top_k,
1345
+ renormalize_topk_logits=renormalize_topk_logits,
1346
+ ep_axis_name=ep_axis_name,
1347
+ act_fn=act_fn,
1348
+ subc_quant_wsz=subc_quant_wsz,
1349
+ bt=bt,
1350
+ bf=bf,
1351
+ bd1=bd1,
1352
+ bd2=bd2,
1353
+ btc=btc,
1354
+ bfc=bfc,
1355
+ bd1c=bd1c,
1356
+ bd2c=bd2c,
1357
+ ),
1358
+ out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
1359
+ t_dtype),
1360
+ grid_spec=pltpu.PrefetchScalarGridSpec(
1361
+ num_scalar_prefetch=0,
1362
+ in_specs=[
1363
+ hbm_block_spec, # tokens_hbm
1364
+ hbm_block_spec, # w1_hbm
1365
+ hbm_block_spec, # w2_hbm
1366
+ None
1367
+ if w1_scale is None else hbm_block_spec, # w1_scale_hbm
1368
+ None
1369
+ if w2_scale is None else hbm_block_spec, # w2_scale_hbm
1370
+ None if b1 is None else hbm_block_spec, # b1_hbm
1371
+ None if b2 is None else hbm_block_spec, # b2_hbm
1372
+ hbm_block_spec, # gating_output_hbm
1373
+ hbm_block_spec, # a2a_g_hbm
1374
+ ],
1375
+ out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
1376
+ scratch_shapes=([
1377
+ # t2e_routing_x2_smem
1378
+ pltpu.SMEM((2, bt, padded_num_experts), jnp.int32),
1379
+ # d2e_count_x2_smem
1380
+ pltpu.SMEM((2, num_devices, 1, padded_num_experts),
1381
+ jnp.int32),
1382
+ # expert_offsets_x2_smem
1383
+ pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
1384
+ # expert_starts_x2_smem
1385
+ pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1386
+ # expert_sizes_x2_smem
1387
+ pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
1388
+ # a2a_s_sends_x2_smem
1389
+ pltpu.SMEM((2, ), jnp.int32),
1390
+ # a2a_s_x2_vmem
1391
+ pltpu.VMEM(
1392
+ (
1393
+ 2,
1394
+ bt * num_devices,
1395
+ t_packing,
1396
+ hidden_size // t_packing,
1397
+ ),
1398
+ t_dtype,
1399
+ ),
1400
+ # a2a_s_acc_x2_vmem
1401
+ pltpu.VMEM(
1402
+ (
1403
+ 2,
1404
+ bt * num_devices,
1405
+ t_packing,
1406
+ hidden_size // t_packing,
1407
+ ),
1408
+ t_dtype,
1409
+ ),
1410
+ # a2a_g_acc_vmem
1411
+ pltpu.VMEM(
1412
+ (top_k, bt, t_packing, hidden_size // t_packing),
1413
+ t_dtype),
1414
+ # b_gating_x2_vmem
1415
+ pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
1416
+ # b_output_x2_vmem
1417
+ pltpu.VMEM((2, bt, hidden_size), t_dtype),
1418
+ # b_w1_x2_vmem
1419
+ pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1420
+ # b_w3_x2_vmem
1421
+ pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1422
+ # b_w2_x2_vmem
1423
+ pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
1424
+ # b_w1_scale_x2_vmem
1425
+ (None if w1_scale is None else pltpu.VMEM(
1426
+ (
1427
+ 2,
1428
+ t_packing,
1429
+ bd1 // t_packing // subc_quant_wsz,
1430
+ 1,
1431
+ bf,
1432
+ ),
1433
+ jnp.float32,
1434
+ )),
1435
+ # b_w3_scale_x2_vmem
1436
+ (None if w1_scale is None else pltpu.VMEM(
1437
+ (
1438
+ 2,
1439
+ t_packing,
1440
+ bd1 // t_packing // subc_quant_wsz,
1441
+ 1,
1442
+ bf,
1443
+ ),
1444
+ jnp.float32,
1445
+ )),
1446
+ # b_w2_scale_x2_vmem
1447
+ (None if w2_scale is None else pltpu.VMEM(
1448
+ (
1449
+ 2,
1450
+ t_packing,
1451
+ bf // subc_quant_wsz,
1452
+ 1,
1453
+ bd2 // t_packing,
1454
+ ),
1455
+ jnp.float32,
1456
+ )),
1457
+ # b_b1_x2_vmem
1458
+ (None if b1 is None else pltpu.VMEM(
1459
+ (
1460
+ 2,
1461
+ 1,
1462
+ bf,
1463
+ ),
1464
+ jnp.float32,
1465
+ )),
1466
+ # b_b3_x2_vmem
1467
+ (None if b1 is None else pltpu.VMEM(
1468
+ (
1469
+ 2,
1470
+ 1,
1471
+ bf,
1472
+ ),
1473
+ jnp.float32,
1474
+ )),
1475
+ # b_b2_x2_vmem
1476
+ (None if b2 is None else pltpu.VMEM(
1477
+ (
1478
+ 2,
1479
+ t_packing,
1480
+ 1,
1481
+ bd2 // t_packing,
1482
+ ),
1483
+ jnp.float32,
1484
+ )),
1485
+ # b_acc_vmem
1486
+ pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
1487
+ # local_sems
1488
+ pltpu.SemaphoreType.DMA((2, 5)),
1489
+ # send_sems
1490
+ pltpu.SemaphoreType.DMA((2, )),
1491
+ # recv_sems
1492
+ pltpu.SemaphoreType.DMA((2, )),
1493
+ # a2a_gather_sem
1494
+ pltpu.SemaphoreType.DMA,
1495
+ # a2a_acc_sem
1496
+ pltpu.SemaphoreType.DMA,
1497
+ ]),
1498
+ ),
1499
+ compiler_params=pltpu.CompilerParams(
1500
+ collective_id=0,
1501
+ vmem_limit_bytes=100 * 1024 * 1024,
1502
+ ),
1503
+ name=scope_name,
1504
+ ))
1505
+
1506
+ @jax.jit
1507
+ @jax.shard_map(
1508
+ mesh=mesh,
1509
+ in_specs=(
1510
+ P(ep_axis_name), # tokens_hbm
1511
+ P(ep_axis_name), # w1_hbm
1512
+ P(ep_axis_name), # w2_hbm
1513
+ None if w1_scale is None else P(ep_axis_name), # w1_scale_hbm
1514
+ None if w2_scale is None else P(ep_axis_name), # w2_scale_hbm
1515
+ None if b1 is None else P(ep_axis_name), # b1_hbm
1516
+ None if b2 is None else P(ep_axis_name), # b2_hbm
1517
+ P(ep_axis_name), # gating_output_hbm
1518
+ P(), # a2a_g_hbm
1519
+ ),
1520
+ out_specs=P(ep_axis_name),
1521
+ check_vma=False,
1522
+ )
1523
+ def kernel(
1524
+ tokens,
1525
+ w1,
1526
+ w2,
1527
+ w1_scale,
1528
+ w2_scale,
1529
+ b1,
1530
+ b2,
1531
+ gating_output,
1532
+ a2a_g_hbm_scratch,
1533
+ ):
1534
+ return fused_moe(
1535
+ pltpu.with_memory_space_constraint(tokens,
1536
+ pltpu.HBM), # tokens_hbm
1537
+ pltpu.with_memory_space_constraint(w1, pltpu.HBM), # w1_hbm
1538
+ pltpu.with_memory_space_constraint(w2, pltpu.HBM), # w2_hbm
1539
+ (None if w1_scale is None else pltpu.with_memory_space_constraint(
1540
+ w1_scale, pltpu.HBM)), # w1_scale_hbm
1541
+ (None if w2_scale is None else pltpu.with_memory_space_constraint(
1542
+ w2_scale, pltpu.HBM)), # w2_scale_hbm
1543
+ (None if b1 is None else pltpu.with_memory_space_constraint(
1544
+ b1, pltpu.HBM)), # b1_hbm
1545
+ (None if b2 is None else pltpu.with_memory_space_constraint(
1546
+ b2, pltpu.HBM)), # b2_hbm
1547
+ pltpu.with_memory_space_constraint(gating_output,
1548
+ pltpu.HBM), # gating_output_hbm
1549
+ pltpu.with_memory_space_constraint(a2a_g_hbm_scratch,
1550
+ pltpu.HBM), # a2a_g_hbm
1551
+ )
1552
+
1553
+ a2a_g_hbm_scratch = pl.empty(
1554
+ (num_experts, bt, t_packing, hidden_size // t_packing), t_dtype)
1555
+ results = kernel(
1556
+ tokens,
1557
+ w1,
1558
+ w2,
1559
+ w1_scale,
1560
+ w2_scale,
1561
+ b1,
1562
+ b2,
1563
+ gating_output,
1564
+ a2a_g_hbm_scratch,
1565
+ )
1566
+ return results[:, :actual_hidden_size]