tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,6 @@ import jax.numpy as jnp
7
7
  from jax import lax
8
8
  from jax._src import dtypes
9
9
  from jax.experimental import pallas as pl
10
- from jax.experimental import shard_map
11
10
  from jax.experimental.pallas import tpu as pltpu
12
11
 
13
12
  P = jax.sharding.PartitionSpec
@@ -20,7 +19,8 @@ def align_to(x, a):
20
19
 
21
20
 
22
21
  def get_dtype_packing(dtype):
23
- bits = dtypes.bit_width(dtype)
22
+ bits = (dtypes.bit_width(dtype)
23
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
24
24
  return 32 // bits
25
25
 
26
26
 
@@ -35,13 +35,50 @@ def broadcast_minor(src, shape):
35
35
  axis=-1)[..., :shape[-1]]
36
36
 
37
37
 
38
+ def swigluoai(gate: jax.Array,
39
+ up: jax.Array,
40
+ *,
41
+ alpha: float = 1.702,
42
+ limit: float = 7.0) -> jax.Array:
43
+ """Activation used in some models such as GPT-OSS."""
44
+ gate = jnp.clip(gate, a_max=limit)
45
+ up = jnp.clip(up, a_min=-limit, a_max=limit)
46
+ glu = gate * jax.nn.sigmoid(alpha * gate)
47
+ return (up + 1.0) * glu
48
+
49
+
50
+ def activation_fn(acc1, acc3, act_fn):
51
+ if act_fn == "silu":
52
+ return jax.nn.silu(acc1) * acc3
53
+ elif act_fn == "gelu":
54
+ return jax.nn.gelu(acc1) * acc3
55
+ elif act_fn == "swigluoai":
56
+ return swigluoai(acc1, acc3)
57
+ else:
58
+ raise RuntimeError(f"Unsupported activation function: {act_fn}")
59
+
60
+
38
61
  def ref_moe(
39
- tokens: jax.Array, # (num_tokens, hidden_size)
40
- w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
41
- w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
42
- gating_output: jax.Array, # (num_tokens, num_experts)
43
- top_k: int,
44
- activation="silu",
62
+ tokens: jax.Array, # (num_tokens, hidden_size)
63
+ w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
64
+ w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
65
+ gating_output: jax.Array, # (num_tokens, num_experts)
66
+ top_k: int,
67
+ *,
68
+ renormalize_topk_logits: bool = False,
69
+ act_fn: str = "silu",
70
+ subc_quant_wsz: int | None = None,
71
+ w1_scale:
72
+ (
73
+ jax.Array | None
74
+ ) = None, # F32(num_experts, 2, hidden_size //subc_quant_wsz, 1, intermediate_size)
75
+ w2_scale:
76
+ (
77
+ jax.Array | None
78
+ ) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
79
+ b1: jax.Array
80
+ | None = None, # F32(num_experts, 2, 1, intermediate_size)
81
+ b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
45
82
  ):
46
83
  n_tokens = tokens.shape[0] # num_tokens
47
84
 
@@ -53,11 +90,16 @@ def ref_moe(
53
90
  top_k_logits, top_k_indices = lax.top_k(
54
91
  gating_logits, top_k) # [num_tokens, top_k], [num_tokens, top_k]
55
92
 
93
+ if renormalize_topk_logits:
94
+ top_k_logits = top_k_logits / jnp.sum(
95
+ top_k_logits, axis=-1, keepdims=True)
96
+
56
97
  t_outputs = []
98
+ hidden_size, intermediate_size = w1.shape[-2:]
57
99
 
58
100
  # Process each token individually
59
101
  for i in range(n_tokens):
60
- curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, d_model]
102
+ curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, hidden_size]
61
103
  assigned_expert_ids = top_k_indices[
62
104
  i] # [top_k] - indices of selected experts for token i
63
105
  tok_expert_act = []
@@ -65,10 +107,24 @@ def ref_moe(
65
107
  # Process each selected expert for the current token
66
108
  for expert_id in assigned_expert_ids:
67
109
  # Get expert weights
110
+ expert_w1 = w1[expert_id, 0].astype(jnp.float32)
111
+ expert_w3 = w1[expert_id, 1].astype(jnp.float32)
112
+ if w1_scale is not None:
113
+ expert_w1 *= jnp.repeat(w1_scale[expert_id, 0, :, 0],
114
+ subc_quant_wsz,
115
+ axis=0)[:hidden_size]
116
+ expert_w3 *= jnp.repeat(w1_scale[expert_id, 1, :, 0],
117
+ subc_quant_wsz,
118
+ axis=0)[:hidden_size]
68
119
  expert_weight_1 = jnp.concat(
69
- [w1[expert_id, 0], w1[expert_id, 1]],
70
- axis=-1) # [d_model, 2 * intermediate_size]
71
- expert_weight_2 = w2[expert_id] # [intermediate_size, d_model]
120
+ [expert_w1, expert_w3],
121
+ axis=-1) # [hidden_size, 2 * intermediate_size]
122
+ expert_weight_2 = w2[expert_id].astype(
123
+ jnp.float32) # [intermediate_size, hidden_size]
124
+ if w2_scale is not None:
125
+ expert_weight_2 *= jnp.repeat(w2_scale[expert_id, :, 0],
126
+ subc_quant_wsz,
127
+ axis=0)[:intermediate_size]
72
128
 
73
129
  # First linear layer with SwiGLU activation
74
130
  gmm_1_out = curr_token @ expert_weight_1 # [1, 2 * intermediate_size]
@@ -77,37 +133,34 @@ def ref_moe(
77
133
  gmm1_w1_proj, gmm1_w3_proj = jnp.split(
78
134
  gmm_1_out, 2,
79
135
  axis=-1) # [1, intermediate_size], [1, intermediate_size]
136
+ if b1 is not None:
137
+ gmm1_w1_proj += b1[expert_id:expert_id + 1, 0, 0]
138
+ gmm1_w3_proj += b1[expert_id:expert_id + 1, 1, 0]
80
139
 
81
140
  # Apply gated activation: activation(gate) * up
82
- if activation == "silu":
83
- act = jax.nn.silu(
84
- gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
85
- elif activation == "gelu":
86
- act = jax.nn.gelu(
87
- gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
88
- else:
89
- raise ValueError(
90
- f"Unsupported activation: {activation}. Use 'silu' or 'gelu'."
91
- )
141
+ act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, act_fn)
92
142
 
93
143
  # Second linear layer (down projection)
94
- gmm_2_out = act @ expert_weight_2 # [1, d_model]
144
+ gmm_2_out = act @ expert_weight_2 # [1, hidden_size]
145
+ if b2 is not None:
146
+ gmm_2_out += b2[expert_id:expert_id + 1, 0]
95
147
  tok_expert_act.append(gmm_2_out)
96
148
 
97
149
  # Combine outputs from all selected experts
98
150
  experts_act = jnp.concatenate(tok_expert_act,
99
- axis=0) # [top_k, d_model]
151
+ axis=0) # [top_k, hidden_size]
100
152
 
101
153
  # Weighted sum using top-k gating weights
102
154
  top_k_weights = top_k_logits[i] # [top_k]
103
155
  top_k_weights = jnp.expand_dims(top_k_weights, axis=1) # [top_k, 1]
104
156
  weighted_output = jnp.sum(experts_act * top_k_weights,
105
157
  axis=0,
106
- keepdims=True) # [1, d_model]
158
+ keepdims=True) # [1, hidden_size]
107
159
 
108
- t_outputs.append(weighted_output)
160
+ t_outputs.append(weighted_output.astype(tokens.dtype))
109
161
 
110
- return jnp.concatenate(t_outputs, axis=0) # [num_tokens, d_model]
162
+ return jnp.concatenate(t_outputs,
163
+ axis=0) # [actual_num_tokens, hidden_size]
111
164
 
112
165
 
