tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (54) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/lora/test_layers.py +0 -6
  3. tests/lora/utils.py +0 -8
  4. tests/test_envs.py +32 -11
  5. tests/test_utils.py +1 -2
  6. tpu_inference/__init__.py +22 -3
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +3 -4
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +61 -8
  11. tpu_inference/executors/ray_distributed_executor.py +31 -11
  12. tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
  15. tpu_inference/layers/common/attention_interface.py +7 -1
  16. tpu_inference/layers/common/sharding.py +5 -5
  17. tpu_inference/layers/vllm/fused_moe.py +74 -25
  18. tpu_inference/layers/vllm/quantization/common.py +6 -1
  19. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
  20. tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
  21. tpu_inference/layers/vllm/sharding.py +2 -2
  22. tpu_inference/lora/torch_punica_tpu.py +1 -2
  23. tpu_inference/models/common/model_loader.py +45 -11
  24. tpu_inference/models/jax/llama3.py +2 -1
  25. tpu_inference/models/jax/llama_eagle3.py +8 -5
  26. tpu_inference/models/jax/llama_guard_4.py +361 -0
  27. tpu_inference/models/jax/qwen2.py +2 -1
  28. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  29. tpu_inference/models/jax/qwen3.py +2 -1
  30. tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
  31. tpu_inference/models/jax/utils/weight_utils.py +198 -143
  32. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
  33. tpu_inference/platforms/tpu_platform.py +28 -22
  34. tpu_inference/runner/compilation_manager.py +144 -59
  35. tpu_inference/runner/kv_cache_manager.py +17 -18
  36. tpu_inference/runner/persistent_batch_manager.py +40 -2
  37. tpu_inference/runner/structured_decoding_manager.py +2 -3
  38. tpu_inference/runner/tpu_runner.py +271 -147
  39. tpu_inference/runner/utils.py +2 -2
  40. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  41. tpu_inference/tpu_info.py +4 -3
  42. tpu_inference/utils.py +36 -13
  43. tpu_inference/worker/tpu_worker.py +162 -25
  44. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
  45. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
  46. tpu_inference/mock/__init__.py +0 -0
  47. tpu_inference/mock/vllm_config_utils.py +0 -28
  48. tpu_inference/mock/vllm_envs.py +0 -1219
  49. tpu_inference/mock/vllm_logger.py +0 -212
  50. tpu_inference/mock/vllm_logging_utils.py +0 -15
  51. tpu_inference/models/jax/phi3.py +0 -376
  52. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
  53. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
  54. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,6 @@ import jax.numpy as jnp
7
7
  from jax import lax
8
8
  from jax._src import dtypes
9
9
  from jax.experimental import pallas as pl
10
- from jax.experimental import shard_map
11
10
  from jax.experimental.pallas import tpu as pltpu
12
11
 
13
12
  P = jax.sharding.PartitionSpec
@@ -35,13 +34,49 @@ def broadcast_minor(src, shape):
35
34
  axis=-1)[..., :shape[-1]]
36
35
 
37
36
 
37
+ def swigluoai(gate: jax.Array,
38
+ up: jax.Array,
39
+ *,
40
+ alpha: float = 1.702,
41
+ limit: float = 7.0) -> jax.Array:
42
+ """Activation used in some models such as GPT-OSS."""
43
+ gate = jnp.clip(gate, a_max=limit)
44
+ up = jnp.clip(up, a_min=-limit, a_max=limit)
45
+ glu = gate * jax.nn.sigmoid(alpha * gate)
46
+ return (up + 1.0) * glu
47
+
48
+
49
+ def activation_fn(acc1, acc3, act_fn):
50
+ if act_fn == "silu":
51
+ return jax.nn.silu(acc1) * acc3
52
+ elif act_fn == "gelu":
53
+ return jax.nn.gelu(acc1) * acc3
54
+ elif act_fn == "swigluoai":
55
+ return swigluoai(acc1, acc3)
56
+ else:
57
+ raise RuntimeError(f"Unsupported activation function: {act_fn}")
58
+
59
+
38
60
  def ref_moe(
39
- tokens: jax.Array, # (num_tokens, hidden_size)
40
- w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
41
- w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
42
- gating_output: jax.Array, # (num_tokens, num_experts)
43
- top_k: int,
44
- activation="silu",
61
+ tokens: jax.Array, # (num_tokens, hidden_size)
62
+ w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
63
+ w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
64
+ gating_output: jax.Array, # (num_tokens, num_experts)
65
+ top_k: int,
66
+ *,
67
+ renormalize_topk_logits: bool = False,
68
+ activation="silu",
69
+ subc_quant_wsz: int | None = None,
70
+ w1_scale:
71
+ (
72
+ jax.Array | None
73
+ ) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size)
74
+ w2_scale:
75
+ (
76
+ jax.Array | None
77
+ ) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size)
78
+ b1: jax.Array | None = None, # (num_experts, 2, intermediate_size)
79
+ b2: jax.Array | None = None, # (num_experts, hidden_size)
45
80
  ):
46
81
  n_tokens = tokens.shape[0] # num_tokens
47
82
 
