tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (56) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_envs.py +11 -32
  6. tests/test_utils.py +2 -1
  7. tpu_inference/__init__.py +3 -22
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +2 -9
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
  16. tpu_inference/layers/common/attention_interface.py +1 -7
  17. tpu_inference/layers/common/sharding.py +5 -5
  18. tpu_inference/layers/vllm/fused_moe.py +208 -170
  19. tpu_inference/layers/vllm/quantization/common.py +1 -6
  20. tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
  21. tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
  22. tpu_inference/layers/vllm/sharding.py +2 -2
  23. tpu_inference/lora/torch_punica_tpu.py +2 -1
  24. tpu_inference/mock/__init__.py +0 -0
  25. tpu_inference/mock/vllm_config_utils.py +28 -0
  26. tpu_inference/mock/vllm_envs.py +1219 -0
  27. tpu_inference/mock/vllm_logger.py +212 -0
  28. tpu_inference/mock/vllm_logging_utils.py +15 -0
  29. tpu_inference/models/common/model_loader.py +10 -43
  30. tpu_inference/models/jax/llama3.py +1 -2
  31. tpu_inference/models/jax/llama_eagle3.py +5 -8
  32. tpu_inference/models/jax/phi3.py +376 -0
  33. tpu_inference/models/jax/qwen2.py +1 -2
  34. tpu_inference/models/jax/qwen2_5_vl.py +48 -163
  35. tpu_inference/models/jax/qwen3.py +1 -2
  36. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  37. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  38. tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
  39. tpu_inference/platforms/tpu_platform.py +31 -37
  40. tpu_inference/runner/compilation_manager.py +58 -141
  41. tpu_inference/runner/kv_cache.py +1 -1
  42. tpu_inference/runner/kv_cache_manager.py +18 -17
  43. tpu_inference/runner/persistent_batch_manager.py +2 -40
  44. tpu_inference/runner/structured_decoding_manager.py +3 -2
  45. tpu_inference/runner/tpu_runner.py +147 -271
  46. tpu_inference/runner/utils.py +2 -2
  47. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  48. tpu_inference/tpu_info.py +3 -4
  49. tpu_inference/utils.py +13 -36
  50. tpu_inference/worker/tpu_worker.py +25 -162
  51. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
  52. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
  53. tpu_inference/models/jax/llama_guard_4.py +0 -361
  54. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
  55. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
  56. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ import jax.numpy as jnp
7
7
  from jax import lax
8
8
  from jax._src import dtypes
9
9
  from jax.experimental import pallas as pl
10
+ from jax.experimental import shard_map
10
11
  from jax.experimental.pallas import tpu as pltpu
11
12
 
12
13
  P = jax.sharding.PartitionSpec
@@ -34,49 +35,13 @@ def broadcast_minor(src, shape):
34
35
  axis=-1)[..., :shape[-1]]
35
36
 
36
37
 
37
- def swigluoai(gate: jax.Array,
38
- up: jax.Array,
39
- *,
40
- alpha: float = 1.702,
41
- limit: float = 7.0) -> jax.Array:
42
- """Activation used in some models such as GPT-OSS."""
43
- gate = jnp.clip(gate, a_max=limit)
44
- up = jnp.clip(up, a_min=-limit, a_max=limit)
45
- glu = gate * jax.nn.sigmoid(alpha * gate)
46
- return (up + 1.0) * glu
47
-
48
-
49
- def activation_fn(acc1, acc3, act_fn):
50
- if act_fn == "silu":
51
- return jax.nn.silu(acc1) * acc3
52
- elif act_fn == "gelu":
53
- return jax.nn.gelu(acc1) * acc3
54
- elif act_fn == "swigluoai":
55
- return swigluoai(acc1, acc3)
56
- else:
57
- raise RuntimeError(f"Unsupported activation function: {act_fn}")
58
-
59
-
60
38
  def ref_moe(
61
- tokens: jax.Array, # (num_tokens, hidden_size)
62
- w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
63
- w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
64
- gating_output: jax.Array, # (num_tokens, num_experts)
65
- top_k: int,
66
- *,
67
- renormalize_topk_logits: bool = False,
68
- activation="silu",
69
- subc_quant_wsz: int | None = None,
70
- w1_scale:
71
- (
72
- jax.Array | None
73
- ) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size)
74
- w2_scale:
75
- (
76
- jax.Array | None
77
- ) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size)
78
- b1: jax.Array | None = None, # (num_experts, 2, intermediate_size)
79
- b2: jax.Array | None = None, # (num_experts, hidden_size)
39
+ tokens: jax.Array, # (num_tokens, hidden_size)
40
+ w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
41
+ w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
42
+ gating_output: jax.Array, # (num_tokens, num_experts)
43
+ top_k: int,
44
+ activation="silu",
80
45
  ):
81
46
  n_tokens = tokens.shape[0] # num_tokens
82
47
 