113
166
  def _fused_ep_moe_kernel(
@@ -115,12 +168,19 @@ def _fused_ep_moe_kernel(
115
168
  tokens_hbm, # (local_num_tokens, t_packing, hidden_size // t_packing)
116
169
  w1_hbm, # (local_num_experts, 2, hidden_size, intermediate_size)
117
170
  w2_hbm, # (local_num_experts, intermediate_size, hidden_size)
171
+ # TODO(jevinjiang): We choose F32 scale for easier slicing. The extra
172
+ # latency should be hidden in the pipeline overlaping. But is there a better
173
+ # way to do this?
174
+ w1_scale_hbm, # None | F32(local_num_experts, 2, cdiv(hidden_size, subc_quant_wsz), 1, intermediate_size)
175
+ w2_scale_hbm, # None | F32(local_num_experts, cdiv(intermediate_size, subc_quant_wsz), 1, hidden_size)
176
+ b1_hbm, # None | F32(local_num_experts, 2, 1, intermediate_size)
177
+ b2_hbm, # None | F32(local_num_experts, 1, hidden_size)
118
178
  gating_hbm, # (local_num_tokens, padded_num_experts)
119
179
  a2a_g_hbm, # (num_experts, bt, t_packing, hidden_size // t_packing)
120
180
  # Output
121
181
  output_hbm, # (local_num_tokens, hidden_size)
122
182
  # Scratch
123
- t2e_routing_x2_smem, # <bt_sem_id> (2, bt, padded_num_experts)
183
+ t2e_routing_x2_smem, # <bt_sem_id> (2, bt, padded_top_k)
124
184
  d2e_count_x2_smem, # <bt_sem_id> (2, num_devices, 1, padded_num_experts)
125
185
  expert_offsets_x2_smem, # <bt_sem_id> (2, 2, padded_num_experts): for a2a_s and a2a_g
126
186
  expert_starts_x2_smem, # <bt_sem_id> (2, 1, padded_num_experts)
@@ -136,6 +196,12 @@ def _fused_ep_moe_kernel(
136
196
  b_w1_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
137
197
  b_w3_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
138
198
  b_w2_x2_vmem, # <bw_sem_id> (2, t_packing, bf, bd2 // t_packing)
199
+ b_w1_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
200
+ b_w3_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
201
+ b_w2_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bf // subc_quant_wsz, 1, bd2 // t_packing)
202
+ b_b1_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
203
+ b_b3_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
204
+ b_b2_x2_vmem, # None | <bw_sem_id> (2, t_packing, 1, bd2 // t_packing)
139
205
  b_acc_vmem, # F32(bt * num_devices, 1, bf * 2)
140
206
  ### Semaphores:
141
207
  local_sems, # (2, 5): 2 x [b_gating_sem, b_w1_sem, b_w2_sem, b_w3_sem, b_output_sem]
@@ -145,7 +211,10 @@ def _fused_ep_moe_kernel(
145
211
  a2a_acc_sem,
146
212
  *,
147
213
  top_k: int,
214
+ renormalize_topk_logits: bool,
148
215
  ep_axis_name: str,
216
+ act_fn: str,
217
+ subc_quant_wsz: int | None = None,
149
218
  # Kernel tuning params.
150
219
  bt: int, # Block size of local_num_tokens.
151
220
  bf: int, # Block size of intermediate_size.
@@ -160,34 +229,58 @@ def _fused_ep_moe_kernel(
160
229
  num_devices = lax.axis_size(ep_axis_name)
161
230
  local_num_tokens = tokens_hbm.shape[0]
162
231
  local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
163
- # num_experts = local_num_experts * num_devices
164
- # padded_num_experts = expert_starts_x2_smem.shape[-1]
165
232
  right_id = (my_id + 1) % num_devices
233
+ num_experts = a2a_g_hbm.shape[0]
234
+ padded_num_experts = d2e_count_x2_smem.shape[-1]
235
+ padded_top_k = t2e_routing_x2_smem.shape[-1]
236
+ assert padded_num_experts == align_to(num_experts, 128)
237
+ assert padded_top_k == align_to(top_k, 128)
166
238
 
167
239
  t_dtype = tokens_hbm.dtype
168
240
  t_packing = get_dtype_packing(t_dtype)
169
241
  t_bitwidth = 32 // t_packing
170
242
  assert a2a_g_hbm.dtype == t_dtype
171
- assert w1_hbm.dtype == t_dtype
172
- assert w2_hbm.dtype == t_dtype
243
+ assert w1_hbm.dtype == w2_hbm.dtype
173
244
 
174
- h_per_packing = hidden_size // t_packing
175
- assert tokens_hbm.shape[-1] == h_per_packing
176
- bd1_per_packing = bd1 // t_packing
177
- bd2_per_packing = bd2 // t_packing
178
- bd1c_per_packing = bd1c // t_packing
179
- bd2c_per_packing = bd2c // t_packing
245
+ assert bd1 % bd1c == 0
246
+ assert bd2 % bd2c == 0
247
+ assert bf % bfc == 0
248
+ assert hidden_size % t_packing == 0
249
+ assert bd1 % t_packing == 0
250
+ assert bd2 % t_packing == 0
251
+ assert bd1c % t_packing == 0
252
+ assert bd2c % t_packing == 0
253
+
254
+ h_per_t_packing = hidden_size // t_packing
255
+ assert tokens_hbm.shape[-1] == h_per_t_packing
256
+ bd1_per_t_packing = bd1 // t_packing
257
+ bd2_per_t_packing = bd2 // t_packing
258
+ bd1c_per_t_packing = bd1c // t_packing
259
+ bd2c_per_t_packing = bd2c // t_packing
260
+
261
+ if subc_quant_wsz is not None:
262
+ assert subc_quant_wsz % 256 == 0
263
+ assert bd1c_per_t_packing == subc_quant_wsz
264
+ assert bfc == subc_quant_wsz
265
+ assert bd1 % subc_quant_wsz == 0
266
+ assert bf % subc_quant_wsz == 0
267
+ assert bd1_per_t_packing % subc_quant_wsz == 0
268
+ assert h_per_t_packing % subc_quant_wsz == 0
180
269
 
181
270
  num_bt = cdiv(local_num_tokens, bt)
182
271
  num_bf = cdiv(intermediate_size, bf)
183
272
  num_bd1 = cdiv(hidden_size, bd1)
184
273
  num_bd2 = cdiv(hidden_size, bd2)
185
274
 
275
+ def get_mesh_device_id(ep_rank):
276
+ dp_rank = jax.lax.axis_index("data")
277
+ return (dp_rank, ep_rank)
278
+
186
279
  def sync_barrier():
187
280
  barrier_sem = pltpu.get_barrier_semaphore()
188
281
  pltpu.semaphore_signal(
189
282
  barrier_sem,
190
- device_id=(0, right_id),
283
+ device_id=get_mesh_device_id(right_id),
191
284
  device_id_type=pltpu.DeviceIdType.MESH,
192
285
  )
193
286
  pltpu.semaphore_wait(barrier_sem, 1)
@@ -212,30 +305,44 @@ def _fused_ep_moe_kernel(
212
305
  sem=b_gating_sem,
213
306
  ).wait()
214
307
 
215
- def get_top_k(input, top_k):
308
+ def get_top_k(input, top_k, renormalize_topk_logits):
216
309
  assert len(input.shape) == 2, input.shape
217
310
  input = input.astype(jnp.float32)
311
+ padded_k_shape = (input.shape[0], padded_top_k)
218
312
  top_k_logits_lst = []
219
313
  top_k_indices_lst = []
220
314
  t2e = jnp.zeros(input.shape, dtype=jnp.int32)
221
- t2e_routing = jnp.zeros(input.shape, dtype=jnp.int32)
315
+ t2e_routing = jnp.zeros(padded_k_shape, dtype=jnp.int32)
222
316
  iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1)
317
+ padded_k_iota = jax.lax.broadcasted_iota(jnp.int32, padded_k_shape, 1)
318
+ top_k_logits_sum = jnp.zeros(padded_k_shape, jnp.float32)
319
+
223
320
  for k_id in range(top_k):
224
- # TODO(jevinjiang): return both top_k values and indices in op in Mosaic
321
+ # TODO(jevinjiang): return both top_k values and indices in Mosaic
225
322
  top_k_logits = jnp.broadcast_to(
226
- jnp.max(input, axis=1, keepdims=True),
227
- (input.shape[0], 128)).astype(input.dtype)
323
+ jnp.max(input[:, :num_experts], axis=1, keepdims=True),
324
+ padded_k_shape,
325
+ ).astype(input.dtype)
228
326
  top_k_logits_lst.append(top_k_logits)
327
+ if renormalize_topk_logits:
328
+ top_k_logits_sum += top_k_logits
229
329
  # TODO(jevinjiang): support bf16 argmax in Mosaic
230
330
  top_k_indices = jnp.broadcast_to(
231
- jnp.argmax(input, axis=1, keepdims=True), input.shape)
331
+ jnp.argmax(input[:, :num_experts], axis=1, keepdims=True),
332
+ padded_k_shape,
333
+ )
232
334
  top_k_indices_lst.append(top_k_indices)
233
- t2e_routing = jnp.where(iota == k_id, top_k_indices, t2e_routing)
234
- mask = iota == top_k_indices
335
+ t2e_routing = jnp.where(padded_k_iota == k_id, top_k_indices,
336
+ t2e_routing)
337
+ mask = iota == broadcast_minor(top_k_indices, input.shape)
235
338
  t2e += mask.astype(jnp.int32)
236
339
  if k_id != top_k - 1:
237
340
  input = jnp.where(mask, -jnp.inf, input)
238
341
 
342
+ if renormalize_topk_logits:
343
+ for k_id in range(top_k):
344
+ top_k_logits_lst[k_id] /= top_k_logits_sum
345
+
239
346
  expert_sizes = jnp.sum(t2e, axis=0, keepdims=True)
240
347
  expert_starts = jnp.zeros_like(expert_sizes)
241
348
  return top_k_logits_lst, t2e_routing, expert_sizes, expert_starts
@@ -277,7 +384,7 @@ def _fused_ep_moe_kernel(
277
384
  dst_ref=d2e_count_vmem.at[row_id],
278
385
  send_sem=send_sem,
279
386
  recv_sem=recv_sem,
280
- device_id=(0, right_id),
387
+ device_id=get_mesh_device_id(right_id),
281
388
  device_id_type=pltpu.DeviceIdType.MESH,
282
389
  ).wait()
283
390
  row_id = (row_id + num_devices - 1) % num_devices
@@ -359,10 +466,8 @@ def _fused_ep_moe_kernel(
359
466
  pl.ds(start, remote_sz)],
360
467
  send_sem=send_sems.at[e_sem_id],
361
468
  recv_sem=recv_sems.at[e_sem_id],
362
- device_id=(
363
- 0,
364
- recv_id,
365
- ),
469
+ device_id=get_mesh_device_id(recv_id),
470
+ device_id_type=pltpu.DeviceIdType.MESH,
366
471
  ).start()
367
472
  a2a_s_sends_x2_smem[e_sem_id] = send_sz
368
473
 
@@ -406,7 +511,8 @@ def _fused_ep_moe_kernel(
406
511
  dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)],
