tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202511270815__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (49) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/lora/test_layers.py +0 -6
  3. tests/lora/utils.py +0 -8
  4. tpu_inference/__init__.py +22 -3
  5. tpu_inference/core/disagg_utils.py +6 -8
  6. tpu_inference/distributed/tpu_connector.py +2 -3
  7. tpu_inference/distributed/utils.py +3 -2
  8. tpu_inference/envs.py +1 -1
  9. tpu_inference/executors/ray_distributed_executor.py +27 -11
  10. tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
  11. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  12. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +141 -107
  13. tpu_inference/layers/common/attention_interface.py +7 -1
  14. tpu_inference/layers/common/sharding.py +2 -1
  15. tpu_inference/layers/vllm/fused_moe.py +74 -25
  16. tpu_inference/layers/vllm/quantization/common.py +6 -1
  17. tpu_inference/layers/vllm/quantization/mxfp4.py +135 -61
  18. tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
  19. tpu_inference/layers/vllm/sharding.py +2 -2
  20. tpu_inference/lora/torch_punica_tpu.py +1 -2
  21. tpu_inference/models/common/model_loader.py +43 -11
  22. tpu_inference/models/jax/llama3.py +2 -1
  23. tpu_inference/models/jax/llama_eagle3.py +8 -5
  24. tpu_inference/models/jax/llama_guard_4.py +361 -0
  25. tpu_inference/models/jax/qwen2.py +2 -1
  26. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  27. tpu_inference/models/jax/qwen3.py +2 -1
  28. tpu_inference/models/jax/utils/weight_utils.py +198 -143
  29. tpu_inference/models/vllm/vllm_model_wrapper.py +13 -5
  30. tpu_inference/platforms/tpu_platform.py +15 -2
  31. tpu_inference/runner/compilation_manager.py +58 -33
  32. tpu_inference/runner/kv_cache_manager.py +9 -3
  33. tpu_inference/runner/structured_decoding_manager.py +2 -3
  34. tpu_inference/runner/tpu_runner.py +203 -102
  35. tpu_inference/spec_decode/jax/eagle3.py +19 -2
  36. tpu_inference/tpu_info.py +4 -3
  37. tpu_inference/utils.py +5 -4
  38. tpu_inference/worker/tpu_worker.py +160 -23
  39. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/METADATA +3 -2
  40. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/RECORD +43 -48
  41. tpu_inference/mock/__init__.py +0 -0
  42. tpu_inference/mock/vllm_config_utils.py +0 -28
  43. tpu_inference/mock/vllm_envs.py +0 -1219
  44. tpu_inference/mock/vllm_logger.py +0 -212
  45. tpu_inference/mock/vllm_logging_utils.py +0 -15
  46. tpu_inference/models/jax/phi3.py +0 -376
  47. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/WHEEL +0 -0
  48. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/licenses/LICENSE +0 -0
  49. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/top_level.txt +0 -0
@@ -110,7 +110,8 @@ def tensor_sharded_gmm_merged_column_parallel(
110
110
  # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
111
111
  m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
112
112
  n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
113
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
113
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
114
+ g)
114
115
 