@@ -88,12 +53,7 @@ def ref_moe(
88
53
  top_k_logits, top_k_indices = lax.top_k(
89
54
  gating_logits, top_k) # [num_tokens, top_k], [num_tokens, top_k]
90
55
 
91
- if renormalize_topk_logits:
92
- top_k_logits = top_k_logits / jnp.sum(
93
- top_k_logits, axis=-1, keepdims=True)
94
-
95
56
  t_outputs = []
96
- hidden_size, intermediate_size = w1.shape[-2:]
97
57
 
98
58
  # Process each token individually
99
59
  for i in range(n_tokens):
@@ -105,24 +65,10 @@ def ref_moe(
105
65
  # Process each selected expert for the current token
106
66
  for expert_id in assigned_expert_ids:
107
67
  # Get expert weights
108
- expert_w1 = w1[expert_id, 0].astype(jnp.float32)
109
- expert_w3 = w1[expert_id, 1].astype(jnp.float32)
110
- if w1_scale is not None:
111
- expert_w1 *= jnp.repeat(w1_scale[expert_id, 0],
112
- subc_quant_wsz,
113
- axis=0)[:hidden_size]
114
- expert_w3 *= jnp.repeat(w1_scale[expert_id, 1],
115
- subc_quant_wsz,
116
- axis=0)[:hidden_size]
117
68
  expert_weight_1 = jnp.concat(
118
- [expert_w1, expert_w3],
69
+ [w1[expert_id, 0], w1[expert_id, 1]],
119
70
  axis=-1) # [d_model, 2 * intermediate_size]
120
- expert_weight_2 = w2[expert_id].astype(
121
- jnp.float32) # [intermediate_size, d_model]
122
- if w2_scale is not None:
123
- expert_weight_2 *= jnp.repeat(w2_scale[expert_id],
124
- subc_quant_wsz,
125
- axis=0)[:intermediate_size]
71
+ expert_weight_2 = w2[expert_id] # [intermediate_size, d_model]
126
72
 
127
73
  # First linear layer with SwiGLU activation
128
74
  gmm_1_out = curr_token @ expert_weight_1 # [1, 2 * intermediate_size]
@@ -131,17 +77,21 @@ def ref_moe(
131
77
  gmm1_w1_proj, gmm1_w3_proj = jnp.split(
132
78
  gmm_1_out, 2,
133
79
  axis=-1) # [1, intermediate_size], [1, intermediate_size]
134
- if b1 is not None:
135
- gmm1_w1_proj += b1[expert_id:expert_id + 1, 0]
136
- gmm1_w3_proj += b1[expert_id:expert_id + 1, 1]
137
80
 
138
81
  # Apply gated activation: activation(gate) * up
139
- act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, activation)
82
+ if activation == "silu":
83
+ act = jax.nn.silu(
84
+ gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
85
+ elif activation == "gelu":
86
+ act = jax.nn.gelu(
87
+ gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
88
+ else:
89
+ raise ValueError(
90
+ f"Unsupported activation: {activation}. Use 'silu' or 'gelu'."
91
+ )
140
92
 
141
93
  # Second linear layer (down projection)
142
94
  gmm_2_out = act @ expert_weight_2 # [1, d_model]
143
- if b2 is not None:
144
- gmm_2_out += b2[expert_id:expert_id + 1]
145
95
  tok_expert_act.append(gmm_2_out)
146
96
 
147
97
  # Combine outputs from all selected experts
@@ -155,7 +105,7 @@ def ref_moe(
155
105
  axis=0,
156
106
  keepdims=True) # [1, d_model]
157
107
 
158
- t_outputs.append(weighted_output.astype(tokens.dtype))
108
+ t_outputs.append(weighted_output)
159
109
 
160
110
  return jnp.concatenate(t_outputs, axis=0) # [num_tokens, d_model]
161
111
 
@@ -165,13 +115,6 @@ def _fused_ep_moe_kernel(
165
115
  tokens_hbm, # (local_num_tokens, t_packing, hidden_size // t_packing)
166
116
  w1_hbm, # (local_num_experts, 2, hidden_size, intermediate_size)
167
117
  w2_hbm, # (local_num_experts, intermediate_size, hidden_size)
168
- # TODO(jevinjiang): We choose F32 scale for easier slicing. The extra
169
- # latency should be hidden in the pipeline overlaping. But is there a better
170
- # way to do this?
171
- w1_scale_hbm, # None | F32(local_num_experts, 2, cdiv(hidden_size, subc_quant_wsz), 1, intermediate_size)
172
- w2_scale_hbm, # None | F32(local_num_experts, cdiv(intermediate_size, subc_quant_wsz), 1, hidden_size)
173
- b1_hbm, # None | F32(local_num_experts, 2, 1, intermediate_size)
174
- b2_hbm, # None | F32(local_num_experts, 1, hidden_size)
175
118
  gating_hbm, # (local_num_tokens, padded_num_experts)
176
119
  a2a_g_hbm, # (num_experts, bt, t_packing, hidden_size // t_packing)
177
120
  # Output
@@ -193,12 +136,6 @@ def _fused_ep_moe_kernel(
193
136
  b_w1_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
194
137
  b_w3_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
195
138
  b_w2_x2_vmem, # <bw_sem_id> (2, t_packing, bf, bd2 // t_packing)
196
- b_w1_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
197
- b_w3_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
198
- b_w2_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bf // subc_quant_wsz, 1, bd2 // t_packing)
199
- b_b1_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
200
- b_b3_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
201
- b_b2_x2_vmem, # None | <bw_sem_id> (2, t_packing, 1, bd2 // t_packing)
202
139
  b_acc_vmem, # F32(bt * num_devices, 1, bf * 2)
203
140
  ### Semaphores:
204
141
  local_sems, # (2, 5): 2 x [b_gating_sem, b_w1_sem, b_w2_sem, b_w3_sem, b_output_sem]
@@ -208,10 +145,7 @@ def _fused_ep_moe_kernel(
208
145
  a2a_acc_sem,
209
146
  *,
210
147
  top_k: int,
211
- renormalize_topk_logits: bool,
212
148
  ep_axis_name: str,
213
- act_fn: str,
214
- subc_quant_wsz: int | None = None,
215
149
  # Kernel tuning params.
216
150
  bt: int, # Block size of local_num_tokens.
217
151
  bf: int, # Block size of intermediate_size.
@@ -226,53 +160,34 @@ def _fused_ep_moe_kernel(
226
160
  num_devices = lax.axis_size(ep_axis_name)
227
161
  local_num_tokens = tokens_hbm.shape[0]
228
162
  local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
163
+ # num_experts = local_num_experts * num_devices
164
+ # padded_num_experts = expert_starts_x2_smem.shape[-1]
229
165
  right_id = (my_id + 1) % num_devices
230
166
 
231
167
  t_dtype = tokens_hbm.dtype
232
168
  t_packing = get_dtype_packing(t_dtype)
233
169
  t_bitwidth = 32 // t_packing
234
170
  assert a2a_g_hbm.dtype == t_dtype
235
- assert w1_hbm.dtype == w2_hbm.dtype
171
+ assert w1_hbm.dtype == t_dtype
172
+ assert w2_hbm.dtype == t_dtype
236
173
 
237
- assert bd1 % bd1c == 0
238
- assert bd2 % bd2c == 0
239
- assert bf % bfc == 0
240
- assert hidden_size % t_packing == 0
241
- assert bd1 % t_packing == 0
242
- assert bd2 % t_packing == 0
243
- assert bd1c % t_packing == 0
244
- assert bd2c % t_packing == 0
245
-
246
- h_per_t_packing = hidden_size // t_packing
247
- assert tokens_hbm.shape[-1] == h_per_t_packing
248
- bd1_per_t_packing = bd1 // t_packing
249
- bd2_per_t_packing = bd2 // t_packing
250
- bd1c_per_t_packing = bd1c // t_packing
251
- bd2c_per_t_packing = bd2c // t_packing
252
-
253
- if subc_quant_wsz is not None:
254
- assert subc_quant_wsz % 256 == 0
255
- assert bd1c_per_t_packing == subc_quant_wsz
256
- assert bfc == subc_quant_wsz
257
- assert bd1 % subc_quant_wsz == 0
258
- assert bf % subc_quant_wsz == 0
259
- assert bd1_per_t_packing % subc_quant_wsz == 0
260
- assert h_per_t_packing % subc_quant_wsz == 0
174
+ h_per_packing = hidden_size // t_packing
175
+ assert tokens_hbm.shape[-1] == h_per_packing
176
+ bd1_per_packing = bd1 // t_packing
177
+ bd2_per_packing = bd2 // t_packing
178
+ bd1c_per_packing = bd1c // t_packing
179
+ bd2c_per_packing = bd2c // t_packing
261
180
 
262
181
  num_bt = cdiv(local_num_tokens, bt)
263
182
  num_bf = cdiv(intermediate_size, bf)
264
183
  num_bd1 = cdiv(hidden_size, bd1)
265
184
  num_bd2 = cdiv(hidden_size, bd2)
266
185
 
267
- def get_mesh_device_id(ep_rank):
268
- dp_rank = jax.lax.axis_index("data")
269
- return (dp_rank, ep_rank)
270
-
271
186
  def sync_barrier():
272
187
  barrier_sem = pltpu.get_barrier_semaphore()
273
188
  pltpu.semaphore_signal(
274
189
  barrier_sem,
275
- device_id=get_mesh_device_id(right_id),
190
+ device_id=(0, right_id),
276
191
  device_id_type=pltpu.DeviceIdType.MESH,
277
192
  )
278
193
  pltpu.semaphore_wait(barrier_sem, 1)
@@ -297,7 +212,7 @@ def _fused_ep_moe_kernel(
297
212
  sem=b_gating_sem,
298
213
  ).wait()
299
214
 
300
- def get_top_k(input, top_k, renormalize_topk_logits):
215
+ def get_top_k(input, top_k):
301
216
  assert len(input.shape) == 2, input.shape
302
217
  input = input.astype(jnp.float32)
303
218
  top_k_logits_lst = []
@@ -305,15 +220,11 @@ def _fused_ep_moe_kernel(
305
220
  t2e = jnp.zeros(input.shape, dtype=jnp.int32)
306
221
  t2e_routing = jnp.zeros(input.shape, dtype=jnp.int32)
307
222
  iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1)
308
- top_k_logits_sum = jnp.zeros((input.shape[0], 128), jnp.float32)
309
-
310
223
  for k_id in range(top_k):
311
- # TODO(jevinjiang): return both top_k values and indices in Mosaic
224
+ # TODO(jevinjiang): return both top_k values and indices in op in Mosaic
312
225
  top_k_logits = jnp.broadcast_to(
313
226
  jnp.max(input, axis=1, keepdims=True),
314
227
  (input.shape[0], 128)).astype(input.dtype)
315
- if renormalize_topk_logits:
316
- top_k_logits_sum += top_k_logits
317
228
  top_k_logits_lst.append(top_k_logits)
318
229
  # TODO(jevinjiang): support bf16 argmax in Mosaic
319
230
  top_k_indices = jnp.broadcast_to(
@@ -325,11 +236,6 @@ def _fused_ep_moe_kernel(
325
236
  if k_id != top_k - 1:
326
237
  input = jnp.where(mask, -jnp.inf, input)
327
238
 
328
- if renormalize_topk_logits:
329
- for k_id in range(top_k):
330
- top_k_logits_lst[
331
- k_id] = top_k_logits_lst[k_id] / top_k_logits_sum
332
-
333
239
  expert_sizes = jnp.sum(t2e, axis=0, keepdims=True)
334
240
  expert_starts = jnp.zeros_like(expert_sizes)
335
241
  return top_k_logits_lst, t2e_routing, expert_sizes, expert_starts
@@ -371,7 +277,7 @@ def _fused_ep_moe_kernel(
371
277
  dst_ref=d2e_count_vmem.at[row_id],
372
278
  send_sem=send_sem,
373
279
  recv_sem=recv_sem,
374
- device_id=get_mesh_device_id(right_id),
280
+ device_id=(0, right_id),
375
281
  device_id_type=pltpu.DeviceIdType.MESH,
376
282
  ).wait()
377
283
  row_id = (row_id + num_devices - 1) % num_devices
@@ -453,8 +359,10 @@ def _fused_ep_moe_kernel(
453
359
  pl.ds(start, remote_sz)],
454
360
  send_sem=send_sems.at[e_sem_id],
455
361
  recv_sem=recv_sems.at[e_sem_id],
456
- device_id=get_mesh_device_id(recv_id),
457
- device_id_type=pltpu.DeviceIdType.MESH,
362
+ device_id=(
363
+ 0,
364
+ recv_id,
365
+ ),
458
366
  ).start()
459
367
  a2a_s_sends_x2_smem[e_sem_id] = send_sz
460
368
 
@@ -498,8 +406,7 @@ def _fused_ep_moe_kernel(
498
406
  dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)],
499
407
  send_sem=send_sems.at[e_sem_id],
500
408
  recv_sem=a2a_gather_sem,
501
- device_id=get_mesh_device_id(recv_id),
502
- device_id_type=pltpu.DeviceIdType.MESH,
409
+ device_id=(0, recv_id),
503
410
  ).start()
504
411
  start += sz
505
412
 
@@ -528,173 +435,68 @@ def _fused_ep_moe_kernel(
528
435
 
529
436
  def start_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
530
437
  for p in range(t_packing):
531
- offset = p * h_per_t_packing + bd1_id * bd1_per_t_packing
438
+ offset = p * h_per_packing + bd1_id * bd1_per_packing
532
439
  pltpu.make_async_copy(
533
440
  src_ref=w1_hbm.at[
534
441
  local_e_id,
535
442
  0,
536
- pl.ds(offset, bd1_per_t_packing),
443
+ pl.ds(offset, bd1_per_packing),
537
444
  pl.ds(bf_id * bf, bf),
538
445
  ],
539
446
  dst_ref=b_w1_x2_vmem.at[bw1_sem_id, p],
540
447
  sem=local_sems.at[bw1_sem_id, 1],
541
448
  ).start()
542
- if w1_scale_hbm is not None:
543
- assert subc_quant_wsz is not None
544
- pltpu.make_async_copy(
545
- src_ref=w1_scale_hbm.at[
546
- local_e_id,
547
- 0,
548
- pl.ds(
549
- offset // subc_quant_wsz,
550
- bd1_per_t_packing // subc_quant_wsz,
551
- ),
552
- pl.ds(0, 1),
553
- pl.ds(bf_id * bf, bf),
554
- ],
555
- dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id, p],
556
- sem=local_sems.at[bw1_sem_id, 1],
557
- ).start()
558
- if b1_hbm is not None and bd1_id == 0:
559
- pltpu.make_async_copy(
560
- src_ref=b1_hbm.at[local_e_id, 0,
561
- pl.ds(0, 1),
562
- pl.ds(bf_id * bf, bf)],
563
- dst_ref=b_b1_x2_vmem.at[bf_id % 2],
564
- sem=local_sems.at[bw1_sem_id, 1],
565
- ).start()
566
449
 
567
450
  def start_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
568
451
  for p in range(t_packing):
569
- offset = p * h_per_t_packing + bd2_id * bd2_per_t_packing
452
+ offset = p * h_per_packing + bd2_id * bd2_per_packing
570
453
  pltpu.make_async_copy(
571
454
  src_ref=w2_hbm.at[
572
455
  local_e_id,
573
456
  pl.ds(bf_id * bf, bf),
574
- pl.ds(offset, bd2_per_t_packing),
457
+ pl.ds(offset, bd2_per_packing),
575
458
  ],
576
459
  dst_ref=b_w2_x2_vmem.at[bw2_sem_id, p],
577
460
  sem=local_sems.at[bw2_sem_id, 2],
578
461
  ).start()
579
- if w2_scale_hbm is not None:
580
- assert subc_quant_wsz is not None
581
- pltpu.make_async_copy(
582
- src_ref=w2_scale_hbm.at[
583
- local_e_id,
584
- pl.ds(bf_id * bf // subc_quant_wsz, bf //
585
- subc_quant_wsz),
586
- pl.ds(0, 1),
587
- pl.ds(offset, bd2_per_t_packing),
588
- ],
589
- dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id, p],
590
- sem=local_sems.at[bw2_sem_id, 2],
591
- ).start()
592
- if b2_hbm is not None and bf_id == 0:
593
- pltpu.make_async_copy(
594
- src_ref=b2_hbm.at[local_e_id,
595
- pl.ds(0, 1),
596
- pl.ds(offset, bd2_per_t_packing)],
597
- dst_ref=b_b2_x2_vmem.at[bd2_id % 2, p],
598
- sem=local_sems.at[bw2_sem_id, 2],
599
- ).start()
600
462
 
601
463
  def start_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
602
464
  for p in range(t_packing):
603
- offset = p * h_per_t_packing + bd3_id * bd1_per_t_packing
465
+ offset = p * h_per_packing + bd3_id * bd1_per_packing
604
466
  pltpu.make_async_copy(
605
467
  src_ref=w1_hbm.at[
606
468
  local_e_id,
607
469
  1,
608
- pl.ds(offset, bd1_per_t_packing),
470
+ pl.ds(offset, bd1_per_packing),
609
471
  pl.ds(bf_id * bf, bf),
610
472
  ],
611
473
  dst_ref=b_w3_x2_vmem.at[bw3_sem_id, p],
612
474
  sem=local_sems.at[bw3_sem_id, 3],
613
475
  ).start()
614
- if w1_scale_hbm is not None:
615
- assert subc_quant_wsz is not None
616
- pltpu.make_async_copy(
617
- src_ref=w1_scale_hbm.at[
618
- local_e_id,
619
- 1,
620
- pl.ds(
621
- offset // subc_quant_wsz,
622
- bd1_per_t_packing // subc_quant_wsz,
623
- ),
624
- pl.ds(0, 1),
625
- pl.ds(bf_id * bf, bf),
626
- ],
627
- dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id, p],
628
- sem=local_sems.at[bw3_sem_id, 3],
629
- ).start()
630
- if b1_hbm is not None and bd3_id == 0:
631
- pltpu.make_async_copy(
632
- src_ref=b1_hbm.at[local_e_id, 1,
633
- pl.ds(0, 1),
634
- pl.ds(bf_id * bf, bf)],
635
- dst_ref=b_b3_x2_vmem.at[bf_id % 2],
636
- sem=local_sems.at[bw3_sem_id, 3],
637
- ).start()
638
476
 
639
477
  def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
640
- del local_e_id
478
+ del local_e_id, bf_id, bd1_id
641
479
  pltpu.make_async_copy(
642
480
  src_ref=b_w1_x2_vmem.at[bw1_sem_id],
643
481
  dst_ref=b_w1_x2_vmem.at[bw1_sem_id],
644
482
  sem=local_sems.at[bw1_sem_id, 1],
645
483
  ).wait()
646
- if w1_scale_hbm is not None:
647
- pltpu.make_async_copy(
648
- src_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
649
- dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
650
- sem=local_sems.at[bw1_sem_id, 1],
651
- ).wait()
652
- if b1_hbm is not None and bd1_id == 0:
653
- pltpu.make_async_copy(
654
- src_ref=b_b1_x2_vmem.at[bf_id % 2],
655
- dst_ref=b_b1_x2_vmem.at[bf_id % 2],
656
- sem=local_sems.at[bw1_sem_id, 1],
657
- ).wait()
658
484
 
659
485
  def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
660
- del local_e_id
486
+ del local_e_id, bf_id, bd2_id
661
487
  pltpu.make_async_copy(
662
488
  src_ref=b_w2_x2_vmem.at[bw2_sem_id],
663
489
  dst_ref=b_w2_x2_vmem.at[bw2_sem_id],
664
490
  sem=local_sems.at[bw2_sem_id, 2],
665
491
  ).wait()
666
- if w2_scale_hbm is not None:
667
- pltpu.make_async_copy(
668
- src_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
669
- dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
670
- sem=local_sems.at[bw2_sem_id, 2],
671
- ).wait()
672
- if b2_hbm is not None and bf_id == 0:
673
- pltpu.make_async_copy(
674
- src_ref=b_b2_x2_vmem.at[bd2_id % 2],
675
- dst_ref=b_b2_x2_vmem.at[bd2_id % 2],
676
- sem=local_sems.at[bw2_sem_id, 2],
677
- ).wait()
678
492
 
679
493
  def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
680
- del local_e_id
494
+ del local_e_id, bf_id, bd3_id
681
495
  pltpu.make_async_copy(
682
496
  src_ref=b_w3_x2_vmem.at[bw3_sem_id],
683
497
  dst_ref=b_w3_x2_vmem.at[bw3_sem_id],
684
498
  sem=local_sems.at[bw3_sem_id, 3],
685
499
  ).wait()
686
- if w1_scale_hbm is not None:
687
- pltpu.make_async_copy(
688
- src_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
689
- dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
690
- sem=local_sems.at[bw3_sem_id, 3],
691
- ).wait()
692
- if b1_hbm is not None and bd3_id == 0:
693
- pltpu.make_async_copy(
694
- src_ref=b_b3_x2_vmem.at[bf_id % 2],
695
- dst_ref=b_b3_x2_vmem.at[bf_id % 2],
696
- sem=local_sems.at[bw3_sem_id, 3],
697
- ).wait()
698
500
 
699
501
  def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id):
700
502
  next_bd1_id = bd1_id + 1
@@ -718,38 +520,18 @@ def _fused_ep_moe_kernel(
718
520
  def dynamic_ffn1(
719
521
  t_b32_vmem,
720
522
  w1_vmem,
721
- w1_scale_vmem,
722
- b1_vmem,
723
523
  w3_vmem,
724
- w3_scale_vmem,
725
- b3_vmem,
726
524
  acc1_vmem,
727
525
  acc3_vmem,
728
526
  dyn_sz,
729
527
  should_init,
730
528
  ):
731
529
  assert t_b32_vmem.shape == (bt * num_devices, bd1 // t_packing)
732
- assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_t_packing,
530
+ assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_packing,
733
531
  bf)
734
532
  assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
735
533
  assert bd1 % (t_packing * 128) == 0, (bd1, t_packing)
736
534
  assert bd1c % (t_packing * 128) == 0, (bd1c, t_packing)
737
- if w1_scale_vmem is not None:
738
- assert w1_scale_vmem.shape == (
739
- t_packing,
740
- bd1_per_t_packing // subc_quant_wsz,
741
- 1,
742
- bf,
743
- )
744
- assert bd1c_per_t_packing == subc_quant_wsz
745
- if w3_scale_vmem is not None:
746
- assert w3_scale_vmem.shape == (
747
- t_packing,
748
- bd1_per_t_packing // subc_quant_wsz,
749
- 1,
750
- bf,
751
- )
752
- assert bd1c_per_t_packing == subc_quant_wsz
753
535
 
754
536
  num_loops = cdiv(dyn_sz, btc)
755
537
  repack_ty = jnp.dtype(f"int{t_bitwidth}")
@@ -758,7 +540,7 @@ def _fused_ep_moe_kernel(
758
540
  for bd1c_id in range(cdiv(bd1, bd1c)):
759
541
  t_b32 = t_b32_vmem[
760
542
  pl.ds(btc_id * btc, btc),
761
- pl.ds(bd1c_id * bd1c_per_t_packing, bd1c_per_t_packing),
543
+ pl.ds(bd1c_id * bd1c_per_packing, bd1c_per_packing),
762
544
  ]
763
545
  for p_id in range(t_packing):
764
546
  t = pltpu.bitcast(t_b32.astype(repack_ty), t_dtype)
@@ -766,64 +548,21 @@ def _fused_ep_moe_kernel(
766
548
  for bfc_id in range(cdiv(bf, bfc)):
767
549
  w_slices = (
768
550
  p_id,
769
- pl.ds(bd1c_id * bd1c_per_t_packing,
770
- bd1c_per_t_packing),
551
+ pl.ds(bd1c_id * bd1c_per_packing,
552
+ bd1c_per_packing),
771
553
  pl.ds(bfc_id * bfc, bfc),
772
554
  )
773
555
  w1 = w1_vmem[*w_slices]
774
556
  acc1 = jnp.dot(t,
775
557
  w1,
776
558
  preferred_element_type=jnp.float32)
777
-
778
- if w1_scale_vmem is not None:
779
- w1_scale_slices = (
780
- p_id,
781
- bd1c_id,
782
- pl.ds(0, 1),
783
- pl.ds(bfc_id * bfc, bfc),
784
- )
785
- # TODO(jevinjiang): can use mosaic to load with stride 0.
786
- w1_scale = jnp.broadcast_to(
787
- w1_scale_vmem[*w1_scale_slices], acc1.shape)
788
- acc1 *= w1_scale
789
-
790
559
  w3 = w3_vmem[*w_slices]
791
-
792
560
  acc3 = jnp.dot(t,
793
561
  w3,
794
562
  preferred_element_type=jnp.float32)
795
-
796
- if w3_scale_vmem is not None:
797
- w3_scale_slices = (
798
- p_id,
799
- bd1c_id,
800
- pl.ds(0, 1),
801
- pl.ds(bfc_id * bfc, bfc),
802
- )
803
- w3_scale = jnp.broadcast_to(
804
- w3_scale_vmem[*w3_scale_slices], acc3.shape)
805
- acc3 *= w3_scale
806
-
807
563
  acc_slices = (pl.ds(btc_id * btc,
808
564
  btc), pl.ds(bfc_id * bfc, bfc))
809
565
  if should_init and p_id == bd1c_id == 0:
810
- if b1_vmem is not None:
811
- b1_scale_slices = (
812
- pl.ds(0, 1),
813
- pl.ds(bfc_id * bfc, bfc),
814
- )
815
- b1 = jnp.broadcast_to(
816
- b1_vmem[*b1_scale_slices], acc1.shape)
817
- acc1 += b1
818
- if b3_vmem is not None:
819
- b3_scale_slices = (
820
- pl.ds(0, 1),
821
- pl.ds(bfc_id * bfc, bfc),
822
- )
823
- b3 = jnp.broadcast_to(
824
- b3_vmem[*b3_scale_slices], acc1.shape)
825
- acc3 += b3
826
-
827
566
  acc1_vmem[*acc_slices] = acc1
828
567
  acc3_vmem[*acc_slices] = acc3
829
568
  else:
@@ -836,28 +575,22 @@ def _fused_ep_moe_kernel(
836
575
  acc1_vmem,
837
576
  acc3_vmem,
838
577
  w2_vmem,
839
- w2_scale_vmem,
840
- b2_vmem,
841
578
  res_b32_vmem,
842
579
  dyn_sz,
843
580
  should_init,
844
581
  ):
845
- assert res_b32_vmem.shape == (bt * num_devices, bd2_per_t_packing)
846
- assert w2_vmem.shape == (t_packing, bf, bd2_per_t_packing)
582
+ assert res_b32_vmem.shape == (bt * num_devices, bd2_per_packing)
583
+ assert w2_vmem.shape == (t_packing, bf, bd2_per_packing), (
584
+ w2_vmem.shape,
585
+ t_packing,
586
+ bf,
587
+ bd2_per_packing,
588
+ )
847
589
  assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
848
590
  assert bd2 % (t_packing * 128) == 0, (bd2, t_packing)
849
591
  assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
850
592
  assert t_dtype in (jnp.float32, jnp.bfloat16)
851
593
 
852
- if w2_scale_vmem is not None:
853
- assert w2_scale_vmem.shape == (
854
- t_packing,
855
- bf // subc_quant_wsz,
856
- 1,
857
- bd2_per_t_packing,
858
- )
859
- assert bfc == subc_quant_wsz
860
-
861
594
  num_loops = cdiv(dyn_sz, btc)
862
595
  assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
863
596
 
@@ -865,47 +598,22 @@ def _fused_ep_moe_kernel(
865
598
  for bd2c_id in range(cdiv(bd2, bd2c)):
866
599
  res_lst = []
867
600
  for p_id in range(t_packing):
868
- res = jnp.zeros((btc, bd2c_per_t_packing),
869
- dtype=jnp.float32)
870
-
871
- if b2_vmem is not None and should_init:
872
- b2_scale_slices = (
873
- p_id,
874
- pl.ds(0, 1),
875
- pl.ds(bd2c_id * bd2c_per_t_packing,
876
- bd2c_per_t_packing),
877
- )
878
- b2 = jnp.broadcast_to(b2_vmem[*b2_scale_slices],
879
- res.shape)
880
- res += b2
881
-
601
+ res = jnp.zeros((btc, bd2c_per_packing), dtype=jnp.float32)
882
602
  for bfc_id in range(cdiv(bf, bfc)):
883
603
  acc_slices = (pl.ds(btc_id * btc,
884
604
  btc), pl.ds(bfc_id * bfc, bfc))
885
605
  acc1 = acc1_vmem[*acc_slices]
886
606
  acc3 = acc3_vmem[*acc_slices]
887
- act = activation_fn(acc1, acc3, act_fn)
607
+ act = jax.nn.silu(acc1) * acc3
888
608
  w2 = w2_vmem[
889
609
  p_id,
890
610
  pl.ds(bfc_id * bfc, bfc),
891
611
  pl.ds(bd2c_id *
892
- bd2c_per_t_packing, bd2c_per_t_packing),
612
+ bd2c_per_packing, bd2c_per_packing),
893
613
  ]
894
- acc = jnp.dot(act,
895
- w2,
896
- preferred_element_type=jnp.float32)
897
- if w2_scale_vmem is not None:
898
- w2_scale_slices = (
899
- p_id,
900
- bfc_id,
901
- pl.ds(0, 1),
902
- pl.ds(bd2c_id * bd2c_per_t_packing,
903
- bd2c_per_t_packing),
904
- )
905
- w2_scale = jnp.broadcast_to(
906
- w2_scale_vmem[*w2_scale_slices], acc.shape)
907
- acc *= w2_scale
908
- res += acc
614
+ res += jnp.dot(act,
615
+ w2,
616
+ preferred_element_type=jnp.float32)
909
617
  res = pltpu.bitcast(res, jnp.uint32)
910
618
  if t_packing == 2:
911
619
  res = res >> 16 << (16 * p_id)
@@ -918,7 +626,7 @@ def _fused_ep_moe_kernel(
918
626
  res |= res_lst[i]
919
627
  sliced_res_vmem = res_b32_vmem.at[
920
628
  pl.ds(btc_id * btc, btc),
921
- pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing),
629
+ pl.ds(bd2c_id * bd2c_per_packing, bd2c_per_packing),
922
630
  ]
923
631
  if should_init:
924
632
  sliced_res_vmem[...] = res
@@ -947,33 +655,21 @@ def _fused_ep_moe_kernel(
947
655
  e_id = my_id * local_num_experts + local_e_id
948
656
  dyn_sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
949
657
 
950
- bd1_per_t_packing = bd1 // t_packing
951
- bd2_per_t_packing = bd2 // t_packing
658
+ bd1_per_packing = bd1 // t_packing
659
+ bd2_per_packing = bd2 // t_packing
952
660
 
953
661
  for bf_id in range(num_bf):
954
662
  for bd1_id in range(num_bd1):
955
663
  start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, 0)
956
- w1_scale_vmem = (None if b_w1_scale_x2_vmem is None else
957
- b_w1_scale_x2_vmem.at[bw_sem_id])
958
- w3_scale_vmem = (None if b_w3_scale_x2_vmem is None else
959
- b_w3_scale_x2_vmem.at[bw_sem_id])
960
- b1_vmem = None if b_b1_x2_vmem is None else b_b1_x2_vmem.at[
961
- bf_id % 2]
962
- b3_vmem = None if b_b3_x2_vmem is None else b_b3_x2_vmem.at[
963
- bf_id % 2]
964
664
  wait_fetch_bw1(local_e_id, bw_sem_id, bf_id, bd1_id)
965
665
  wait_fetch_bw3(local_e_id, bw_sem_id, bf_id, bd1_id)
966
666
 
967
667
  dynamic_ffn1(
968
668
  t_b32_vmem=a2a_s_b32_vmem.at[
969
669
  ...,
970
- pl.ds(bd1_id * bd1_per_t_packing, bd1_per_t_packing)],
670
+ pl.ds(bd1_id * bd1_per_packing, bd1_per_packing)],
971
671
  w1_vmem=b_w1_x2_vmem.at[bw_sem_id],
972
- w1_scale_vmem=w1_scale_vmem,
973
- b1_vmem=b1_vmem,
974
672
  w3_vmem=b_w3_x2_vmem.at[bw_sem_id],
975
- w3_scale_vmem=w3_scale_vmem,
976
- b3_vmem=b3_vmem,
977
673
  acc1_vmem=b_acc1_vmem,
978
674
  acc3_vmem=b_acc3_vmem,
979
675
  dyn_sz=dyn_sz,
@@ -988,19 +684,13 @@ def _fused_ep_moe_kernel(
988
684
  if bf_id == bd2_id == 0:
989
685
  wait_a2a_gather_send(bt_id, e_sem_id, local_e_id - 2)
990
686
 
991
- w2_scale_vmem = (None if b_w2_scale_x2_vmem is None else
992
- b_w2_scale_x2_vmem.at[bw_sem_id])
993
- b2_vmem = None if b_b2_x2_vmem is None else b_b2_x2_vmem.at[
994
- bd2_id % 2]
995
687
  dynamic_ffn2(
996
688
  acc1_vmem=b_acc1_vmem,
997
689
  acc3_vmem=b_acc3_vmem,
998
690
  w2_vmem=b_w2_x2_vmem.at[bw_sem_id],
999
- w2_scale_vmem=w2_scale_vmem,
1000
- b2_vmem=b2_vmem,
1001
691
  res_b32_vmem=a2a_s_acc_b32_vmem.at[
1002
692
  ...,
1003
- pl.ds(bd2_id * bd2_per_t_packing, bd2_per_t_packing)],
693
+ pl.ds(bd2_id * bd2_per_packing, bd2_per_packing)],
1004
694
  dyn_sz=dyn_sz,
1005
695
  should_init=(bf_id == 0),
1006
696
  )
@@ -1067,7 +757,7 @@ def _fused_ep_moe_kernel(
1067
757
  b_gating = b_gating_x2_vmem[bt_sem_id]
1068
758
  b_gating_score = jax.nn.softmax(b_gating, axis=-1)
1069
759
  top_k_logits_lst, t2e_routing, expert_sizes, expert_starts = get_top_k(
1070
- b_gating_score, top_k, renormalize_topk_logits)
760
+ b_gating_score, top_k)
1071
761
 
1072
762
  all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts,
1073
763
  expert_sizes)
@@ -1137,9 +827,6 @@ def _fused_ep_moe_kernel(
1137
827
  static_argnames=[
1138
828
  "mesh",
1139
829
  "top_k",
1140
- "renormalize_topk_logits",
1141
- "act_fn",
1142
- "subc_quant_wsz",
1143
830
  "bt",
1144
831
  "bf",
1145
832
  "bd1",
@@ -1158,18 +845,7 @@ def fused_ep_moe(
1158
845
  w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
1159
846
  gating_output: jax.Array, # (num_tokens, num_experts)
1160
847
  top_k: int,
1161
- renormalize_topk_logits: bool = False,
1162
- act_fn: str = "silu",
1163
848
  *,
1164
- subc_quant_wsz: int | None = None,
1165
- w1_scale: (
1166
- jax.Array | None
1167
- ) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size)
1168
- w2_scale: (
1169
- jax.Array | None
1170
- ) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size)
1171
- b1: jax.Array | None = None, # (num_experts, 2, intermediate_size)
1172
- b2: jax.Array | None = None, # (num_experts, hidden_size)
1173
849
  # Kernel tuning parameters.
1174
850
  bt: int,
1175
851
  bf: int,
@@ -1179,19 +855,18 @@ def fused_ep_moe(
1179
855
  bfc: int,
1180
856
  bd1c: int,
1181
857
  bd2c: int,
1182
- ep_axis_name: str = "model",
858
+ ep_axis_name: str = 'model',
1183
859
  ):
1184
- # TODO(jevinjiang): move all these assertions to validation function.
1185
860
  # Assert all other axes have length of 1
1186
- assert len(mesh.shape) == 2, "Expect 2D mesh"
1187
- assert ("data" in mesh.shape
1188
- and mesh.shape["data"] == 1), "Expect data axis size of 1"
861
+ assert len(mesh.shape) == 2, "Expect 2D mesh in tpu-inference"
862
+ assert 'data' in mesh.shape and mesh.shape['data'] == 1, \
863
+ "Expect data axis size of 1 in tpu-inference"
1189
864
 
1190
865
  ep_size = mesh.shape[ep_axis_name]
1191
866
  num_devices = ep_size
1192
867
 
1193
868
  num_tokens, actual_hidden_size = tokens.shape
1194
- num_experts, actual_intermediate_size, _ = w2.shape
869
+ num_experts, intermediate_size, _ = w2.shape
1195
870
 
1196
871
  assert num_tokens % ep_size == 0
1197
872
  assert num_experts % ep_size == 0
@@ -1199,18 +874,26 @@ def fused_ep_moe(
1199
874
  local_num_tokens = num_tokens // ep_size
1200
875
  # local_num_experts = num_experts // ep_size
1201
876
  padded_num_experts = align_to(num_experts, 128)
877
+
1202
878
  t_dtype = tokens.dtype
1203
879
  t_packing = get_dtype_packing(t_dtype)
880
+ hidden_size = align_to(actual_hidden_size, 128 * t_packing)
881
+ if hidden_size != actual_hidden_size:
882
+ tokens = jnp.pad(
883
+ tokens,
884
+ ((0, 0), (0, hidden_size - actual_hidden_size)),
885
+ constant_values=0,
886
+ )
887
+ tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
888
+ bt = min(bt, local_num_tokens)
889
+ bf = min(bf, intermediate_size)
890
+ bd1 = min(bd1, hidden_size)
891
+ bd2 = min(bd2, hidden_size)
1204
892
 
1205
- if subc_quant_wsz is not None:
1206
- if subc_quant_wsz % 256 != 0:
1207
- raise NotImplementedError(
1208
- "Sub-quantized window is not aligned to 256.")
1209
- # We force compute size of contracting dim to subc_quant_wsz. So we can
1210
- # apply same scale after matmul and accumulation.
1211
- bd1c = subc_quant_wsz * t_packing
1212
- bfc = subc_quant_wsz
1213
-
893
+ btc = min(btc, bt * num_devices)
894
+ bfc = min(bfc, bf)
895
+ bd1c = min(bd1c, bd1)
896
+ bd2c = min(bd2c, bd2)
1214
897
  assert bfc % 128 == 0
1215
898
  assert bd1c % (t_packing * 128) == 0
1216
899
  assert bd2c % (t_packing * 128) == 0
@@ -1218,41 +901,6 @@ def fused_ep_moe(
1218
901
  assert bd1 % bd1c == 0
1219
902
  assert bd2 % bd2c == 0
1220
903
 
1221
- btc = min(btc, bt * num_devices)
1222
- hidden_size = align_to(actual_hidden_size, 128 * t_packing)
1223
- # TODO(jevinjiang): instead of padding outside the kernel, we can try dynammic
1224
- # masking inside the kernel.
1225
- hidden_size = align_to(hidden_size, bd1)
1226
- hidden_size = align_to(hidden_size, bd2)
1227
- intermediate_size = align_to(actual_intermediate_size, bf)
1228
-
1229
- # TODO(jevinjiang): we should dump scale as the kernel expected shape in the
1230
- # checkpoint offline or reshape right after weight loading.
1231
- if w1_scale is not None:
1232
- assert w1_scale.shape[0] == w1.shape[0]
1233
- assert w1_scale.shape[1] == w1.shape[1] == 2
1234
- assert w1_scale.shape[2] == cdiv(w1.shape[2], subc_quant_wsz)
1235
- assert w1_scale.shape[3] == w1.shape[3]
1236
- w1_scale = jnp.expand_dims(w1_scale.astype(jnp.float32), axis=-2)
1237
-
1238
- if w2_scale is not None:
1239
- assert w2_scale.shape[0] == w2.shape[0]
1240
- assert w2_scale.shape[1] == cdiv(w2.shape[1], subc_quant_wsz)
1241
- assert w2_scale.shape[2] == w2.shape[2]
1242
- w2_scale = jnp.expand_dims(w2_scale.astype(jnp.float32), axis=-2)
1243
-
1244
- if b1 is not None:
1245
- assert b1.shape[0] == w1.shape[0]
1246
- assert b1.shape[1] == w1.shape[1] == 2
1247
- assert b1.shape[2] == w1.shape[3]
1248
- b1 = jnp.expand_dims(b1.astype(jnp.float32), axis=-2)
1249
-
1250
- if b2 is not None:
1251
- assert b2.shape[0] == w2.shape[0]
1252
- assert b2.shape[1] == w2.shape[2]
1253
- b2 = jnp.expand_dims(b2.astype(jnp.float32), axis=-2)
1254
-
1255
- # Prepare inputs for the kernel.
1256
904
  if padded_num_experts != gating_output.shape[-1]:
1257
905
  gating_output = jnp.pad(
1258
906
  gating_output,
@@ -1260,92 +908,13 @@ def fused_ep_moe(
1260
908
  constant_values=-jnp.inf,
1261
909
  )
1262
910
 
1263
- if (hidden_size != actual_hidden_size
1264
- or intermediate_size != actual_intermediate_size):
1265
- tokens = jnp.pad(
1266
- tokens,
1267
- ((0, 0), (0, hidden_size - actual_hidden_size)),
1268
- constant_values=0,
1269
- )
1270
- w1 = jnp.pad(
1271
- w1,
1272
- (
1273
- (0, 0),
1274
- (0, 0),
1275
- (0, hidden_size - actual_hidden_size),
1276
- (0, intermediate_size - actual_intermediate_size),
1277
- ),
1278
- constant_values=0,
1279
- )
1280
- w2 = jnp.pad(
1281
- w2,
1282
- (
1283
- (0, 0),
1284
- (0, intermediate_size - actual_intermediate_size),
1285
- (0, hidden_size - actual_hidden_size),
1286
- ),
1287
- constant_values=0,
1288
- )
1289
- if w1_scale is not None:
1290
- w1_scale = jnp.pad(
1291
- w1_scale,
1292
- (
1293
- (0, 0),
1294
- (0, 0),
1295
- (0,
1296
- cdiv(hidden_size, subc_quant_wsz) - w1_scale.shape[-3]),
1297
- (0, 0),
1298
- (0, intermediate_size - w1_scale.shape[-1]),
1299
- ),
1300
- constant_values=0,
1301
- )
1302
- if w2_scale is not None:
1303
- w2_scale = jnp.pad(
1304
- w2_scale,
1305
- (
1306
- (0, 0),
1307
- (0, cdiv(intermediate_size, subc_quant_wsz) -
1308
- w2_scale.shape[-3]),
1309
- (0, 0),
1310
- (0, hidden_size - w2_scale.shape[-1]),
1311
- ),
1312
- constant_values=0,
1313
- )
1314
- if b1 is not None:
1315
- b1 = jnp.pad(
1316
- b1,
1317
- (
1318
- (0, 0),
1319
- (0, 0),
1320
- (0, 0),
1321
- (0, intermediate_size - b1.shape[-1]),
1322
- ),
1323
- constant_values=0,
1324
- )
1325
- if b2 is not None:
1326
- b2 = jnp.pad(
1327
- b2,
1328
- (
1329
- (0, 0),
1330
- (0, 0),
1331
- (0, hidden_size - b2.shape[-1]),
1332
- ),
1333
- constant_values=0,
1334
- )
1335
-
1336
- tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
1337
-
1338
- hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
1339
- scope_name = f"fused_moe_k-{top_k}_renorm-{renormalize_topk_logits}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}"
911
+ scope_name = f"fused_moe_k-{top_k}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}"
1340
912
  fused_moe = jax.named_scope(scope_name)(
1341
913
  pl.pallas_call(
1342
914
  functools.partial(
1343
915
  _fused_ep_moe_kernel,
1344
916
  top_k=top_k,
1345
- renormalize_topk_logits=renormalize_topk_logits,
1346
917
  ep_axis_name=ep_axis_name,
1347
- act_fn=act_fn,
1348
- subc_quant_wsz=subc_quant_wsz,
1349
918
  bt=bt,
1350
919
  bf=bf,
1351
920
  bd1=bd1,
@@ -1360,17 +929,11 @@ def fused_ep_moe(
1360
929
  grid_spec=pltpu.PrefetchScalarGridSpec(
1361
930
  num_scalar_prefetch=0,
1362
931
  in_specs=[
1363
- hbm_block_spec, # tokens_hbm
1364
- hbm_block_spec, # w1_hbm
1365
- hbm_block_spec, # w2_hbm
1366
- None
1367
- if w1_scale is None else hbm_block_spec, # w1_scale_hbm
1368
- None
1369
- if w2_scale is None else hbm_block_spec, # w2_scale_hbm
1370
- None if b1 is None else hbm_block_spec, # b1_hbm
1371
- None if b2 is None else hbm_block_spec, # b2_hbm
1372
- hbm_block_spec, # gating_output_hbm
1373
- hbm_block_spec, # a2a_g_hbm
932
+ pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
933
+ pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
934
+ pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
935
+ pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
936
+ pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
1374
937
  ],
1375
938
  out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
1376
939
  scratch_shapes=([
@@ -1421,67 +984,6 @@ def fused_ep_moe(
1421
984
  pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
1422
985
  # b_w2_x2_vmem
1423
986
  pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
1424
- # b_w1_scale_x2_vmem
1425
- (None if w1_scale is None else pltpu.VMEM(
1426
- (
1427
- 2,
1428
- t_packing,
1429
- bd1 // t_packing // subc_quant_wsz,
1430
- 1,
1431
- bf,
1432
- ),
1433
- jnp.float32,
1434
- )),
1435
- # b_w3_scale_x2_vmem
1436
- (None if w1_scale is None else pltpu.VMEM(
1437
- (
1438
- 2,
1439
- t_packing,
1440
- bd1 // t_packing // subc_quant_wsz,
1441
- 1,
1442
- bf,
1443
- ),
1444
- jnp.float32,
1445
- )),
1446
- # b_w2_scale_x2_vmem
1447
- (None if w2_scale is None else pltpu.VMEM(
1448
- (
1449
- 2,
1450
- t_packing,
1451
- bf // subc_quant_wsz,
1452
- 1,
1453
- bd2 // t_packing,
1454
- ),
1455
- jnp.float32,
1456
- )),
1457
- # b_b1_x2_vmem
1458
- (None if b1 is None else pltpu.VMEM(
1459
- (
1460
- 2,
1461
- 1,
1462
- bf,
1463
- ),
1464
- jnp.float32,
1465
- )),
1466
- # b_b3_x2_vmem
1467
- (None if b1 is None else pltpu.VMEM(
1468
- (
1469
- 2,
1470
- 1,
1471
- bf,
1472
- ),
1473
- jnp.float32,
1474
- )),
1475
- # b_b2_x2_vmem
1476
- (None if b2 is None else pltpu.VMEM(
1477
- (
1478
- 2,
1479
- t_packing,
1480
- 1,
1481
- bd2 // t_packing,
1482
- ),
1483
- jnp.float32,
1484
- )),
1485
987
  # b_acc_vmem
1486
988
  pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
1487
989
  # local_sems
@@ -1504,50 +1006,21 @@ def fused_ep_moe(
1504
1006
  ))
1505
1007
 
1506
1008
  @jax.jit
1507
- @jax.shard_map(
1009
+ @functools.partial(
1010
+ shard_map.shard_map,
1508
1011
  mesh=mesh,
1509
- in_specs=(
1510
- P(ep_axis_name), # tokens_hbm
1511
- P(ep_axis_name), # w1_hbm
1512
- P(ep_axis_name), # w2_hbm
1513
- None if w1_scale is None else P(ep_axis_name), # w1_scale_hbm
1514
- None if w2_scale is None else P(ep_axis_name), # w2_scale_hbm
1515
- None if b1 is None else P(ep_axis_name), # b1_hbm
1516
- None if b2 is None else P(ep_axis_name), # b2_hbm
1517
- P(ep_axis_name), # gating_output_hbm
1518
- P(), # a2a_g_hbm
1519
- ),
1012
+ in_specs=(P(ep_axis_name), P(ep_axis_name), P(ep_axis_name),
1013
+ P(ep_axis_name), P()),
1520
1014
  out_specs=P(ep_axis_name),
1521
- check_vma=False,
1015
+ check_rep=False,
1522
1016
  )
1523
- def kernel(
1524
- tokens,
1525
- w1,
1526
- w2,
1527
- w1_scale,
1528
- w2_scale,
1529
- b1,
1530
- b2,
1531
- gating_output,
1532
- a2a_g_hbm_scratch,
1533
- ):
1017
+ def kernel(tokens, w1, w2, gating_output, a2a_g_hbm_scratch):
1534
1018
  return fused_moe(
1535
- pltpu.with_memory_space_constraint(tokens,
1536
- pltpu.HBM), # tokens_hbm
1537
- pltpu.with_memory_space_constraint(w1, pltpu.HBM), # w1_hbm
1538
- pltpu.with_memory_space_constraint(w2, pltpu.HBM), # w2_hbm
1539
- (None if w1_scale is None else pltpu.with_memory_space_constraint(
1540
- w1_scale, pltpu.HBM)), # w1_scale_hbm
1541
- (None if w2_scale is None else pltpu.with_memory_space_constraint(
1542
- w2_scale, pltpu.HBM)), # w2_scale_hbm
1543
- (None if b1 is None else pltpu.with_memory_space_constraint(
1544
- b1, pltpu.HBM)), # b1_hbm
1545
- (None if b2 is None else pltpu.with_memory_space_constraint(
1546
- b2, pltpu.HBM)), # b2_hbm
1547
- pltpu.with_memory_space_constraint(gating_output,
1548
- pltpu.HBM), # gating_output_hbm
1549
- pltpu.with_memory_space_constraint(a2a_g_hbm_scratch,
1550
- pltpu.HBM), # a2a_g_hbm
1019
+ pltpu.with_memory_space_constraint(tokens, pltpu.HBM),
1020
+ pltpu.with_memory_space_constraint(w1, pltpu.HBM),
1021
+ pltpu.with_memory_space_constraint(w2, pltpu.HBM),
1022
+ pltpu.with_memory_space_constraint(gating_output, pltpu.HBM),
1023
+ pltpu.with_memory_space_constraint(a2a_g_hbm_scratch, pltpu.HBM),
1551
1024
  )
1552
1025
 
1553
1026
  a2a_g_hbm_scratch = pl.empty(
@@ -1556,10 +1029,6 @@ def fused_ep_moe(
1556
1029
  tokens,
1557
1030
  w1,
1558
1031
  w2,
1559
- w1_scale,
1560
- w2_scale,
1561
- b1,
1562
- b2,
1563
1032
  gating_output,
1564
1033
  a2a_g_hbm_scratch,
1565
1034
  )