407
512
  send_sem=send_sems.at[e_sem_id],
408
513
  recv_sem=a2a_gather_sem,
409
- device_id=(0, recv_id),
514
+ device_id=get_mesh_device_id(recv_id),
515
+ device_id_type=pltpu.DeviceIdType.MESH,
410
516
  ).start()
411
517
  start += sz
412
518
 
@@ -435,68 +541,173 @@ def _fused_ep_moe_kernel(
435
541
 
436
542
  def start_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
437
543
  for p in range(t_packing):
438
- offset = p * h_per_packing + bd1_id * bd1_per_packing
544
+ offset = p * h_per_t_packing + bd1_id * bd1_per_t_packing
439
545
  pltpu.make_async_copy(
440
546
  src_ref=w1_hbm.at[
441
547
  local_e_id,
442
548
  0,
443
- pl.ds(offset, bd1_per_packing),
549
+ pl.ds(offset, bd1_per_t_packing),
444
550
  pl.ds(bf_id * bf, bf),
445
551
  ],
446
552
  dst_ref=b_w1_x2_vmem.at[bw1_sem_id, p],
447
553
  sem=local_sems.at[bw1_sem_id, 1],
448
554
  ).start()
555
+ if w1_scale_hbm is not None:
556
+ assert subc_quant_wsz is not None
557
+ pltpu.make_async_copy(
558
+ src_ref=w1_scale_hbm.at[
559
+ local_e_id,
560
+ 0,
561
+ pl.ds(
562
+ offset // subc_quant_wsz,
563
+ bd1_per_t_packing // subc_quant_wsz,
564
+ ),
565
+ pl.ds(0, 1),
566
+ pl.ds(bf_id * bf, bf),
567
+ ],
568
+ dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id, p],
569
+ sem=local_sems.at[bw1_sem_id, 1],
570
+ ).start()
571
+ if b1_hbm is not None and bd1_id == 0:
572
+ pltpu.make_async_copy(
573
+ src_ref=b1_hbm.at[local_e_id, 0,
574
+ pl.ds(0, 1),
575
+ pl.ds(bf_id * bf, bf)],
576
+ dst_ref=b_b1_x2_vmem.at[bf_id % 2],
577
+ sem=local_sems.at[bw1_sem_id, 1],
578
+ ).start()
449
579
 
450
580
  def start_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
451
581
  for p in range(t_packing):
452
- offset = p * h_per_packing + bd2_id * bd2_per_packing
582
+ offset = p * h_per_t_packing + bd2_id * bd2_per_t_packing
453
583
  pltpu.make_async_copy(
454
584
  src_ref=w2_hbm.at[
455
585
  local_e_id,
456
586
  pl.ds(bf_id * bf, bf),
457
- pl.ds(offset, bd2_per_packing),
587
+ pl.ds(offset, bd2_per_t_packing),
458
588
  ],
459
589
  dst_ref=b_w2_x2_vmem.at[bw2_sem_id, p],
460
590
  sem=local_sems.at[bw2_sem_id, 2],
461
591
  ).start()
592
+ if w2_scale_hbm is not None:
593
+ assert subc_quant_wsz is not None
594
+ pltpu.make_async_copy(
595
+ src_ref=w2_scale_hbm.at[
596
+ local_e_id,
597
+ pl.ds(bf_id * bf // subc_quant_wsz, bf //
598
+ subc_quant_wsz),
599
+ pl.ds(0, 1),
600
+ pl.ds(offset, bd2_per_t_packing),
601
+ ],
602
+ dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id, p],
603
+ sem=local_sems.at[bw2_sem_id, 2],
604
+ ).start()
605
+ if b2_hbm is not None and bf_id == 0:
606
+ pltpu.make_async_copy(
607
+ src_ref=b2_hbm.at[local_e_id,
608
+ pl.ds(0, 1),
609
+ pl.ds(offset, bd2_per_t_packing)],
610
+ dst_ref=b_b2_x2_vmem.at[bd2_id % 2, p],
611
+ sem=local_sems.at[bw2_sem_id, 2],
612
+ ).start()
462
613
 
463
614
  def start_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
464
615
  for p in range(t_packing):
465
- offset = p * h_per_packing + bd3_id * bd1_per_packing
616
+ offset = p * h_per_t_packing + bd3_id * bd1_per_t_packing
466
617
  pltpu.make_async_copy(
467
618
  src_ref=w1_hbm.at[
468
619
  local_e_id,
469
620
  1,
470
- pl.ds(offset, bd1_per_packing),
621
+ pl.ds(offset, bd1_per_t_packing),
471
622
  pl.ds(bf_id * bf, bf),
472
623
  ],
473
624
  dst_ref=b_w3_x2_vmem.at[bw3_sem_id, p],
474
625
  sem=local_sems.at[bw3_sem_id, 3],
475
626
  ).start()