115
116
  _gmm = functools.partial(
116
117
  gmm,
@@ -123,14 +124,26 @@ def tensor_sharded_gmm_merged_column_parallel(
123
124
  gmm_result = shard_map(
124
125
  _gmm,
125
126
  mesh=mesh,
126
- in_specs=(P(), P(None, "model", None), P()),
127
- out_specs=(P(None, "model")),
127
+ in_specs=(P("data", None), P(None, "model", None), P("data")),
128
+ out_specs=(P("data", "model")),
128
129
  check_rep=False,
129
130
  )(lhs, rhs, group_sizes)
130
131
 
131
132
  if rhs_bias is not None:
132
- rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
133
- gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype)
133
+
134
+ def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
135
+ rhs_bis = jnp.repeat(rhs_bias_local,
136
+ group_sizes_global,
137
+ 0,
138
+ total_repeat_length=m // mesh.shape["data"])
139
+ return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
140
+
141
+ gmm_result = shard_map(
142
+ _add_bias,
143
+ mesh=mesh,
144
+ in_specs=(P("data", "model"), P(None, "model"), P("data")),
145
+ out_specs=(P("data", "model")),
146
+ )(gmm_result, rhs_bias, group_sizes)
134
147
 
135
148
  n_shards = mesh.shape["model"]
136
149
  output_sizes = [intermediate_size, intermediate_size]
@@ -150,7 +163,8 @@ def tensor_sharded_gmm_row_parallel(
150
163
  # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
151
164
  m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
152
165
  n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
153
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
166
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
167
+ g)
154
168
 
155
169
  _gmm = functools.partial(
156
170
  gmm,
@@ -167,14 +181,25 @@ def tensor_sharded_gmm_row_parallel(
167
181
  gmm_result = shard_map(
168
182
  _gmm_all_reduce,
169
183
  mesh=mesh,
170
- in_specs=(P(None, "model"), P(None, None, "model"), P()),
171
- out_specs=(P()),
184
+ in_specs=(P("data", "model"), P(None, None, "model"), P("data")),
185
+ out_specs=(P("data")),
172
186
  check_rep=False,
173
187
  )(lhs, rhs, group_sizes)
174
-
175
188
  if rhs_bias is not None:
176
- rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
177
- gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)
189
+
190
+ def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
191
+ rhs_bis = jnp.repeat(rhs_bias_local,
192
+ group_sizes_global,
193
+ 0,
194
+ total_repeat_length=m // mesh.shape["data"])
195
+ return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
196
+
197
+ gmm_result = shard_map(
198
+ _add_bias,
199
+ mesh=mesh,
200
+ in_specs=(P("data"), P(), P("data")),
201
+ out_specs=(P("data")),
202
+ )(gmm_result, rhs_bias, group_sizes)
178
203
 
179
204
  return gmm_result
180
205
 
@@ -366,15 +391,27 @@ def fused_moe_func(
366
391
  topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
367
392
  topk_weights = topk_weights.astype(dtype)
368
393
 
369
- topk_indices_flat = topk_indices.flatten()
370
- topk_argsort_indices = jnp.argsort(topk_indices_flat)
371
- topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
372
- token_indices = jnp.arange(num_tokens, dtype=jnp.int32).repeat(topk)
373
- token_indices_sorted = token_indices[topk_argsort_indices]
374
- group_sizes = jnp.bincount(topk_indices_flat, length=global_num_experts)
375
-
376
- x = hidden_states[token_indices_sorted]
377
-
394
+ def _process_tokens_locally(hidden_states_local, topk_indices_local):
395
+ num_tokens_local = hidden_states_local.shape[0]
396
+ topk_indices_flat = topk_indices_local.flatten()
397
+ topk_argsort_indices = jnp.argsort(topk_indices_flat)
398
+ topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
399
+ token_indices = jnp.arange(num_tokens_local,
400
+ dtype=jnp.int32).repeat(topk)
401
+ token_indices_sorted = token_indices[topk_argsort_indices]
402
+ group_sizes_local = jnp.bincount(topk_indices_flat,
403
+ length=global_num_experts)
404
+
405
+ x = hidden_states_local[token_indices_sorted]
406
+ return x, group_sizes_local, topk_argsort_revert_indices
407
+
408
+ x, group_sizes, topk_argsort_revert_indices = shard_map(
409
+ _process_tokens_locally,
410
+ mesh=mesh,
411
+ in_specs=(P("data", None), P("data", None)),
412
+ out_specs=(P("data", None), P("data"), P("data")),
413
+ check_rep=False,
414
+ )(hidden_states, topk_indices)
378
415
  if use_ep:
379
416
  x = expert_sharded_gmm(
380
417
  x,
@@ -411,7 +448,7 @@ def fused_moe_func(
411
448
  )
412
449
  else:
413
450
  x = jax.lax.with_sharding_constraint(
414
- x, NamedSharding(mesh, P(None, "model")))
451
+ x, NamedSharding(mesh, P("data", "model")))
415
452
  x = tensor_sharded_gmm_row_parallel(
416
453
  x,
417
454
  w2,
@@ -421,13 +458,25 @@ def fused_moe_func(
421
458
  mesh=mesh,
422
459
  )
423
460
 
424
- x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
425
- x = x * jnp.expand_dims(topk_weights, axis=-1)
426
- x = x.sum(axis=-2)
461
+ def _finalize_output(x_local, topk_argsort_revert_indices_local,
462
+ topk_weights_local):
463
+ x_local = x_local[topk_argsort_revert_indices_local].reshape(
464
+ -1, topk, hidden_size)
465
+ x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
466
+ x_local = x_local.sum(axis=-2)
467
+ return x_local
468
+
469
+ x = shard_map(
470
+ _finalize_output,
471
+ mesh=mesh,
472
+ in_specs=(P("data", None), P("data"), P("data", None)),
473
+ out_specs=(P("data", None)),
474
+ check_rep=False,
475
+ )(x, topk_argsort_revert_indices, topk_weights)
427
476
  x = x.reshape(orig_shape)
428
477
 
429
478
  if reduce_results:
430
- x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
479
+ x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P("data")))
431
480
  return x
432
481
 
433
482
 
@@ -61,7 +61,12 @@ class JaxCommonLinearConfig:
61
61
  " bad performance.", type(layer))
62
62
 
63
63
  self.bias_sharding = P(self.weight_sharding[0])
64
- self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
64
+ if isinstance(self.weight_sharding[0], tuple):
65
+ self.n_shards = 1
66
+ for axis in self.weight_sharding[0]:
67
+ self.n_shards *= self.mesh.shape.get(axis, 1)
68
+ else:
69
+ self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
65
70
 
66
71
  def get_input_sharding(self, x: torchax.tensor.Tensor):
67
72
  if self.enable_sequence_parallelism:
@@ -24,6 +24,8 @@ from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
24
24
  from vllm.model_executor.layers.quantization.utils.quant_utils import \
25
25
  is_layer_skipped
26
26
 
27
+ from tpu_inference import envs
28
+ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
27
29
  from tpu_inference.layers.common.quant_methods import (MXFP4,
28
30
  get_tpu_quant_method)
29
31
  from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
@@ -103,13 +105,30 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
103
105
 
104
106
  class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
105
107
 
106
- def __init__(self, moe: FusedMoEConfig, mesh: Mesh):
108
+ def __init__(self,
109
+ moe: FusedMoEConfig,
110
+ mesh: Mesh,
111
+ ep_axis_name: str = 'model'):
107
112
  FusedMoEMethodBase.__init__(self, moe)
108
113
 
109
114
  # We piggyback on triton implementation as it applies minimal hardware
110
115
  # specific post processing to the weights.
111
116
  self.mxfp4_backend = Mxfp4Backend.TRITON
117
+
112
118
  self.mesh = mesh
119
+ self.use_kernel = envs.USE_MOE_EP_KERNEL
120
+ self.ep_axis_name = ep_axis_name
121
+ # TODO: Use autotune table once we have it.
122
+ self.block_size = {
123
+ "bt": 64,
124
+ "bf": 1024,
125
+ "bd1": 1536,
126
+ "bd2": 1536,
127
+ "btc": 64,
128
+ "bfc": 1024,
129
+ "bd1c": 1536,
130
+ "bd2c": 1536,
131
+ }
113
132
 
114
133
  def get_fused_moe_quant_config(
115
134
  self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
@@ -122,6 +141,7 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
122
141
 
123
142
  def process_weights_after_loading(self, layer: torch.nn.Module):
124
143
  assert isinstance(layer, FusedMoE)
144
+ assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
125
145
 
126
146
  w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
127
147
  w13_weight_scale = e8m0_to_fp32(
@@ -157,57 +177,95 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
157
177
  w3_bias = w13_bias[:, 1::2]
158
178
  w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
159
179
 
160
- # TODO(kyuyeunk): Add weight processing logic for the new kernel.
161
- if layer.use_ep:
180
+ if self.use_kernel and layer.use_ep:
181
+ # Kernel expects:
182
+ # w13: (num_experts, 2, hidden_size, intermediate_size)
183
+ # w2: (num_experts, intermediate_size, hidden_size)
184
+ # Current format:
185
+ # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
186
+ # w2_weight: (num_experts, hidden_size, intermediate_size)
187
+ num_experts = w13_weight.shape[0]
188
+ intermediate_size = w13_weight.shape[1] // 2
189
+ hidden_size = w13_weight.shape[2]
190
+
191
+ # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
192
+ w13_reshaped = w13_weight.reshape(num_experts, 2,
193
+ intermediate_size, hidden_size)
194
+ w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
195
+
196
+ # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
197
+ w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
198
+
199
+ # Apply EP sharding
162
200
  w13_weight = jax.device_put(
163
- w13_weight,
164
- Format(Layout((0, 1, 2)),
165
- NamedSharding(self.mesh, P("model", None, None))))
201
+ w13_weight_transposed,
202
+ Format(Layout((0, 1, 2, 3)),
203
+ NamedSharding(self.mesh, P("model", None, None, None))))
166
204
  w2_weight = jax.device_put(
167
- w2_weight,
205
+ w2_weight_transposed,
168
206
  Format(Layout((0, 1, 2)),
169
207
  NamedSharding(self.mesh, P("model", None, None))))
170
208
 
171
- w13_bias = jax.device_put(
172
- w13_bias,
173
- Format(Layout((0, 1)),
174
- NamedSharding(self.mesh, P("model", None))))
175
- w2_bias = jax.device_put(
176
- w2_bias,
177
- Format(Layout((0, 1)),
178
- NamedSharding(self.mesh, P("model", None))))
209
+ if self.moe.has_bias:
210
+ w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
211
+
212
+ # Apply EP sharding
213
+ w13_bias = jax.device_put(
214
+ w13_bias,
215
+ Format(Layout((0, 1, 2)),
216
+ NamedSharding(self.mesh, P("model", None, None))))
217
+ w2_bias = jax.device_put(
218
+ w2_bias,
219
+ Format(Layout((0, 1)),
220
+ NamedSharding(self.mesh, P("model", None))))
179
221
 
180
222
  else:
181
- intermediate_size = w13_weight.shape[1] // 2
182
- assert intermediate_size == w2_weight.shape[-1]
183
- output_sizes = [intermediate_size, intermediate_size]
184
- n_shards = self.mesh.shape["model"]
185
- assert intermediate_size % n_shards == 0
186
- w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
187
- output_sizes,
188
- n_shards,
189
- dim=1)
190
- w13_weight = jax.device_put(
191
- w13_weight,
192
- Format(Layout((0, 1, 2)),
193
- NamedSharding(self.mesh, P(None, "model", None))))
194
- w2_weight = jax.device_put(
195
- w2_weight,
196
- Format(Layout((0, 1, 2)),
197
- NamedSharding(self.mesh, P(None, None, "model"))))
198
-
199
- w13_bias = reorder_concatenated_tensor_for_sharding(w13_bias,
200
- output_sizes,
201
- n_shards,
202
- dim=1)
203
- w13_bias = jax.device_put(
204
- w13_bias,
205
- Format(Layout((0, 1)),
206
- NamedSharding(self.mesh, P(None, "model"))))
207
- w2_bias = jax.device_put(
208
- w2_bias,
209
- Format(Layout((0, 1)), NamedSharding(self.mesh, P(None,
210
- None))))
223
+ if layer.use_ep:
224
+ w13_weight = jax.device_put(
225
+ w13_weight,
226
+ Format(Layout((0, 1, 2)),
227
+ NamedSharding(self.mesh, P("model", None, None))))
228
+ w2_weight = jax.device_put(
229
+ w2_weight,
230
+ Format(Layout((0, 1, 2)),
231
+ NamedSharding(self.mesh, P("model", None, None))))
232
+
233
+ w13_bias = jax.device_put(
234
+ w13_bias,
235
+ Format(Layout((0, 1)),
236
+ NamedSharding(self.mesh, P("model", None))))
237
+ w2_bias = jax.device_put(
238
+ w2_bias,
239
+ Format(Layout((0, 1)),
240
+ NamedSharding(self.mesh, P("model", None))))
241
+
242
+ else:
243
+ intermediate_size = w13_weight.shape[1] // 2
244
+ assert intermediate_size == w2_weight.shape[-1]
245
+ output_sizes = [intermediate_size, intermediate_size]
246
+ n_shards = self.mesh.shape["model"]
247
+ assert intermediate_size % n_shards == 0
248
+ w13_weight = reorder_concatenated_tensor_for_sharding(
249
+ w13_weight, output_sizes, n_shards, dim=1)
250
+ w13_weight = jax.device_put(
251
+ w13_weight,
252
+ Format(Layout((0, 1, 2)),
253
+ NamedSharding(self.mesh, P(None, "model", None))))
254
+ w2_weight = jax.device_put(
255
+ w2_weight,
256
+ Format(Layout((0, 1, 2)),
257
+ NamedSharding(self.mesh, P(None, None, "model"))))
258
+
259
+ w13_bias = reorder_concatenated_tensor_for_sharding(
260
+ w13_bias, output_sizes, n_shards, dim=1)
261
+ w13_bias = jax.device_put(
262
+ w13_bias,
263
+ Format(Layout((0, 1)),
264
+ NamedSharding(self.mesh, P(None, "model"))))
265
+ w2_bias = jax.device_put(
266
+ w2_bias,
267
+ Format(Layout((0, 1)),
268
+ NamedSharding(self.mesh, P(None, None))))
211
269
 
212
270
  layer.w13_weight = Parameter(torch_view(w13_weight),
213
271
  requires_grad=False)
@@ -246,21 +304,37 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
246
304
  raise NotImplementedError(
247
305
  "Only softmax is supported for scoring_func")
248
306
 
249
- # Use the original implementation
250
- output = fused_moe_func_padded(
251
- jax_view(x),
252
- jax_view(layer.w13_weight),
253
- jax_view(layer.w2_weight),
254
- jax_view(layer.w13_bias) if self.moe.has_bias else None,
255
- jax_view(layer.w2_bias) if self.moe.has_bias else None,
256
- jax_view(router_logits),
257
- topk=top_k,
258
- global_num_experts=global_num_experts,
259
- renormalize=renormalize,
260
- reduce_results=layer.reduce_results,
261
- mesh=self.mesh,
262
- use_ep=layer.use_ep,
263
- activation=activation,
264
- )
307
+ if self.use_kernel and layer.use_ep:
308
+ output = fused_ep_moe(
309
+ mesh=self.mesh,
310
+ tokens=jax_view(x),
311
+ w1=jax_view(layer.w13_weight),
312
+ w2=jax_view(layer.w2_weight),
313
+ b1=jax_view(layer.w13_bias),
314
+ b2=jax_view(layer.w2_bias),
315
+ gating_output=jax_view(router_logits),
316
+ top_k=top_k,
317
+ ep_axis_name=self.ep_axis_name,
318
+ renormalize_topk_logits=renormalize,
319
+ act_fn=activation,
320
+ **self.block_size,
321
+ )
322
+ else:
323
+ # Use the original implementation
324
+ output = fused_moe_func_padded(
325
+ jax_view(x),
326
+ jax_view(layer.w13_weight),
327
+ jax_view(layer.w2_weight),
328
+ jax_view(layer.w13_bias),
329
+ jax_view(layer.w2_bias),
330
+ jax_view(router_logits),
331
+ topk=top_k,
332
+ global_num_experts=global_num_experts,
333
+ renormalize=renormalize,
334
+ reduce_results=layer.reduce_results,
335
+ mesh=self.mesh,
336
+ use_ep=layer.use_ep,
337
+ activation=activation,
338
+ )
265
339
 
266
340
  return torch_view(output)