@@ -53,7 +88,12 @@ def ref_moe(
53
88
  top_k_logits, top_k_indices = lax.top_k(
54
89
  gating_logits, top_k) # [num_tokens, top_k], [num_tokens, top_k]
55
90
 
91
+ if renormalize_topk_logits:
92
+ top_k_logits = top_k_logits / jnp.sum(
93
+ top_k_logits, axis=-1, keepdims=True)
94
+
56
95
  t_outputs = []
96
+ hidden_size, intermediate_size = w1.shape[-2:]
57
97
 
58
98
  # Process each token individually
59
99
  for i in range(n_tokens):
@@ -65,10 +105,24 @@ def ref_moe(
65
105
  # Process each selected expert for the current token
66
106
  for expert_id in assigned_expert_ids:
67
107
  # Get expert weights
108
+ expert_w1 = w1[expert_id, 0].astype(jnp.float32)
109
+ expert_w3 = w1[expert_id, 1].astype(jnp.float32)
110
+ if w1_scale is not None:
111
+ expert_w1 *= jnp.repeat(w1_scale[expert_id, 0],
112
+ subc_quant_wsz,
113
+ axis=0)[:hidden_size]
114
+ expert_w3 *= jnp.repeat(w1_scale[expert_id, 1],
115
+ subc_quant_wsz,
116
+ axis=0)[:hidden_size]
68
117
  expert_weight_1 = jnp.concat(
69
- [w1[expert_id, 0], w1[expert_id, 1]],
118
+ [expert_w1, expert_w3],
70
119
  axis=-1) # [d_model, 2 * intermediate_size]
71
- expert_weight_2 = w2[expert_id] # [intermediate_size, d_model]
120
+ expert_weight_2 = w2[expert_id].astype(
121
+ jnp.float32) # [intermediate_size, d_model]
122
+ if w2_scale is not None:
123
+ expert_weight_2 *= jnp.repeat(w2_scale[expert_id],
124
+ subc_quant_wsz,
125
+ axis=0)[:intermediate_size]
72
126
 
73
127
  # First linear layer with SwiGLU activation
74
128
  gmm_1_out = curr_token @ expert_weight_1 # [1, 2 * intermediate_size]
@@ -77,21 +131,17 @@ def ref_moe(
77
131
  gmm1_w1_proj, gmm1_w3_proj = jnp.split(
78
132
  gmm_1_out, 2,
79
133
  axis=-1) # [1, intermediate_size], [1, intermediate_size]
134
+ if b1 is not None:
135
+ gmm1_w1_proj += b1[expert_id:expert_id + 1, 0]
136
+ gmm1_w3_proj += b1[expert_id:expert_id + 1, 1]
80
137
 
81
138
  # Apply gated activation: activation(gate) * up
82
- if activation == "silu":
83
- act = jax.nn.silu(
84
- gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
85
- elif activation == "gelu":
86
- act = jax.nn.gelu(
87
- gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
88
- else:
89
- raise ValueError(
90
- f"Unsupported activation: {activation}. Use 'silu' or 'gelu'."
91
- )
139
+ act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, activation)
92
140
 
93
141
  # Second linear layer (down projection)
94
142
  gmm_2_out = act @ expert_weight_2 # [1, d_model]
143
+ if b2 is not None:
144
+ gmm_2_out += b2[expert_id:expert_id + 1]
95
145
  tok_expert_act.append(gmm_2_out)
96
146
 
97
147
  # Combine outputs from all selected experts
@@ -105,7 +155,7 @@ def ref_moe(
105
155
  axis=0,
106
156
  keepdims=True) # [1, d_model]
107
157
 
108
- t_outputs.append(weighted_output)
158
+ t_outputs.append(weighted_output.astype(tokens.dtype))
109
159
 
110
160
  return jnp.concatenate(t_outputs, axis=0) # [num_tokens, d_model]
111
161
 
@@ -115,6 +165,13 @@ def _fused_ep_moe_kernel(
115
165
  tokens_hbm, # (local_num_tokens, t_packing, hidden_size // t_packing)
116
166
  w1_hbm, # (local_num_experts, 2, hidden_size, intermediate_size)
117
167
  w2_hbm, # (local_num_experts, intermediate_size, hidden_size)
168
+ # TODO(jevinjiang): We choose F32 scale for easier slicing. The extra
169
+ # latency should be hidden in the pipeline overlaping. But is there a better
170
+ # way to do this?
171
+ w1_scale_hbm, # None | F32(local_num_experts, 2, cdiv(hidden_size, subc_quant_wsz), 1, intermediate_size)
172
+ w2_scale_hbm, # None | F32(local_num_experts, cdiv(intermediate_size, subc_quant_wsz), 1, hidden_size)
173
+ b1_hbm, # None | F32(local_num_experts, 2, 1, intermediate_size)
174
+ b2_hbm, # None | F32(local_num_experts, 1, hidden_size)
118
175
  gating_hbm, # (local_num_tokens, padded_num_experts)
119
176
  a2a_g_hbm, # (num_experts, bt, t_packing, hidden_size // t_packing)
120
177
  # Output
@@ -136,6 +193,12 @@ def _fused_ep_moe_kernel(
136
193
  b_w1_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
137
194
  b_w3_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
138
195
  b_w2_x2_vmem, # <bw_sem_id> (2, t_packing, bf, bd2 // t_packing)
196
+ b_w1_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
197
+ b_w3_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
198
+ b_w2_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bf // subc_quant_wsz, 1, bd2 // t_packing)
199
+ b_b1_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
200
+ b_b3_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
201
+ b_b2_x2_vmem, # None | <bw_sem_id> (2, t_packing, 1, bd2 // t_packing)
139
202
  b_acc_vmem, # F32(bt * num_devices, 1, bf * 2)
140
203
  ### Semaphores:
141
204
  local_sems, # (2, 5): 2 x [b_gating_sem, b_w1_sem, b_w2_sem, b_w3_sem, b_output_sem]
@@ -145,7 +208,10 @@ def _fused_ep_moe_kernel(
145
208
  a2a_acc_sem,
146
209
  *,
147
210
  top_k: int,
211
+ renormalize_topk_logits: bool,
148
212
  ep_axis_name: str,
213
+ act_fn: str,
214
+ subc_quant_wsz: int | None = None,
149
215
  # Kernel tuning params.
150
216
  bt: int, # Block size of local_num_tokens.
151
217
  bf: int, # Block size of intermediate_size.
@@ -160,34 +226,53 @@ def _fused_ep_moe_kernel(
160
226
  num_devices = lax.axis_size(ep_axis_name)
161
227
  local_num_tokens = tokens_hbm.shape[0]
162
228
  local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
163
- # num_experts = local_num_experts * num_devices
164
- # padded_num_experts = expert_starts_x2_smem.shape[-1]
165
229
  right_id = (my_id + 1) % num_devices
166
230
 
167
231
  t_dtype = tokens_hbm.dtype
168
232
  t_packing = get_dtype_packing(t_dtype)
169
233
  t_bitwidth = 32 // t_packing
170
234
  assert a2a_g_hbm.dtype == t_dtype
171
- assert w1_hbm.dtype == t_dtype
172
- assert w2_hbm.dtype == t_dtype
235
+ assert w1_hbm.dtype == w2_hbm.dtype
173
236
 
174
- h_per_packing = hidden_size // t_packing
175
- assert tokens_hbm.shape[-1] == h_per_packing
176
- bd1_per_packing = bd1 // t_packing
177
- bd2_per_packing = bd2 // t_packing
178
- bd1c_per_packing = bd1c // t_packing
179
- bd2c_per_packing = bd2c // t_packing
237
+ assert bd1 % bd1c == 0
238
+ assert bd2 % bd2c == 0
239
+ assert bf % bfc == 0
240
+ assert hidden_size % t_packing == 0
241
+ assert bd1 % t_packing == 0
242
+ assert bd2 % t_packing == 0
243
+ assert bd1c % t_packing == 0
244
+ assert bd2c % t_packing == 0
245
+
246
+ h_per_t_packing = hidden_size // t_packing
247
+ assert tokens_hbm.shape[-1] == h_per_t_packing
248
+ bd1_per_t_packing = bd1 // t_packing
249
+ bd2_per_t_packing = bd2 // t_packing
250
+ bd1c_per_t_packing = bd1c // t_packing
251
+ bd2c_per_t_packing = bd2c // t_packing
252
+
253
+ if subc_quant_wsz is not None:
254
+ assert subc_quant_wsz % 256 == 0
255
+ assert bd1c_per_t_packing == subc_quant_wsz
256
+ assert bfc == subc_quant_wsz
257
+ assert bd1 % subc_quant_wsz == 0
258
+ assert bf % subc_quant_wsz == 0
259
+ assert bd1_per_t_packing % subc_quant_wsz == 0
260
+ assert h_per_t_packing % subc_quant_wsz == 0
180
261
 
181
262
  num_bt = cdiv(local_num_tokens, bt)
182
263
  num_bf = cdiv(intermediate_size, bf)
183
264
  num_bd1 = cdiv(hidden_size, bd1)
184
265
  num_bd2 = cdiv(hidden_size, bd2)
185
266
 
267
+ def get_mesh_device_id(ep_rank):
268
+ dp_rank = jax.lax.axis_index("data")
269
+ return (dp_rank, ep_rank)
270
+
186
271
  def sync_barrier():
187
272
  barrier_sem = pltpu.get_barrier_semaphore()
188
273
  pltpu.semaphore_signal(
189
274
  barrier_sem,
190
- device_id=(0, right_id),
275
+ device_id=get_mesh_device_id(right_id),
191
276
  device_id_type=pltpu.DeviceIdType.MESH,
192
277
  )
193
278
  pltpu.semaphore_wait(barrier_sem, 1)
@@ -212,7 +297,7 @@ def _fused_ep_moe_kernel(
212
297
  sem=b_gating_sem,
213
298
  ).wait()
214
299
 
215
- def get_top_k(input, top_k):
300
+ def get_top_k(input, top_k, renormalize_topk_logits):
216
301
  assert len(input.shape) == 2, input.shape
217
302
  input = input.astype(jnp.float32)
218
303
  top_k_logits_lst = []
@@ -220,11 +305,15 @@ def _fused_ep_moe_kernel(
220
305
  t2e = jnp.zeros(input.shape, dtype=jnp.int32)
221
306
  t2e_routing = jnp.zeros(input.shape, dtype=jnp.int32)
222
307
  iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1)
308
+ top_k_logits_sum = jnp.zeros((input.shape[0], 128), jnp.float32)
309
+
223
310
  for k_id in range(top_k):
224
- # TODO(jevinjiang): return both top_k values and indices in op in Mosaic
311
+ # TODO(jevinjiang): return both top_k values and indices in Mosaic
225
312
  top_k_logits = jnp.broadcast_to(
226
313
  jnp.max(input, axis=1, keepdims=True),
227
314
  (input.shape[0], 128)).astype(input.dtype)
315
+ if renormalize_topk_logits:
316
+ top_k_logits_sum += top_k_logits
228
317
  top_k_logits_lst.append(top_k_logits)
229
318
  # TODO(jevinjiang): support bf16 argmax in Mosaic
230
319
  top_k_indices = jnp.broadcast_to(
@@ -236,6 +325,11 @@ def _fused_ep_moe_kernel(
236
325
  if k_id != top_k - 1:
237
326
  input = jnp.where(mask, -jnp.inf, input)
238
327
 
328
+ if renormalize_topk_logits:
329
+ for k_id in range(top_k):
330
+ top_k_logits_lst[
331
+ k_id] = top_k_logits_lst[k_id] / top_k_logits_sum
332
+
239
333
  expert_sizes = jnp.sum(t2e, axis=0, keepdims=True)
240
334
  expert_starts = jnp.zeros_like(expert_sizes)
241
335
  return top_k_logits_lst, t2e_routing, expert_sizes, expert_starts
@@ -277,7 +371,7 @@ def _fused_ep_moe_kernel(
277
371
  dst_ref=d2e_count_vmem.at[row_id],
278
372
  send_sem=send_sem,
279
373
  recv_sem=recv_sem,
280
- device_id=(0, right_id),
374
+ device_id=get_mesh_device_id(right_id),
281
375
  device_id_type=pltpu.DeviceIdType.MESH,
282
376
  ).wait()
283
377
  row_id = (row_id + num_devices - 1) % num_devices
@@ -359,10 +453,8 @@ def _fused_ep_moe_kernel(
359
453
  pl.ds(start, remote_sz)],
360
454
  send_sem=send_sems.at[e_sem_id],
361
455
  recv_sem=recv_sems.at[e_sem_id],
362
- device_id=(
363
- 0,
364
- recv_id,
365
- ),
456
+ device_id=get_mesh_device_id(recv_id),
457
+ device_id_type=pltpu.DeviceIdType.MESH,
366
458
  ).start()
367
459
  a2a_s_sends_x2_smem[e_sem_id] = send_sz
368
460
 
@@ -406,7 +498,8 @@ def _fused_ep_moe_kernel(
406
498
  dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)],
407
499
  send_sem=send_sems.at[e_sem_id],
408
500
  recv_sem=a2a_gather_sem,
409
- device_id=(0, recv_id),
501
+ device_id=get_mesh_device_id(recv_id),
502
+ device_id_type=pltpu.DeviceIdType.MESH,
410
503
  ).start()
411
504
  start += sz
412
505
 
@@ -435,68 +528,173 @@ def _fused_ep_moe_kernel(
435
528
 
436
529
  def start_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
437
530
  for p in range(t_packing):
438
- offset = p * h_per_packing + bd1_id * bd1_per_packing
531
+ offset = p * h_per_t_packing + bd1_id * bd1_per_t_packing
439
532
  pltpu.make_async_copy(
440
533
  src_ref=w1_hbm.at[
441
534
  local_e_id,
442
535
  0,
443
- pl.ds(offset, bd1_per_packing),
536
+ pl.ds(offset, bd1_per_t_packing),
444
537
  pl.ds(bf_id * bf, bf),
445
538
  ],
446
539
  dst_ref=b_w1_x2_vmem.at[bw1_sem_id, p],
447
540
  sem=local_sems.at[bw1_sem_id, 1],
448
541
  ).start()
542
+ if w1_scale_hbm is not None:
543
+ assert subc_quant_wsz is not None
544
+ pltpu.make_async_copy(
545
+ src_ref=w1_scale_hbm.at[
546
+ local_e_id,
547
+ 0,
548
+ pl.ds(
549
+ offset // subc_quant_wsz,
550
+ bd1_per_t_packing // subc_quant_wsz,
551
+ ),
552
+ pl.ds(0, 1),
553
+ pl.ds(bf_id * bf, bf),
554
+ ],
555
+ dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id, p],
556
+ sem=local_sems.at[bw1_sem_id, 1],
557
+ ).start()
558
+ if b1_hbm is not None and bd1_id == 0:
559
+ pltpu.make_async_copy(
560
+ src_ref=b1_hbm.at[local_e_id, 0,
561
+ pl.ds(0, 1),
562
+ pl.ds(bf_id * bf, bf)],
563
+ dst_ref=b_b1_x2_vmem.at[bf_id % 2],
564
+ sem=local_sems.at[bw1_sem_id, 1],
565
+ ).start()
449
566
 
450
567
  def start_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
451
568
  for p in range(t_packing):
452
- offset = p * h_per_packing + bd2_id * bd2_per_packing
569
+ offset = p * h_per_t_packing + bd2_id * bd2_per_t_packing
453
570
  pltpu.make_async_copy(
454
571
  src_ref=w2_hbm.at[
455
572
  local_e_id,
456
573
  pl.ds(bf_id * bf, bf),
457
- pl.ds(offset, bd2_per_packing),
574
+ pl.ds(offset, bd2_per_t_packing),
458
575
  ],
459
576
  dst_ref=b_w2_x2_vmem.at[bw2_sem_id, p],
460
577
  sem=local_sems.at[bw2_sem_id, 2],
461
578
  ).start()
579
+ if w2_scale_hbm is not None:
580
+ assert subc_quant_wsz is not None
581
+ pltpu.make_async_copy(
582
+ src_ref=w2_scale_hbm.at[
583
+ local_e_id,
584
+ pl.ds(bf_id * bf // subc_quant_wsz, bf //
585
+ subc_quant_wsz),
586
+ pl.ds(0, 1),
587
+ pl.ds(offset, bd2_per_t_packing),
588
+ ],
589
+ dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id, p],
590
+ sem=local_sems.at[bw2_sem_id, 2],
591
+ ).start()
592
+ if b2_hbm is not None and bf_id == 0:
593
+ pltpu.make_async_copy(
594
+ src_ref=b2_hbm.at[local_e_id,
595
+ pl.ds(0, 1),
596
+ pl.ds(offset, bd2_per_t_packing)],
597
+ dst_ref=b_b2_x2_vmem.at[bd2_id % 2, p],
598
+ sem=local_sems.at[bw2_sem_id, 2],
599
+ ).start()
462
600
 
463
601
  def start_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
464
602
  for p in range(t_packing):
465
- offset = p * h_per_packing + bd3_id * bd1_per_packing
603
+ offset = p * h_per_t_packing + bd3_id * bd1_per_t_packing
466
604
  pltpu.make_async_copy(
467
605
  src_ref=w1_hbm.at[
468
606
  local_e_id,
469
607
  1,
470
- pl.ds(offset, bd1_per_packing),
608
+ pl.ds(offset, bd1_per_t_packing),
471
609
  pl.ds(bf_id * bf, bf),
472
610
  ],
473
611
  dst_ref=b_w3_x2_vmem.at[bw3_sem_id, p],
474
612
  sem=local_sems.at[bw3_sem_id, 3],
475
613
  ).start()
614
+ if w1_scale_hbm is not None:
615
+ assert subc_quant_wsz is not None
616
+ pltpu.make_async_copy(
617
+ src_ref=w1_scale_hbm.at[
618
+ local_e_id,
619
+ 1,
620
+ pl.ds(
621
+ offset // subc_quant_wsz,
622
+ bd1_per_t_packing // subc_quant_wsz,
623
+ ),
624
+ pl.ds(0, 1),
625
+ pl.ds(bf_id * bf, bf),
626
+ ],
627
+ dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id, p],
628
+ sem=local_sems.at[bw3_sem_id, 3],
629
+ ).start()
630
+ if b1_hbm is not None and bd3_id == 0:
631
+ pltpu.make_async_copy(
632
+ src_ref=b1_hbm.at[local_e_id, 1,
633
+ pl.ds(0, 1),
634
+ pl.ds(bf_id * bf, bf)],
635
+ dst_ref=b_b3_x2_vmem.at[bf_id % 2],
636
+ sem=local_sems.at[bw3_sem_id, 3],
637
+ ).start()
476
638
 
477
639
  def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
478
- del local_e_id, bf_id, bd1_id
640
+ del local_e_id
479
641
  pltpu.make_async_copy(
480
642
  src_ref=b_w1_x2_vmem.at[bw1_sem_id],
481
643
  dst_ref=b_w1_x2_vmem.at[bw1_sem_id],
482
644
  sem=local_sems.at[bw1_sem_id, 1],
483
645
  ).wait()
646
+ if w1_scale_hbm is not None:
647
+ pltpu.make_async_copy(
648
+ src_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
649
+ dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
650
+ sem=local_sems.at[bw1_sem_id, 1],
651
+ ).wait()
652
+ if b1_hbm is not None and bd1_id == 0:
653
+ pltpu.make_async_copy(
654
+ src_ref=b_b1_x2_vmem.at[bf_id % 2],
655
+ dst_ref=b_b1_x2_vmem.at[bf_id % 2],
656
+ sem=local_sems.at[bw1_sem_id, 1],
657
+ ).wait()
484
658
 
485
659
  def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
486
- del local_e_id, bf_id, bd2_id
660
+ del local_e_id
487
661
  pltpu.make_async_copy(
488
662
  src_ref=b_w2_x2_vmem.at[bw2_sem_id],
489
663
  dst_ref=b_w2_x2_vmem.at[bw2_sem_id],
490
664
  sem=local_sems.at[bw2_sem_id, 2],
491
665
  ).wait()
666
+ if w2_scale_hbm is not None:
667
+ pltpu.make_async_copy(
668
+ src_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
669
+ dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
670
+ sem=local_sems.at[bw2_sem_id, 2],
671
+ ).wait()
672
+ if b2_hbm is not None and bf_id == 0:
673
+ pltpu.make_async_copy(
674
+ src_ref=b_b2_x2_vmem.at[bd2_id % 2],
675
+ dst_ref=b_b2_x2_vmem.at[bd2_id % 2],
676
+ sem=local_sems.at[bw2_sem_id, 2],
677
+ ).wait()
492
678
 
493
679
  def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
494
- del local_e_id, bf_id, bd3_id
680
+ del local_e_id
495
681
  pltpu.make_async_copy(
496
682
  src_ref=b_w3_x2_vmem.at[bw3_sem_id],
497
683
  dst_ref=b_w3_x2_vmem.at[bw3_sem_id],
498
684
  sem=local_sems.at[bw3_sem_id, 3],
499
685
  ).wait()
686
+ if w1_scale_hbm is not None:
687
+ pltpu.make_async_copy(
688
+ src_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
689
+ dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
690
+ sem=local_sems.at[bw3_sem_id, 3],
691
+ ).wait()
692
+ if b1_hbm is not None and bd3_id == 0:
693
+ pltpu.make_async_copy(
694
+ src_ref=b_b3_x2_vmem.at[bf_id % 2],
695
+ dst_ref=b_b3_x2_vmem.at[bf_id % 2],
696
+ sem=local_sems.at[bw3_sem_id, 3],
697
+ ).wait()
500
698
 
501
699
  def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id):
502
700
  next_bd1_id = bd1_id + 1
@@ -520,18 +718,38 @@ def _fused_ep_moe_kernel(
520
718
  def dynamic_ffn1(
521
719
  t_b32_vmem,
522
720
  w1_vmem,
721
+ w1_scale_vmem,
722
+ b1_vmem,
523
723
  w3_vmem,
724
+ w3_scale_vmem,
725
+ b3_vmem,
524
726
  acc1_vmem,
525
727
  acc3_vmem,
526
728
  dyn_sz,
527
729
  should_init,
528
730
  ):
529
731
  assert t_b32_vmem.shape == (bt * num_devices, bd1 // t_packing)
530
- assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_packing,
732
+ assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_t_packing,
531
733
  bf)
532
734
  assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
533
735
  assert bd1 % (t_packing * 128) == 0, (bd1, t_packing)
534
736
  assert bd1c % (t_packing * 128) == 0, (bd1c, t_packing)
737
+ if w1_scale_vmem is not None:
738
+ assert w1_scale_vmem.shape == (
739
+ t_packing,
740
+ bd1_per_t_packing // subc_quant_wsz,
741
+ 1,
742
+ bf,
743
+ )
744
+ assert bd1c_per_t_packing == subc_quant_wsz
745
+ if w3_scale_vmem is not None:
746
+ assert w3_scale_vmem.shape == (
747
+ t_packing,
748
+ bd1_per_t_packing // subc_quant_wsz,
749
+ 1,
750
+ bf,
751
+ )
752
+ assert bd1c_per_t_packing == subc_quant_wsz
535
753
 
536
754
  num_loops = cdiv(dyn_sz, btc)
537
755
  repack_ty = jnp.dtype(f"int{t_bitwidth}")
@@ -540,7 +758,7 @@ def _fused_ep_moe_kernel(
540
758
  for bd1c_id in range(cdiv(bd1, bd1c)):
541
759
  t_b32 = t_b32_vmem[
542
760
  pl.ds(btc_id * btc, btc),
543
- pl.ds(bd1c_id * bd1c_per_packing, bd1c_per_packing),
761
+ pl.ds(bd1c_id * bd1c_per_t_packing, bd1c_per_t_packing),
544
762
  ]
545
763
  for p_id in range(t_packing):
546
764
  t = pltpu.bitcast(t_b32.astype(repack_ty), t_dtype)
@@ -548,21 +766,64 @@ def _fused_ep_moe_kernel(
548
766
  for bfc_id in range(cdiv(bf, bfc)):
549
767
  w_slices = (
550
768
  p_id,
551
- pl.ds(bd1c_id * bd1c_per_packing,
552
- bd1c_per_packing),
769
+ pl.ds(bd1c_id * bd1c_per_t_packing,
770
+ bd1c_per_t_packing),
553
771
  pl.ds(bfc_id * bfc, bfc),
554
772
  )
555
773
  w1 = w1_vmem[*w_slices]
556
774
  acc1 = jnp.dot(t,
557
775
  w1,
558
776
  preferred_element_type=jnp.float32)
777
+
778
+ if w1_scale_vmem is not None:
779
+ w1_scale_slices = (
780
+ p_id,
781
+ bd1c_id,
782
+ pl.ds(0, 1),
783
+ pl.ds(bfc_id * bfc, bfc),
784
+ )
785
+ # TODO(jevinjiang): can use mosaic to load with stride 0.
786
+ w1_scale = jnp.broadcast_to(
787
+ w1_scale_vmem[*w1_scale_slices], acc1.shape)
788
+ acc1 *= w1_scale
789
+
559
790
  w3 = w3_vmem[*w_slices]
791
+
560
792
  acc3 = jnp.dot(t,
561
793
  w3,
562
794
  preferred_element_type=jnp.float32)
795
+
796
+ if w3_scale_vmem is not None:
797
+ w3_scale_slices = (
798
+ p_id,
799
+ bd1c_id,
800
+ pl.ds(0, 1),
801
+ pl.ds(bfc_id * bfc, bfc),
802
+ )
803
+ w3_scale = jnp.broadcast_to(
804
+ w3_scale_vmem[*w3_scale_slices], acc3.shape)
805
+ acc3 *= w3_scale
806
+
563
807
  acc_slices = (pl.ds(btc_id * btc,
564
808
  btc), pl.ds(bfc_id * bfc, bfc))
565
809
  if should_init and p_id == bd1c_id == 0:
810
+ if b1_vmem is not None:
811
+ b1_scale_slices = (
812
+ pl.ds(0, 1),
813
+ pl.ds(bfc_id * bfc, bfc),
814
+ )
815
+ b1 = jnp.broadcast_to(
816
+ b1_vmem[*b1_scale_slices], acc1.shape)
817
+ acc1 += b1
818
+ if b3_vmem is not None:
819
+ b3_scale_slices = (
820
+ pl.ds(0, 1),
821
+ pl.ds(bfc_id * bfc, bfc),
822
+ )
823
+ b3 = jnp.broadcast_to(
824
+ b3_vmem[*b3_scale_slices], acc1.shape)
825
+ acc3 += b3
826
+
566
827
  acc1_vmem[*acc_slices] = acc1
567
828
  acc3_vmem[*acc_slices] = acc3
568
829
  else:
@@ -575,22 +836,28 @@ def _fused_ep_moe_kernel(
575
836
  acc1_vmem,
576
837
  acc3_vmem,
577
838
  w2_vmem,
839
+ w2_scale_vmem,
840
+ b2_vmem,
578
841
  res_b32_vmem,
579
842
  dyn_sz,
580
843
  should_init,
581
844
  ):
582
- assert res_b32_vmem.shape == (bt * num_devices, bd2_per_packing)
583
- assert w2_vmem.shape == (t_packing, bf, bd2_per_packing), (
584
- w2_vmem.shape,
585
- t_packing,
586
- bf,
587
- bd2_per_packing,
588
- )
845
+ assert res_b32_vmem.shape == (bt * num_devices, bd2_per_t_packing)
846
+ assert w2_vmem.shape == (t_packing, bf, bd2_per_t_packing)
589
847
  assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
590
848
  assert bd2 % (t_packing * 128) == 0, (bd2, t_packing)
591
849
  assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
592
850
  assert t_dtype in (jnp.float32, jnp.bfloat16)
593
851
 
852
+ if w2_scale_vmem is not None:
853
+ assert w2_scale_vmem.shape == (
854
+ t_packing,
855
+ bf // subc_quant_wsz,
856
+ 1,
857
+ bd2_per_t_packing,
858
+ )
859
+ assert bfc == subc_quant_wsz
860
+
594
861
  num_loops = cdiv(dyn_sz, btc)
595
862
  assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
596
863
 
@@ -598,22 +865,47 @@ def _fused_ep_moe_kernel(
598
865
  for bd2c_id in range(cdiv(bd2, bd2c)):
599
866
  res_lst = []
600
867
  for p_id in range(t_packing):
601
- res = jnp.zeros((btc, bd2c_per_packing), dtype=jnp.float32)
868
+ res = jnp.zeros((btc, bd2c_per_t_packing),
869
+ dtype=jnp.float32)
870
+
871
+ if b2_vmem is not None and should_init:
872
+ b2_scale_slices = (
873
+ p_id,
874
+ pl.ds(0, 1),
875
+ pl.ds(bd2c_id * bd2c_per_t_packing,
876
+ bd2c_per_t_packing),
877
+ )
878
+ b2 = jnp.broadcast_to(b2_vmem[*b2_scale_slices],
879
+ res.shape)
880
+ res += b2
881
+
602
882
  for bfc_id in range(cdiv(bf, bfc)):
603
883
  acc_slices = (pl.ds(btc_id * btc,
604
884
  btc), pl.ds(bfc_id * bfc, bfc))
605
885
  acc1 = acc1_vmem[*acc_slices]
606
886
  acc3 = acc3_vmem[*acc_slices]
607
- act = jax.nn.silu(acc1) * acc3
887
+ act = activation_fn(acc1, acc3, act_fn)
608
888
  w2 = w2_vmem[
609
889
  p_id,
610
890
  pl.ds(bfc_id * bfc, bfc),
611
891
  pl.ds(bd2c_id *
612
- bd2c_per_packing, bd2c_per_packing),
892
+ bd2c_per_t_packing, bd2c_per_t_packing),
613
893
  ]
614
- res += jnp.dot(act,
615
- w2,
616
- preferred_element_type=jnp.float32)
894
+ acc = jnp.dot(act,
895
+ w2,
896
+ preferred_element_type=jnp.float32)
897
+ if w2_scale_vmem is not None:
898
+ w2_scale_slices = (
899
+ p_id,
900
+ bfc_id,
901
+ pl.ds(0, 1),
902
+ pl.ds(bd2c_id * bd2c_per_t_packing,
903
+ bd2c_per_t_packing),
904
+ )
905
+ w2_scale = jnp.broadcast_to(
906
+ w2_scale_vmem[*w2_scale_slices], acc.shape)
907
+ acc *= w2_scale
908
+ res += acc
617
909
  res = pltpu.bitcast(res, jnp.uint32)
618
910
  if t_packing == 2:
619
911
  res = res >> 16 << (16 * p_id)
@@ -626,7 +918,7 @@ def _fused_ep_moe_kernel(
626
918
  res |= res_lst[i]
627
919
  sliced_res_vmem = res_b32_vmem.at[
628
920
  pl.ds(btc_id * btc, btc),
629
- pl.ds(bd2c_id * bd2c_per_packing, bd2c_per_packing),
921
+ pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing),
630
922
  ]
631
923
  if should_init:
632
924
  sliced_res_vmem[...] = res
@@ -655,21 +947,33 @@ def _fused_ep_moe_kernel(
655
947
  e_id = my_id * local_num_experts + local_e_id
656
948
  dyn_sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
657
949
 
658
- bd1_per_packing = bd1 // t_packing
659
- bd2_per_packing = bd2 // t_packing
950
+ bd1_per_t_packing = bd1 // t_packing
951
+ bd2_per_t_packing = bd2 // t_packing
660
952
 
661
953
  for bf_id in range(num_bf):
662
954
  for bd1_id in range(num_bd1):
663
955
  start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, 0)
956
+ w1_scale_vmem = (None if b_w1_scale_x2_vmem is None else
957
+ b_w1_scale_x2_vmem.at[bw_sem_id])
958
+ w3_scale_vmem = (None if b_w3_scale_x2_vmem is None else
959
+ b_w3_scale_x2_vmem.at[bw_sem_id])
960
+ b1_vmem = None if b_b1_x2_vmem is None else b_b1_x2_vmem.at[
961
+ bf_id % 2]
962
+ b3_vmem = None if b_b3_x2_vmem is None else b_b3_x2_vmem.at[
963
+ bf_id % 2]
664
964
  wait_fetch_bw1(local_e_id, bw_sem_id, bf_id, bd1_id)
665
965
  wait_fetch_bw3(local_e_id, bw_sem_id, bf_id, bd1_id)
666
966
 
667
967
  dynamic_ffn1(
668
968
  t_b32_vmem=a2a_s_b32_vmem.at[
669
969
  ...,
670
- pl.ds(bd1_id * bd1_per_packing, bd1_per_packing)],
970
+ pl.ds(bd1_id * bd1_per_t_packing, bd1_per_t_packing)],
671
971
  w1_vmem=b_w1_x2_vmem.at[bw_sem_id],
972
+ w1_scale_vmem=w1_scale_vmem,
973
+ b1_vmem=b1_vmem,
672
974
  w3_vmem=b_w3_x2_vmem.at[bw_sem_id],
975
+ w3_scale_vmem=w3_scale_vmem,
976
+ b3_vmem=b3_vmem,
673
977
  acc1_vmem=b_acc1_vmem,
674
978
  acc3_vmem=b_acc3_vmem,
675
979
  dyn_sz=dyn_sz,
@@ -684,13 +988,19 @@ def _fused_ep_moe_kernel(
684
988
  if bf_id == bd2_id == 0:
685
989
  wait_a2a_gather_send(bt_id, e_sem_id, local_e_id - 2)
686
990
 
991
+ w2_scale_vmem = (None if b_w2_scale_x2_vmem is None else
992
+ b_w2_scale_x2_vmem.at[bw_sem_id])
993
+ b2_vmem = None if b_b2_x2_vmem is None else b_b2_x2_vmem.at[
994
+ bd2_id % 2]
687
995
  dynamic_ffn2(
688
996
  acc1_vmem=b_acc1_vmem,
689
997
  acc3_vmem=b_acc3_vmem,
690
998
  w2_vmem=b_w2_x2_vmem.at[bw_sem_id],
999
+ w2_scale_vmem=w2_scale_vmem,
1000
+ b2_vmem=b2_vmem,
691
1001
  res_b32_vmem=a2a_s_acc_b32_vmem.at[
692
1002
  ...,
693
- pl.ds(bd2_id * bd2_per_packing, bd2_per_packing)],
1003
+ pl.ds(bd2_id * bd2_per_t_packing, bd2_per_t_packing)],
694
1004
  dyn_sz=dyn_sz,
695
1005
  should_init=(bf_id == 0),
696
1006
  )
@@ -757,7 +1067,7 @@ def _fused_ep_moe_kernel(
757
1067
  b_gating = b_gating_x2_vmem[bt_sem_id]
758
1068
  b_gating_score = jax.nn.softmax(b_gating, axis=-1)
759
1069
  top_k_logits_lst, t2e_routing, expert_sizes, expert_starts = get_top_k(
760
- b_gating_score, top_k)
1070
+ b_gating_score, top_k, renormalize_topk_logits)
761
1071
 
762
1072
  all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts,
763
1073
  expert_sizes)
@@ -827,6 +1137,9 @@ def _fused_ep_moe_kernel(
827
1137
  static_argnames=[
828
1138
  "mesh",
829
1139
  "top_k",
1140
+ "renormalize_topk_logits",
1141
+ "act_fn",
1142
+ "subc_quant_wsz",
830
1143
  "bt",
831
1144
  "bf",
832
1145
  "bd1",
@@ -845,7 +1158,18 @@ def fused_ep_moe(
845
1158
  w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
846
1159
  gating_output: jax.Array, # (num_tokens, num_experts)
847
1160
  top_k: int,
1161
+ renormalize_topk_logits: bool = False,
1162
+ act_fn: str = "silu",
848
1163
  *,
1164
+ subc_quant_wsz: int | None = None,
1165
+ w1_scale: (
1166
+ jax.Array | None
1167
+ ) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size)
1168
+ w2_scale: (
1169
+ jax.Array | None
1170
+ ) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size)
1171
+ b1: jax.Array | None = None, # (num_experts, 2, intermediate_size)
1172
+ b2: jax.Array | None = None, # (num_experts, hidden_size)
849
1173
  # Kernel tuning parameters.
850
1174
  bt: int,
851
1175
  bf: int,
@@ -855,18 +1179,19 @@ def fused_ep_moe(
855
1179
  bfc: int,
856
1180
  bd1c: int,
857
1181
  bd2c: int,
858
- ep_axis_name: str = 'model',
1182
+ ep_axis_name: str = "model",
859
1183
  ):
1184
+ # TODO(jevinjiang): move all these assertions to validation function.
860
1185
  # Assert all other axes have length of 1
861
- assert len(mesh.shape) == 2, "Expect 2D mesh in tpu-inference"
862
- assert 'data' in mesh.shape and mesh.shape['data'] == 1, \
863
- "Expect data axis size of 1 in tpu-inference"
1186
+ assert len(mesh.shape) == 2, "Expect 2D mesh"
1187
+ assert ("data" in mesh.shape
1188
+ and mesh.shape["data"] == 1), "Expect data axis size of 1"
864
1189
 
865
1190
  ep_size = mesh.shape[ep_axis_name]
866
1191
  num_devices = ep_size
867
1192
 
868
1193
  num_tokens, actual_hidden_size = tokens.shape
869
- num_experts, intermediate_size, _ = w2.shape
1194
+ num_experts, actual_intermediate_size, _ = w2.shape
870
1195
 
871
1196
  assert num_tokens % ep_size == 0
872
1197
  assert num_experts % ep_size == 0
@@ -874,26 +1199,18 @@ def fused_ep_moe(
874
1199
  local_num_tokens = num_tokens // ep_size
875
1200
  # local_num_experts = num_experts // ep_size
876
1201
  padded_num_experts = align_to(num_experts, 128)
877
-
878
1202
  t_dtype = tokens.dtype
879
1203
  t_packing = get_dtype_packing(t_dtype)
880
- hidden_size = align_to(actual_hidden_size, 128 * t_packing)
881
- if hidden_size != actual_hidden_size:
882
- tokens = jnp.pad(
883
- tokens,
884
- ((0, 0), (0, hidden_size - actual_hidden_size)),
885
- constant_values=0,
886
- )
887
- tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
888
- bt = min(bt, local_num_tokens)
889
- bf = min(bf, intermediate_size)
890
- bd1 = min(bd1, hidden_size)
891
- bd2 = min(bd2, hidden_size)
892
1204
 
893
- btc = min(btc, bt * num_devices)
894
- bfc = min(bfc, bf)
895
- bd1c = min(bd1c, bd1)
896
- bd2c = min(bd2c, bd2)
1205
+ if subc_quant_wsz is not None:
1206
+ if subc_quant_wsz % 256 != 0:
1207
+ raise NotImplementedError(
1208
+ "Sub-quantized window is not aligned to 256.")
1209
+ # We force compute size of contracting dim to subc_quant_wsz. So we can
1210
+ # apply same scale after matmul and accumulation.
1211
+ bd1c = subc_quant_wsz * t_packing
1212
+ bfc = subc_quant_wsz
1213
+
897
1214
  assert bfc % 128 == 0
898
1215
  assert bd1c % (t_packing * 128) == 0
899
1216
  assert bd2c % (t_packing * 128) == 0
@@ -901,6 +1218,41 @@ def fused_ep_moe(
901
1218
  assert bd1 % bd1c == 0
902
1219
  assert bd2 % bd2c == 0
903
1220
 
1221
+ btc = min(btc, bt * num_devices)
1222
+ hidden_size = align_to(actual_hidden_size, 128 * t_packing)
1223
+ # TODO(jevinjiang): instead of padding outside the kernel, we can try dynammic
1224
+ # masking inside the kernel.
1225
+ hidden_size = align_to(hidden_size, bd1)
1226
+ hidden_size = align_to(hidden_size, bd2)
1227
+ intermediate_size = align_to(actual_intermediate_size, bf)
1228
+
1229
+ # TODO(jevinjiang): we should dump scale as the kernel expected shape in the
1230
+ # checkpoint offline or reshape right after weight loading.
1231
+ if w1_scale is not None:
1232
+ assert w1_scale.shape[0] == w1.shape[0]
1233
+ assert w1_scale.shape[1] == w1.shape[1] == 2
1234
+ assert w1_scale.shape[2] == cdiv(w1.shape[2], subc_quant_wsz)
1235
+ assert w1_scale.shape[3] == w1.shape[3]
1236
+ w1_scale = jnp.expand_dims(w1_scale.astype(jnp.float32), axis=-2)
1237
+
1238
+ if w2_scale is not None:
1239
+ assert w2_scale.shape[0] == w2.shape[0]
1240
+ assert w2_scale.shape[1] == cdiv(w2.shape[1], subc_quant_wsz)
1241
+ assert w2_scale.shape[2] == w2.shape[2]
1242
+ w2_scale = jnp.expand_dims(w2_scale.astype(jnp.float32), axis=-2)
1243
+
1244
+ if b1 is not None:
1245
+ assert b1.shape[0] == w1.shape[0]
1246
+ assert b1.shape[1] == w1.shape[1] == 2
1247
+ assert b1.shape[2] == w1.shape[3]
1248
+ b1 = jnp.expand_dims(b1.astype(jnp.float32), axis=-2)
1249
+
1250
+ if b2 is not None:
1251
+ assert b2.shape[0] == w2.shape[0]
1252
+ assert b2.shape[1] == w2.shape[2]
1253
+ b2 = jnp.expand_dims(b2.astype(jnp.float32), axis=-2)
1254
+
1255
+ # Prepare inputs for the kernel.
904
1256
  if padded_num_experts != gating_output.shape[-1]:
905
1257
  gating_output = jnp.pad(
906
1258
  gating_output,
@@ -908,13 +1260,92 @@ def fused_ep_moe(
908
1260
  constant_values=-jnp.inf,
909
1261
  )
910
1262
 
911
- scope_name = f"fused_moe_k-{top_k}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}"
1263
+ if (hidden_size != actual_hidden_size
1264
+ or intermediate_size != actual_intermediate_size):
1265
+ tokens = jnp.pad(
1266
+ tokens,
1267
+ ((0, 0), (0, hidden_size - actual_hidden_size)),
1268
+ constant_values=0,
1269
+ )
1270
+ w1 = jnp.pad(
1271
+ w1,
1272
+ (
1273
+ (0, 0),
1274
+ (0, 0),
1275
+ (0, hidden_size - actual_hidden_size),
1276
+ (0, intermediate_size - actual_intermediate_size),
1277
+ ),
1278
+ constant_values=0,
1279
+ )
1280
+ w2 = jnp.pad(
1281
+ w2,
1282
+ (
1283
+ (0, 0),
1284
+ (0, intermediate_size - actual_intermediate_size),
1285
+ (0, hidden_size - actual_hidden_size),
1286
+ ),
1287
+ constant_values=0,
1288
+ )
1289
+ if w1_scale is not None:
1290
+ w1_scale = jnp.pad(
1291
+ w1_scale,
1292
+ (
1293
+ (0, 0),
1294
+ (0, 0),
1295
+ (0,
1296
+ cdiv(hidden_size, subc_quant_wsz) - w1_scale.shape[-3]),
1297
+ (0, 0),
1298
+ (0, intermediate_size - w1_scale.shape[-1]),
1299
+ ),
1300
+ constant_values=0,
1301
+ )
1302
+ if w2_scale is not None:
1303
+ w2_scale = jnp.pad(
1304
+ w2_scale,
1305
+ (
1306
+ (0, 0),
1307
+ (0, cdiv(intermediate_size, subc_quant_wsz) -
1308
+ w2_scale.shape[-3]),
1309
+ (0, 0),
1310
+ (0, hidden_size - w2_scale.shape[-1]),
1311
+ ),
1312
+ constant_values=0,
1313
+ )
1314
+ if b1 is not None:
1315
+ b1 = jnp.pad(
1316
+ b1,
1317
+ (
1318
+ (0, 0),
1319
+ (0, 0),
1320
+ (0, 0),
1321
+ (0, intermediate_size - b1.shape[-1]),
1322
+ ),
1323
+ constant_values=0,
1324
+ )
1325
+ if b2 is not None:
1326
+ b2 = jnp.pad(
1327
+ b2,
1328
+ (
1329
+ (0, 0),
1330
+ (0, 0),
1331
+ (0, hidden_size - b2.shape[-1]),
1332
+ ),
1333
+ constant_values=0,
1334
+ )
1335
+
1336
+ tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
1337
+
1338
+ hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
1339
+ scope_name = f"fused_moe_k-{top_k}_renorm-{renormalize_topk_logits}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}"
912
1340
  fused_moe = jax.named_scope(scope_name)(
913
1341
  pl.pallas_call(
914
1342
  functools.partial(
915
1343
  _fused_ep_moe_kernel,
916
1344
  top_k=top_k,
1345
+ renormalize_topk_logits=renormalize_topk_logits,
917
1346
  ep_axis_name=ep_axis_name,
1347
+ act_fn=act_fn,
1348
+ subc_quant_wsz=subc_quant_wsz,
918
1349
  bt=bt,
919
1350
  bf=bf,
920
1351
  bd1=bd1,
@@ -929,11 +1360,17 @@ def fused_ep_moe(
929
1360
  grid_spec=pltpu.PrefetchScalarGridSpec(
930
1361
  num_scalar_prefetch=0,
931
1362
  in_specs=[
932
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
933
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
934
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
935
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
936
- pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
1363
+ hbm_block_spec, # tokens_hbm
1364
+ hbm_block_spec, # w1_hbm
1365
+ hbm_block_spec, # w2_hbm
1366
+ None
1367
+ if w1_scale is None else hbm_block_spec, # w1_scale_hbm
1368
+ None
1369
+ if w2_scale is None else hbm_block_spec, # w2_scale_hbm
1370
+ None if b1 is None else hbm_block_spec, # b1_hbm
1371
+ None if b2 is None else hbm_block_spec, # b2_hbm
1372
+ hbm_block_spec, # gating_output_hbm
1373
+ hbm_block_spec, # a2a_g_hbm
937
1374
  ],
938
1375
  out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
939
1376
  scratch_shapes=([
@@ -984,6 +1421,67 @@ def fused_ep_moe(
984
1421
  pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
985
1422
  # b_w2_x2_vmem
986
1423
  pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
1424
+ # b_w1_scale_x2_vmem
1425
+ (None if w1_scale is None else pltpu.VMEM(
1426
+ (
1427
+ 2,
1428
+ t_packing,
1429
+ bd1 // t_packing // subc_quant_wsz,
1430
+ 1,
1431
+ bf,
1432
+ ),
1433
+ jnp.float32,
1434
+ )),
1435
+ # b_w3_scale_x2_vmem
1436
+ (None if w1_scale is None else pltpu.VMEM(
1437
+ (
1438
+ 2,
1439
+ t_packing,
1440
+ bd1 // t_packing // subc_quant_wsz,
1441
+ 1,
1442
+ bf,
1443
+ ),
1444
+ jnp.float32,
1445
+ )),
1446
+ # b_w2_scale_x2_vmem
1447
+ (None if w2_scale is None else pltpu.VMEM(
1448
+ (
1449
+ 2,
1450
+ t_packing,
1451
+ bf // subc_quant_wsz,
1452
+ 1,
1453
+ bd2 // t_packing,
1454
+ ),
1455
+ jnp.float32,
1456
+ )),
1457
+ # b_b1_x2_vmem
1458
+ (None if b1 is None else pltpu.VMEM(
1459
+ (
1460
+ 2,
1461
+ 1,
1462
+ bf,
1463
+ ),
1464
+ jnp.float32,
1465
+ )),
1466
+ # b_b3_x2_vmem
1467
+ (None if b1 is None else pltpu.VMEM(
1468
+ (
1469
+ 2,
1470
+ 1,
1471
+ bf,
1472
+ ),
1473
+ jnp.float32,
1474
+ )),
1475
+ # b_b2_x2_vmem
1476
+ (None if b2 is None else pltpu.VMEM(
1477
+ (
1478
+ 2,
1479
+ t_packing,
1480
+ 1,
1481
+ bd2 // t_packing,
1482
+ ),
1483
+ jnp.float32,
1484
+ )),
987
1485
  # b_acc_vmem
988
1486
  pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
989
1487
  # local_sems
@@ -1006,21 +1504,50 @@ def fused_ep_moe(
1006
1504
  ))
1007
1505
 
1008
1506
  @jax.jit
1009
- @functools.partial(
1010
- shard_map.shard_map,
1507
+ @jax.shard_map(
1011
1508
  mesh=mesh,
1012
- in_specs=(P(ep_axis_name), P(ep_axis_name), P(ep_axis_name),
1013
- P(ep_axis_name), P()),
1509
+ in_specs=(
1510
+ P(ep_axis_name), # tokens_hbm
1511
+ P(ep_axis_name), # w1_hbm
1512
+ P(ep_axis_name), # w2_hbm
1513
+ None if w1_scale is None else P(ep_axis_name), # w1_scale_hbm
1514
+ None if w2_scale is None else P(ep_axis_name), # w2_scale_hbm
1515
+ None if b1 is None else P(ep_axis_name), # b1_hbm
1516
+ None if b2 is None else P(ep_axis_name), # b2_hbm
1517
+ P(ep_axis_name), # gating_output_hbm
1518
+ P(), # a2a_g_hbm
1519
+ ),
1014
1520
  out_specs=P(ep_axis_name),
1015
- check_rep=False,
1521
+ check_vma=False,
1016
1522
  )
1017
- def kernel(tokens, w1, w2, gating_output, a2a_g_hbm_scratch):
1523
+ def kernel(
1524
+ tokens,
1525
+ w1,
1526
+ w2,
1527
+ w1_scale,
1528
+ w2_scale,
1529
+ b1,
1530
+ b2,
1531
+ gating_output,
1532
+ a2a_g_hbm_scratch,
1533
+ ):
1018
1534
  return fused_moe(
1019
- pltpu.with_memory_space_constraint(tokens, pltpu.HBM),
1020
- pltpu.with_memory_space_constraint(w1, pltpu.HBM),
1021
- pltpu.with_memory_space_constraint(w2, pltpu.HBM),
1022
- pltpu.with_memory_space_constraint(gating_output, pltpu.HBM),
1023
- pltpu.with_memory_space_constraint(a2a_g_hbm_scratch, pltpu.HBM),
1535
+ pltpu.with_memory_space_constraint(tokens,
1536
+ pltpu.HBM), # tokens_hbm
1537
+ pltpu.with_memory_space_constraint(w1, pltpu.HBM), # w1_hbm
1538
+ pltpu.with_memory_space_constraint(w2, pltpu.HBM), # w2_hbm
1539
+ (None if w1_scale is None else pltpu.with_memory_space_constraint(
1540
+ w1_scale, pltpu.HBM)), # w1_scale_hbm
1541
+ (None if w2_scale is None else pltpu.with_memory_space_constraint(
1542
+ w2_scale, pltpu.HBM)), # w2_scale_hbm
1543
+ (None if b1 is None else pltpu.with_memory_space_constraint(
1544
+ b1, pltpu.HBM)), # b1_hbm
1545
+ (None if b2 is None else pltpu.with_memory_space_constraint(
1546
+ b2, pltpu.HBM)), # b2_hbm
1547
+ pltpu.with_memory_space_constraint(gating_output,
1548
+ pltpu.HBM), # gating_output_hbm
1549
+ pltpu.with_memory_space_constraint(a2a_g_hbm_scratch,
1550
+ pltpu.HBM), # a2a_g_hbm
1024
1551
  )
1025
1552
 
1026
1553
  a2a_g_hbm_scratch = pl.empty(
@@ -1029,6 +1556,10 @@ def fused_ep_moe(
1029
1556
  tokens,
1030
1557
  w1,
1031
1558
  w2,
1559
+ w1_scale,
1560
+ w2_scale,
1561
+ b1,
1562
+ b2,
1032
1563
  gating_output,
1033
1564
  a2a_g_hbm_scratch,
1034
1565
  )