627
+ if w1_scale_hbm is not None:
628
+ assert subc_quant_wsz is not None
629
+ pltpu.make_async_copy(
630
+ src_ref=w1_scale_hbm.at[
631
+ local_e_id,
632
+ 1,
633
+ pl.ds(
634
+ offset // subc_quant_wsz,
635
+ bd1_per_t_packing // subc_quant_wsz,
636
+ ),
637
+ pl.ds(0, 1),
638
+ pl.ds(bf_id * bf, bf),
639
+ ],
640
+ dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id, p],
641
+ sem=local_sems.at[bw3_sem_id, 3],
642
+ ).start()
643
+ if b1_hbm is not None and bd3_id == 0:
644
+ pltpu.make_async_copy(
645
+ src_ref=b1_hbm.at[local_e_id, 1,
646
+ pl.ds(0, 1),
647
+ pl.ds(bf_id * bf, bf)],
648
+ dst_ref=b_b3_x2_vmem.at[bf_id % 2],
649
+ sem=local_sems.at[bw3_sem_id, 3],
650
+ ).start()
476
651
 
477
652
  def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
478
- del local_e_id, bf_id, bd1_id
653
+ del local_e_id
479
654
  pltpu.make_async_copy(
480
655
  src_ref=b_w1_x2_vmem.at[bw1_sem_id],
481
656
  dst_ref=b_w1_x2_vmem.at[bw1_sem_id],
482
657
  sem=local_sems.at[bw1_sem_id, 1],
483
658
  ).wait()
659
+ if w1_scale_hbm is not None:
660
+ pltpu.make_async_copy(
661
+ src_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
662
+ dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
663
+ sem=local_sems.at[bw1_sem_id, 1],
664
+ ).wait()
665
+ if b1_hbm is not None and bd1_id == 0:
666
+ pltpu.make_async_copy(
667
+ src_ref=b_b1_x2_vmem.at[bf_id % 2],
668
+ dst_ref=b_b1_x2_vmem.at[bf_id % 2],
669
+ sem=local_sems.at[bw1_sem_id, 1],
670
+ ).wait()
484
671
 
485
672
  def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
486
- del local_e_id, bf_id, bd2_id
673
+ del local_e_id
487
674
  pltpu.make_async_copy(
488
675
  src_ref=b_w2_x2_vmem.at[bw2_sem_id],
489
676
  dst_ref=b_w2_x2_vmem.at[bw2_sem_id],
490
677
  sem=local_sems.at[bw2_sem_id, 2],
491
678
  ).wait()
679
+ if w2_scale_hbm is not None:
680
+ pltpu.make_async_copy(
681
+ src_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
682
+ dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
683
+ sem=local_sems.at[bw2_sem_id, 2],
684
+ ).wait()
685
+ if b2_hbm is not None and bf_id == 0:
686
+ pltpu.make_async_copy(
687
+ src_ref=b_b2_x2_vmem.at[bd2_id % 2],
688
+ dst_ref=b_b2_x2_vmem.at[bd2_id % 2],
689
+ sem=local_sems.at[bw2_sem_id, 2],
690
+ ).wait()
492
691
 
493
692
  def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
494
- del local_e_id, bf_id, bd3_id
693
+ del local_e_id
495
694
  pltpu.make_async_copy(
496
695
  src_ref=b_w3_x2_vmem.at[bw3_sem_id],
497
696
  dst_ref=b_w3_x2_vmem.at[bw3_sem_id],
498
697
  sem=local_sems.at[bw3_sem_id, 3],
499
698
  ).wait()
699
+ if w1_scale_hbm is not None:
700
+ pltpu.make_async_copy(
701
+ src_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
702
+ dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
703
+ sem=local_sems.at[bw3_sem_id, 3],
704
+ ).wait()
705
+ if b1_hbm is not None and bd3_id == 0:
706
+ pltpu.make_async_copy(
707
+ src_ref=b_b3_x2_vmem.at[bf_id % 2],
708
+ dst_ref=b_b3_x2_vmem.at[bf_id % 2],
709
+ sem=local_sems.at[bw3_sem_id, 3],
710
+ ).wait()
500
711
 
501
712
  def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id):
502
713
  next_bd1_id = bd1_id + 1
@@ -520,18 +731,38 @@ def _fused_ep_moe_kernel(
520
731
  def dynamic_ffn1(
521
732
  t_b32_vmem,
522
733
  w1_vmem,
734
+ w1_scale_vmem,
735
+ b1_vmem,
523
736
  w3_vmem,
737
+ w3_scale_vmem,
738
+ b3_vmem,
524
739
  acc1_vmem,
525
740
  acc3_vmem,
526
741
  dyn_sz,
527
742
  should_init,
528
743
  ):
529
744
  assert t_b32_vmem.shape == (bt * num_devices, bd1 // t_packing)
530
- assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_packing,
745
+ assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_t_packing,
531
746
  bf)
532
747
  assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
533
748
  assert bd1 % (t_packing * 128) == 0, (bd1, t_packing)
534
749
  assert bd1c % (t_packing * 128) == 0, (bd1c, t_packing)
750
+ if w1_scale_vmem is not None:
751
+ assert w1_scale_vmem.shape == (
752
+ t_packing,
753
+ bd1_per_t_packing // subc_quant_wsz,
754
+ 1,
755
+ bf,
756
+ )
757
+ assert bd1c_per_t_packing == subc_quant_wsz
758
+ if w3_scale_vmem is not None:
759
+ assert w3_scale_vmem.shape == (
760
+ t_packing,
761
+ bd1_per_t_packing // subc_quant_wsz,
762
+ 1,
763
+ bf,
764
+ )
765
+ assert bd1c_per_t_packing == subc_quant_wsz
535
766
 
536
767
  num_loops = cdiv(dyn_sz, btc)
537
768
  repack_ty = jnp.dtype(f"int{t_bitwidth}")
@@ -540,7 +771,7 @@ def _fused_ep_moe_kernel(
540
771
  for bd1c_id in range(cdiv(bd1, bd1c)):
541
772
  t_b32 = t_b32_vmem[
542
773
  pl.ds(btc_id * btc, btc),
543
- pl.ds(bd1c_id * bd1c_per_packing, bd1c_per_packing),
774
+ pl.ds(bd1c_id * bd1c_per_t_packing, bd1c_per_t_packing),
544
775
  ]
545
776
  for p_id in range(t_packing):
546
777
  t = pltpu.bitcast(t_b32.astype(repack_ty), t_dtype)
@@ -548,21 +779,64 @@ def _fused_ep_moe_kernel(
548
779
  for bfc_id in range(cdiv(bf, bfc)):
549
780
  w_slices = (
550
781
  p_id,
551
- pl.ds(bd1c_id * bd1c_per_packing,
552
- bd1c_per_packing),
782
+ pl.ds(bd1c_id * bd1c_per_t_packing,
783
+ bd1c_per_t_packing),
553
784
  pl.ds(bfc_id * bfc, bfc),
554
785
  )
555
786
  w1 = w1_vmem[*w_slices]
556
787
  acc1 = jnp.dot(t,
557
788
  w1,
558
789
  preferred_element_type=jnp.float32)
790
+
791
+ if w1_scale_vmem is not None:
792
+ w1_scale_slices = (
793
+ p_id,
794
+ bd1c_id,
795
+ pl.ds(0, 1),
796
+ pl.ds(bfc_id * bfc, bfc),
797
+ )
798
+ # TODO(jevinjiang): can use mosaic to load with stride 0.
799
+ w1_scale = jnp.broadcast_to(
800
+ w1_scale_vmem[*w1_scale_slices], acc1.shape)
801
+ acc1 *= w1_scale
802
+
559
803
  w3 = w3_vmem[*w_slices]
804
+
560
805
  acc3 = jnp.dot(t,
561
806
  w3,
562
807
  preferred_element_type=jnp.float32)
808
+
809
+ if w3_scale_vmem is not None:
810
+ w3_scale_slices = (
811
+ p_id,
812
+ bd1c_id,
813
+ pl.ds(0, 1),
814
+ pl.ds(bfc_id * bfc, bfc),
815
+ )
816
+ w3_scale = jnp.broadcast_to(
817
+ w3_scale_vmem[*w3_scale_slices], acc3.shape)
818
+ acc3 *= w3_scale
819
+
563
820
  acc_slices = (pl.ds(btc_id * btc,
564
821
  btc), pl.ds(bfc_id * bfc, bfc))
565
822
  if should_init and p_id == bd1c_id == 0:
823
+ if b1_vmem is not None:
824
+ b1_scale_slices = (
825
+ pl.ds(0, 1),
826
+ pl.ds(bfc_id * bfc, bfc),
827
+ )
828
+ b1 = jnp.broadcast_to(
829
+ b1_vmem[*b1_scale_slices], acc1.shape)
830
+ acc1 += b1
831
+ if b3_vmem is not None:
832
+ b3_scale_slices = (
833
+ pl.ds(0, 1),
834
+ pl.ds(bfc_id * bfc, bfc),
835
+ )
836
+ b3 = jnp.broadcast_to(
837
+ b3_vmem[*b3_scale_slices], acc1.shape)
838
+ acc3 += b3
839
+
566
840
  acc1_vmem[*acc_slices] = acc1
567
841
  acc3_vmem[*acc_slices] = acc3
568
842
  else:
@@ -575,22 +849,28 @@ def _fused_ep_moe_kernel(
575
849
  acc1_vmem,
576
850
  acc3_vmem,
577
851
  w2_vmem,
852
+ w2_scale_vmem,
853
+ b2_vmem,
578
854
  res_b32_vmem,
579
855
  dyn_sz,
580
856
  should_init,
581
857
  ):
582
- assert res_b32_vmem.shape == (bt * num_devices, bd2_per_packing)
583
- assert w2_vmem.shape == (t_packing, bf, bd2_per_packing), (
584
- w2_vmem.shape,
585
- t_packing,
586
- bf,
587
- bd2_per_packing,
588
- )
858
+ assert res_b32_vmem.shape == (bt * num_devices, bd2_per_t_packing)
859
+ assert w2_vmem.shape == (t_packing, bf, bd2_per_t_packing)
589
860
  assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
590
861
  assert bd2 % (t_packing * 128) == 0, (bd2, t_packing)
591
862
  assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
592
863
  assert t_dtype in (jnp.float32, jnp.bfloat16)
593
864
 
865
+ if w2_scale_vmem is not None:
866
+ assert w2_scale_vmem.shape == (
867
+ t_packing,
868
+ bf // subc_quant_wsz,
869
+ 1,
870
+ bd2_per_t_packing,
871
+ )
872
+ assert bfc == subc_quant_wsz
873
+
594
874
  num_loops = cdiv(dyn_sz, btc)
595
875
  assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
596
876
 
@@ -598,22 +878,47 @@ def _fused_ep_moe_kernel(
598
878
  for bd2c_id in range(cdiv(bd2, bd2c)):
599
879
  res_lst = []
600
880
  for p_id in range(t_packing):
601
- res = jnp.zeros((btc, bd2c_per_packing), dtype=jnp.float32)
881
+ res = jnp.zeros((btc, bd2c_per_t_packing),
882
+ dtype=jnp.float32)
883
+
884
+ if b2_vmem is not None and should_init:
885
+ b2_scale_slices = (
886
+ p_id,
887
+ pl.ds(0, 1),
888
+ pl.ds(bd2c_id * bd2c_per_t_packing,
889
+ bd2c_per_t_packing),
890
+ )
891
+ b2 = jnp.broadcast_to(b2_vmem[*b2_scale_slices],
892
+ res.shape)
893
+ res += b2
894
+
602
895
  for bfc_id in range(cdiv(bf, bfc)):
603
896
  acc_slices = (pl.ds(btc_id * btc,
604
897
  btc), pl.ds(bfc_id * bfc, bfc))
605
898
  acc1 = acc1_vmem[*acc_slices]
606
899
  acc3 = acc3_vmem[*acc_slices]
607
- act = jax.nn.silu(acc1) * acc3
900
+ act = activation_fn(acc1, acc3, act_fn)
608
901
  w2 = w2_vmem[
609
902
  p_id,
610
903
  pl.ds(bfc_id * bfc, bfc),
611
904
  pl.ds(bd2c_id *
612
- bd2c_per_packing, bd2c_per_packing),
905
+ bd2c_per_t_packing, bd2c_per_t_packing),
613
906
  ]
614
- res += jnp.dot(act,
615
- w2,
616
- preferred_element_type=jnp.float32)
907
+ acc = jnp.dot(act,
908
+ w2,
909
+ preferred_element_type=jnp.float32)
910
+ if w2_scale_vmem is not None:
911
+ w2_scale_slices = (
912
+ p_id,
913
+ bfc_id,
914
+ pl.ds(0, 1),
915
+ pl.ds(bd2c_id * bd2c_per_t_packing,
916
+ bd2c_per_t_packing),
917
+ )
918
+ w2_scale = jnp.broadcast_to(
919
+ w2_scale_vmem[*w2_scale_slices], acc.shape)
920
+ acc *= w2_scale
921
+ res += acc
617
922
  res = pltpu.bitcast(res, jnp.uint32)
618
923
  if t_packing == 2:
619
924
  res = res >> 16 << (16 * p_id)
@@ -626,7 +931,7 @@ def _fused_ep_moe_kernel(
626
931
  res |= res_lst[i]
627
932
  sliced_res_vmem = res_b32_vmem.at[
628
933
  pl.ds(btc_id * btc, btc),
629
- pl.ds(bd2c_id * bd2c_per_packing, bd2c_per_packing),
934
+ pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing),
630
935
  ]
631
936
  if should_init:
632
937
  sliced_res_vmem[...] = res
@@ -655,21 +960,33 @@ def _fused_ep_moe_kernel(
655
960
  e_id = my_id * local_num_experts + local_e_id
656
961
  dyn_sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
657
962
 
658
- bd1_per_packing = bd1 // t_packing
659
- bd2_per_packing = bd2 // t_packing
963
+ bd1_per_t_packing = bd1 // t_packing
964
+ bd2_per_t_packing = bd2 // t_packing
660
965
 
661
966
  for bf_id in range(num_bf):
662
967
  for bd1_id in range(num_bd1):
663
968
  start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, 0)
969
+ w1_scale_vmem = (None if b_w1_scale_x2_vmem is None else
970
+ b_w1_scale_x2_vmem.at[bw_sem_id])
971
+ w3_scale_vmem = (None if b_w3_scale_x2_vmem is None else
972
+ b_w3_scale_x2_vmem.at[bw_sem_id])
973
+ b1_vmem = None if b_b1_x2_vmem is None else b_b1_x2_vmem.at[
974
+ bf_id % 2]
975
+ b3_vmem = None if b_b3_x2_vmem is None else b_b3_x2_vmem.at[
976
+ bf_id % 2]
664
977
  wait_fetch_bw1(local_e_id, bw_sem_id, bf_id, bd1_id)
665
978
  wait_fetch_bw3(local_e_id, bw_sem_id, bf_id, bd1_id)
666
979
 
667
980
  dynamic_ffn1(
668
981
  t_b32_vmem=a2a_s_b32_vmem.at[
669
982
  ...,
670
- pl.ds(bd1_id * bd1_per_packing, bd1_per_packing)],
983
+ pl.ds(bd1_id * bd1_per_t_packing, bd1_per_t_packing)],
671
984
  w1_vmem=b_w1_x2_vmem.at[bw_sem_id],
985
+ w1_scale_vmem=w1_scale_vmem,
986
+ b1_vmem=b1_vmem,
672
987
  w3_vmem=b_w3_x2_vmem.at[bw_sem_id],
988
+ w3_scale_vmem=w3_scale_vmem,
989
+ b3_vmem=b3_vmem,
673
990
  acc1_vmem=b_acc1_vmem,
674
991
  acc3_vmem=b_acc3_vmem,
675
992
  dyn_sz=dyn_sz,
@@ -684,13 +1001,19 @@ def _fused_ep_moe_kernel(
684
1001
  if bf_id == bd2_id == 0:
685
1002
  wait_a2a_gather_send(bt_id, e_sem_id, local_e_id - 2)
686
1003
 
1004
+ w2_scale_vmem = (None if b_w2_scale_x2_vmem is None else
1005
+ b_w2_scale_x2_vmem.at[bw_sem_id])
1006
+ b2_vmem = None if b_b2_x2_vmem is None else b_b2_x2_vmem.at[
1007
+ bd2_id % 2]
687
1008
  dynamic_ffn2(
688
1009
  acc1_vmem=b_acc1_vmem,
689
1010
  acc3_vmem=b_acc3_vmem,
690
1011
  w2_vmem=b_w2_x2_vmem.at[bw_sem_id],
1012
+ w2_scale_vmem=w2_scale_vmem,
1013
+ b2_vmem=b2_vmem,
691
1014
  res_b32_vmem=a2a_s_acc_b32_vmem.at[
692
1015
  ...,
693
- pl.ds(bd2_id * bd2_per_packing, bd2_per_packing)],
1016
+ pl.ds(bd2_id * bd2_per_t_packing, bd2_per_t_packing)],
694
1017
  dyn_sz=dyn_sz,
695
1018
  should_init=(bf_id == 0),
696
1019
  )
@@ -757,31 +1080,42 @@ def _fused_ep_moe_kernel(
757
1080
  b_gating = b_gating_x2_vmem[bt_sem_id]
758
1081
  b_gating_score = jax.nn.softmax(b_gating, axis=-1)
759
1082
  top_k_logits_lst, t2e_routing, expert_sizes, expert_starts = get_top_k(
760
- b_gating_score, top_k)
1083
+ b_gating_score, top_k, renormalize_topk_logits)
761
1084
 
762
1085
  all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts,
763
1086
  expert_sizes)
1087
+ sync_barrier()
764
1088
 
1089
+ # Start a2a scatter for first active expert.
765
1090
  start_a2a_scatter(bt_id=bt_id, e_sem_id=e_sem_id, local_e_id=0)
766
1091
 
767
1092
  def run_per_expert(local_e_id, e_sem_id):
768
1093
  sync_barrier()
1094
+
1095
+ # Prefetch weights for CURRENT active expert.
1096
+ # TODO(jevinjiang): It is hard to prefetch weights in previous iteration
1097
+ # because the expert_ffn keeps overwriting the buffers. Triple buffering
1098
+ # could resolve this but it takes more VMEM scratch. Need further
1099
+ # experiment on this.
1100
+ start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
1101
+ start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
1102
+
1103
+ # Next ids.
769
1104
  next_e_sem_id = lax.select(e_sem_id == 0, 1, 0)
770
1105
  next_local_e_id = local_e_id + 1
771
1106
 
1107
+ # Start a2a scatter for NEXT active expert.
772
1108
  @pl.when(next_local_e_id < local_num_experts)
773
1109
  def _():
774
1110
  start_a2a_scatter(bt_id, next_e_sem_id, next_local_e_id)
775
1111
 
776
- # Prefetch weights for active expert.
777
- start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
778
- start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
779
-
780
- # Wait for a2a scatter and perform FFN for active expert.
1112
+ # Wait a2a scatter for CURRENT active expert.
781
1113
  wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id)
1114
+
1115
+ # Perform FFN for CURRENT active expert.
782
1116
  expert_ffn(bt_id, e_sem_id, local_e_id)
783
1117
 
784
- # Wait for a2a gather to send back tokens for active expert.
1118
+ # Start a2a gather to send back tokens for CURRENT active expert.
785
1119
  start_a2a_gather(bt_id, e_sem_id, local_e_id)
786
1120
 
787
1121
  # A must-wait before next sync_barrier.
@@ -794,7 +1128,10 @@ def _fused_ep_moe_kernel(
794
1128
  e_sem_id,
795
1129
  unroll=False)
796
1130
 
1131
+ # Wait to receive a2a gather for ALL experts.
797
1132
  wait_a2a_gather_recv_all()
1133
+
1134
+ # Accumulate results for current batch.
798
1135
  output = bt_acc(bt_id, top_k_logits_lst)
799
1136
 
800
1137
  # Make sure it is safe to overwrite output buffer.
@@ -827,6 +1164,9 @@ def _fused_ep_moe_kernel(
827
1164
  static_argnames=[
828
1165
  "mesh",
829
1166
  "top_k",
1167
+ "renormalize_topk_logits",
1168
+ "act_fn",
1169
+ "subc_quant_wsz",
830
1170
  "bt",
831
1171
  "bf",
832
1172
  "bd1",
@@ -846,6 +1186,17 @@ def fused_ep_moe(
846
1186
  gating_output: jax.Array, # (num_tokens, num_experts)
847
1187
  top_k: int,
848
1188
  *,
1189
+ renormalize_topk_logits: bool = False,
1190
+ act_fn: str = "silu",
1191
+ subc_quant_wsz: int | None = None,
1192
+ w1_scale: (
1193
+ jax.Array | None
1194
+ ) = None, # F32(num_experts, 2, hidden_size // subc_quant_wsz, 1, intermediate_size)
1195
+ w2_scale: (
1196
+ jax.Array | None
1197
+ ) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
1198
+ b1: jax.Array | None = None, # F32(num_experts, 2, 1, intermediate_size)
1199
+ b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
849
1200
  # Kernel tuning parameters.
850
1201
  bt: int,
851
1202
  bf: int,
@@ -855,52 +1206,164 @@ def fused_ep_moe(
855
1206
  bfc: int,
856
1207
  bd1c: int,
857
1208
  bd2c: int,
858
- ep_axis_name: str = 'model',
1209
+ ep_axis_name: str = "model",
859
1210
  ):
860
- # Assert all other axes have length of 1
861
- assert len(mesh.shape) == 2, "Expect 2D mesh in tpu-inference"
862
- assert 'data' in mesh.shape and mesh.shape['data'] == 1, \
863
- "Expect data axis size of 1 in tpu-inference"
1211
+ # TODO(jevinjiang): move all these assertions to validation function.
1212
+ if len(mesh.shape) != 2:
1213
+ raise NotImplementedError("Only 2D mesh is supported.")
1214
+
1215
+ for axis_name in mesh.axis_names:
1216
+ if axis_name == ep_axis_name:
1217
+ continue
1218
+ if mesh.shape[axis_name] != 1:
1219
+ raise NotImplementedError(
1220
+ f"Expected all non-ep axis to have size 1 in {mesh.shape=}")
864
1221
 
865
1222
  ep_size = mesh.shape[ep_axis_name]
866
1223
  num_devices = ep_size
867
1224
 
868
- num_tokens, actual_hidden_size = tokens.shape
1225
+ num_tokens, hidden_size = tokens.shape
869
1226
  num_experts, intermediate_size, _ = w2.shape
870
1227
 
871
- assert num_tokens % ep_size == 0
872
- assert num_experts % ep_size == 0
1228
+ if w1.shape != (num_experts, 2, hidden_size, intermediate_size):
1229
+ raise ValueError(
1230
+ f"Expected {w1.shape=} to be"
1231
+ f" {(num_experts, 2, hidden_size, intermediate_size)}.")
1232
+
1233
+ if w2.shape != (num_experts, intermediate_size, hidden_size):
1234
+ raise ValueError(f"Expected {w2.shape=} to be"
1235
+ f" {(num_experts, intermediate_size, hidden_size)}.")
1236
+
1237
+ if gating_output.shape != (num_tokens, num_experts):
1238
+ raise ValueError(
1239
+ f"Expected {gating_output.shape=} to be {(num_tokens, num_experts)}."
1240
+ )
1241
+
1242
+ if not (0 < top_k <= num_experts):
1243
+ raise ValueError(
1244
+ f"Expected {top_k=} to be in range (0, {num_experts=}].")
1245
+
1246
+ if hidden_size % 128 != 0 or intermediate_size % 128 != 0:
1247
+ raise ValueError(
1248
+ f"Expected {hidden_size=} and {intermediate_size=} to be aligned to"
1249
+ " 128. Did you pad them with zeros outside the kernel?")
1250
+ if num_tokens % ep_size != 0:
1251
+ raise ValueError(
1252
+ f"Expected {num_tokens=} to be aligned to {ep_size=}.")
1253
+ if num_experts % ep_size != 0:
1254
+ raise ValueError(
1255
+ f"Expected {num_experts=} to be aligned to {ep_size=}.")
873
1256
 
874
1257
  local_num_tokens = num_tokens // ep_size
875
1258
  # local_num_experts = num_experts // ep_size
876
1259
  padded_num_experts = align_to(num_experts, 128)
877
-
1260
+ padded_top_k = align_to(top_k, 128)
878
1261
  t_dtype = tokens.dtype
879
1262
  t_packing = get_dtype_packing(t_dtype)
880
- hidden_size = align_to(actual_hidden_size, 128 * t_packing)
881
- if hidden_size != actual_hidden_size:
882
- tokens = jnp.pad(
883
- tokens,
884
- ((0, 0), (0, hidden_size - actual_hidden_size)),
885
- constant_values=0,
886
- )
887
- tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
888
- bt = min(bt, local_num_tokens)
889
- bf = min(bf, intermediate_size)
890
- bd1 = min(bd1, hidden_size)
891
- bd2 = min(bd2, hidden_size)
892
-
893
- btc = min(btc, bt * num_devices)
894
- bfc = min(bfc, bf)
895
- bd1c = min(bd1c, bd1)
896
- bd2c = min(bd2c, bd2)
897
- assert bfc % 128 == 0
898
- assert bd1c % (t_packing * 128) == 0
899
- assert bd2c % (t_packing * 128) == 0
900
- assert bf % bfc == 0
901
- assert bd1 % bd1c == 0
902
- assert bd2 % bd2c == 0
903
1263
 
1264
+ # Override bt
1265
+ if local_num_tokens <= t_packing * 8:
1266
+ bt = local_num_tokens
1267
+ btc = bt
1268
+ bt = min(local_num_tokens, bt)
1269
+ # The worst case is that all devices send bt to one device.
1270
+ btc = min(bt, btc, bt * num_devices)
1271
+
1272
+ if local_num_tokens % t_packing != 0:
1273
+ raise ValueError(
1274
+ f"Expected {local_num_tokens=} to be aligned to {t_packing=}.")
1275
+
1276
+ if bt % t_packing != 0:
1277
+ raise ValueError(f"Expected {bt=} to be aligned to {t_packing=}.")
1278
+ if local_num_tokens % bt != 0:
1279
+ raise ValueError(
1280
+ f"Expected {local_num_tokens=} to be aligned to {bt=}.")
1281
+
1282
+ if subc_quant_wsz is not None:
1283
+ if subc_quant_wsz <= 0:
1284
+ raise ValueError(f"Expected {subc_quant_wsz=} to be non-negative.")
1285
+ if subc_quant_wsz % 256 != 0:
1286
+ raise ValueError(
1287
+ "Expected {subc_quant_wsz=} to be aligned to 256.")
1288
+ if hidden_size % subc_quant_wsz != 0:
1289
+ raise ValueError(
1290
+ f"Expected {hidden_size=} to be aligned to {subc_quant_wsz=}.")
1291
+ if intermediate_size % subc_quant_wsz != 0:
1292
+ raise ValueError(
1293
+ f"Expected {intermediate_size=} to be aligned to {subc_quant_wsz=}."
1294
+ )
1295
+ # We force compute size of contracting dim to be subc_quant_wsz. So we can
1296
+ # apply same scale after matmul and accumulation.
1297
+ bd1c = subc_quant_wsz * t_packing
1298
+ bfc = subc_quant_wsz
1299
+
1300
+ if bfc % 128 != 0:
1301
+ raise ValueError(f"Expected {bfc=} to be aligned to 128.")
1302
+ if bd1c % (t_packing * 128) != 0:
1303
+ raise ValueError(
1304
+ f"Expected {bd1c=} to be aligned to {t_packing * 128}.")
1305
+ if bd2c % (t_packing * 128) != 0:
1306
+ raise ValueError(
1307
+ f"Expected {bd2c=} to be aligned to {t_packing * 128}.")
1308
+ if bf % bfc != 0:
1309
+ raise ValueError(f"Expected {bf=} to be aligned to {bfc=}.")
1310
+ if bd1 % bd1c != 0:
1311
+ raise ValueError(f"Expected {bd1=} to be aligned to {bd1c=}.")
1312
+ if bd2 % bd2c != 0:
1313
+ raise ValueError(f"Expected {bd2=} to be aligned to {bd2c=}.")
1314
+ if hidden_size % bd1 != 0 or hidden_size % bd2 != 0:
1315
+ raise ValueError(
1316
+ f"Expected {hidden_size=} to be aligned to {bd1=} and {bd2=}.")
1317
+ if intermediate_size % bf != 0:
1318
+ raise ValueError(
1319
+ f"Expected {intermediate_size=} to be aligned to {bf=}.")
1320
+
1321
+ # Note: we should dump scale as the kernel expected shape in the
1322
+ # checkpoint offline or reshape right after weight loading.
1323
+ if w1_scale is not None:
1324
+ expected_w1_scale_shape = (
1325
+ num_experts,
1326
+ 2,
1327
+ hidden_size // subc_quant_wsz,
1328
+ 1,
1329
+ intermediate_size,
1330
+ )
1331
+ if w1_scale.shape != expected_w1_scale_shape:
1332
+ raise ValueError(
1333
+ f"Expected {w1_scale.shape=} to be {expected_w1_scale_shape}.")
1334
+ if w1_scale.dtype != jnp.float32:
1335
+ w1_scale = w1_scale.astype(jnp.float32)
1336
+
1337
+ if w2_scale is not None:
1338
+ expected_w2_scale_shape = (
1339
+ num_experts,
1340
+ intermediate_size // subc_quant_wsz,
1341
+ 1,
1342
+ hidden_size,
1343
+ )
1344
+ if w2_scale.shape != expected_w2_scale_shape:
1345
+ raise ValueError(
1346
+ f"Expected {w2_scale.shape=} to be {expected_w2_scale_shape}.")
1347
+ if w2_scale.dtype != jnp.float32:
1348
+ w2_scale = w2_scale.astype(jnp.float32)
1349
+
1350
+ if b1 is not None:
1351
+ expected_b1_shape = (num_experts, 2, 1, intermediate_size)
1352
+ if b1.shape != expected_b1_shape:
1353
+ raise ValueError(
1354
+ f"Expected {b1.shape=} to be {expected_b1_shape}.")
1355
+ if b1.dtype != jnp.float32:
1356
+ b1 = b1.astype(jnp.float32)
1357
+
1358
+ if b2 is not None:
1359
+ expected_b2_shape = (num_experts, 1, hidden_size)
1360
+ if b2.shape != expected_b2_shape:
1361
+ raise ValueError(
1362
+ f"Expected {b2.shape=} to be {expected_b2_shape}.")
1363
+ if b2.dtype != jnp.float32:
1364
+ b2 = b2.astype(jnp.float32)
1365
+
1366
+ # Prepare inputs for the kernel.
904
1367
  if padded_num_experts != gating_output.shape[-1]:
905
1368
  gating_output = jnp.pad(
906
1369
  gating_output,
@@ -908,13 +1371,20 @@ def fused_ep_moe(
908
1371
  constant_values=-jnp.inf,
909
1372
  )
910
1373
 
911
- scope_name = f"fused_moe_k-{top_k}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}"
1374
+ tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
1375
+
1376
+ hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
1377
+ renorm_str = "-renorm_k" if renormalize_topk_logits else ""
1378
+ scope_name = f"fused-moe-k_{top_k}{renorm_str}-bt_{bt}_{btc}-bf_{bf}_{bfc}-bd1_{bd1}_{bd1c}-bd2_{bd2}_{bd2c}"
912
1379
  fused_moe = jax.named_scope(scope_name)(
913
1380
  pl.pallas_call(
914
1381
  functools.partial(
915
1382
  _fused_ep_moe_kernel,
916
1383
  top_k=top_k,
1384
+ renormalize_topk_logits=renormalize_topk_logits,
917
1385
  ep_axis_name=ep_axis_name,
1386
+ act_fn=act_fn,
1387
+ subc_quant_wsz=subc_quant_wsz,
918
1388
  bt=bt,
919
1389
  bf=bf,
920
1390
  bd1=bd1,
@@ -929,16 +1399,22 @@ def fused_ep_moe(
929
1399
  grid_spec=pltpu.PrefetchScalarGridSpec(
930
1400
  num_scalar_prefetch=0,
931
1401
  in_specs=[
932
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
933
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
934
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
935
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
936
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
1402
+ hbm_block_spec, # tokens_hbm
1403
+ hbm_block_spec, # w1_hbm
1404
+ hbm_block_spec, # w2_hbm
1405
+ None
1406
+ if w1_scale is None else hbm_block_spec, # w1_scale_hbm
1407
+ None
1408
+ if w2_scale is None else hbm_block_spec, # w2_scale_hbm
1409
+ None if b1 is None else hbm_block_spec, # b1_hbm
1410
+ None if b2 is None else hbm_block_spec, # b2_hbm
1411
+ hbm_block_spec, # gating_output_hbm
1412
+ hbm_block_spec, # a2a_g_hbm
937
1413
  ],
938
1414
  out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
939
1415
  scratch_shapes=([
940
1416
  # t2e_routing_x2_smem
941
- pltpu.SMEM((2, bt, padded_num_experts), jnp.int32),
1417
+ pltpu.SMEM((2, bt, padded_top_k), jnp.int32),
942
1418
  # d2e_count_x2_smem
943
1419
  pltpu.SMEM((2, num_devices, 1, padded_num_experts),
944
1420
  jnp.int32),
@@ -984,6 +1460,67 @@ def fused_ep_moe(
984
1460
  pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
985
1461
  # b_w2_x2_vmem
986
1462
  pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
1463
+ # b_w1_scale_x2_vmem
1464
+ (None if w1_scale is None else pltpu.VMEM(
1465
+ (
1466
+ 2,
1467
+ t_packing,
1468
+ bd1 // t_packing // subc_quant_wsz,
1469
+ 1,
1470
+ bf,
1471
+ ),
1472
+ jnp.float32,
1473
+ )),
1474
+ # b_w3_scale_x2_vmem
1475
+ (None if w1_scale is None else pltpu.VMEM(
1476
+ (
1477
+ 2,
1478
+ t_packing,
1479
+ bd1 // t_packing // subc_quant_wsz,
1480
+ 1,
1481
+ bf,
1482
+ ),
1483
+ jnp.float32,
1484
+ )),
1485
+ # b_w2_scale_x2_vmem
1486
+ (None if w2_scale is None else pltpu.VMEM(
1487
+ (
1488
+ 2,
1489
+ t_packing,
1490
+ bf // subc_quant_wsz,
1491
+ 1,
1492
+ bd2 // t_packing,
1493
+ ),
1494
+ jnp.float32,
1495
+ )),
1496
+ # b_b1_x2_vmem
1497
+ (None if b1 is None else pltpu.VMEM(
1498
+ (
1499
+ 2,
1500
+ 1,
1501
+ bf,
1502
+ ),
1503
+ jnp.float32,
1504
+ )),
1505
+ # b_b3_x2_vmem
1506
+ (None if b1 is None else pltpu.VMEM(
1507
+ (
1508
+ 2,
1509
+ 1,
1510
+ bf,
1511
+ ),
1512
+ jnp.float32,
1513
+ )),
1514
+ # b_b2_x2_vmem
1515
+ (None if b2 is None else pltpu.VMEM(
1516
+ (
1517
+ 2,
1518
+ t_packing,
1519
+ 1,
1520
+ bd2 // t_packing,
1521
+ ),
1522
+ jnp.float32,
1523
+ )),
987
1524
  # b_acc_vmem
988
1525
  pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
989
1526
  # local_sems
@@ -1006,30 +1543,62 @@ def fused_ep_moe(
1006
1543
  ))
1007
1544
 
1008
1545
  @jax.jit
1009
- @functools.partial(
1010
- shard_map.shard_map,
1546
+ @jax.shard_map(
1011
1547
  mesh=mesh,
1012
- in_specs=(P(ep_axis_name), P(ep_axis_name), P(ep_axis_name),
1013
- P(ep_axis_name), P()),
1548
+ in_specs=(
1549
+ P(ep_axis_name), # tokens_hbm
1550
+ P(ep_axis_name), # w1_hbm
1551
+ P(ep_axis_name), # w2_hbm
1552
+ None if w1_scale is None else P(ep_axis_name), # w1_scale_hbm
1553
+ None if w2_scale is None else P(ep_axis_name), # w2_scale_hbm
1554
+ None if b1 is None else P(ep_axis_name), # b1_hbm
1555
+ None if b2 is None else P(ep_axis_name), # b2_hbm
1556
+ P(ep_axis_name), # gating_output_hbm
1557
+ P(), # a2a_g_hbm
1558
+ ),
1014
1559
  out_specs=P(ep_axis_name),
1015
- check_rep=False,
1560
+ check_vma=False,
1016
1561
  )
1017
- def kernel(tokens, w1, w2, gating_output, a2a_g_hbm_scratch):
1562
+ def kernel(
1563
+ tokens,
1564
+ w1,
1565
+ w2,
1566
+ w1_scale,
1567
+ w2_scale,
1568
+ b1,
1569
+ b2,
1570
+ gating_output,
1571
+ a2a_g_hbm_scratch,
1572
+ ):
1018
1573
  return fused_moe(
1019
- pltpu.with_memory_space_constraint(tokens, pltpu.HBM),
1020
- pltpu.with_memory_space_constraint(w1, pltpu.HBM),
1021
- pltpu.with_memory_space_constraint(w2, pltpu.HBM),
1022
- pltpu.with_memory_space_constraint(gating_output, pltpu.HBM),
1023
- pltpu.with_memory_space_constraint(a2a_g_hbm_scratch, pltpu.HBM),
1574
+ pltpu.with_memory_space_constraint(tokens,
1575
+ pltpu.HBM), # tokens_hbm
1576
+ pltpu.with_memory_space_constraint(w1, pltpu.HBM), # w1_hbm
1577
+ pltpu.with_memory_space_constraint(w2, pltpu.HBM), # w2_hbm
1578
+ (None if w1_scale is None else pltpu.with_memory_space_constraint(
1579
+ w1_scale, pltpu.HBM)), # w1_scale_hbm
1580
+ (None if w2_scale is None else pltpu.with_memory_space_constraint(
1581
+ w2_scale, pltpu.HBM)), # w2_scale_hbm
1582
+ (None if b1 is None else pltpu.with_memory_space_constraint(
1583
+ b1, pltpu.HBM)), # b1_hbm
1584
+ (None if b2 is None else pltpu.with_memory_space_constraint(
1585
+ b2, pltpu.HBM)), # b2_hbm
1586
+ pltpu.with_memory_space_constraint(gating_output,
1587
+ pltpu.HBM), # gating_output_hbm
1588
+ pltpu.with_memory_space_constraint(a2a_g_hbm_scratch,
1589
+ pltpu.HBM), # a2a_g_hbm
1024
1590
  )
1025
1591
 
1026
1592
  a2a_g_hbm_scratch = pl.empty(
1027
1593
  (num_experts, bt, t_packing, hidden_size // t_packing), t_dtype)
1028
- results = kernel(
1594
+ return kernel(
1029
1595
  tokens,
1030
1596
  w1,
1031
1597
  w2,
1598
+ w1_scale,
1599
+ w2_scale,
1600
+ b1,
1601
+ b2,
1032
1602
  gating_output,
1033
1603
  a2a_g_hbm_scratch,
1034
1604
  )
1035
- return results[:, :actual_hidden_size]