tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (56) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_envs.py +11 -32
  6. tests/test_utils.py +2 -1
  7. tpu_inference/__init__.py +3 -22
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +2 -9
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
  16. tpu_inference/layers/common/attention_interface.py +1 -7
  17. tpu_inference/layers/common/sharding.py +5 -5
  18. tpu_inference/layers/vllm/fused_moe.py +208 -170
  19. tpu_inference/layers/vllm/quantization/common.py +1 -6
  20. tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
  21. tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
  22. tpu_inference/layers/vllm/sharding.py +2 -2
  23. tpu_inference/lora/torch_punica_tpu.py +2 -1
  24. tpu_inference/mock/__init__.py +0 -0
  25. tpu_inference/mock/vllm_config_utils.py +28 -0
  26. tpu_inference/mock/vllm_envs.py +1219 -0
  27. tpu_inference/mock/vllm_logger.py +212 -0
  28. tpu_inference/mock/vllm_logging_utils.py +15 -0
  29. tpu_inference/models/common/model_loader.py +10 -43
  30. tpu_inference/models/jax/llama3.py +1 -2
  31. tpu_inference/models/jax/llama_eagle3.py +5 -8
  32. tpu_inference/models/jax/phi3.py +376 -0
  33. tpu_inference/models/jax/qwen2.py +1 -2
  34. tpu_inference/models/jax/qwen2_5_vl.py +48 -163
  35. tpu_inference/models/jax/qwen3.py +1 -2
  36. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  37. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  38. tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
  39. tpu_inference/platforms/tpu_platform.py +31 -37
  40. tpu_inference/runner/compilation_manager.py +58 -141
  41. tpu_inference/runner/kv_cache.py +1 -1
  42. tpu_inference/runner/kv_cache_manager.py +18 -17
  43. tpu_inference/runner/persistent_batch_manager.py +2 -40
  44. tpu_inference/runner/structured_decoding_manager.py +3 -2
  45. tpu_inference/runner/tpu_runner.py +147 -271
  46. tpu_inference/runner/utils.py +2 -2
  47. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  48. tpu_inference/tpu_info.py +3 -4
  49. tpu_inference/utils.py +13 -36
  50. tpu_inference/worker/tpu_worker.py +25 -162
  51. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
  52. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
  53. tpu_inference/models/jax/llama_guard_4.py +0 -361
  54. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
  55. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
  56. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
@@ -24,11 +24,9 @@ from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
24
24
  from vllm.model_executor.layers.quantization.utils.quant_utils import \
25
25
  is_layer_skipped
26
26
 
27
- from tpu_inference import envs
28
- from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
29
27
  from tpu_inference.layers.common.quant_methods import (MXFP4,
30
28
  get_tpu_quant_method)
31
- from tpu_inference.layers.vllm.fused_moe import fused_moe_func
29
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
32
30
  from tpu_inference.layers.vllm.linear_common import \
33
31
  reorder_concatenated_tensor_for_sharding
34
32
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
@@ -87,14 +85,17 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
87
85
  fused_mapping=self.packed_modules_mapping,
88
86
  ):
89
87
  return VllmUnquantizedLinearMethod(linear_config)
88
+ # TODO: Add support for MXFP4 Linear Method.
89
+ # MXFP4 LinearMethod is available in AMD-Quark, refer to that
90
+ # implementation if you are interested in enabling MXFP4 here.
90
91
  logger.warning_once(
91
92
  "MXFP4 linear layer is not implemented - falling back to "
92
93
  "UnquantizedLinearMethod.")
93
94
  return VllmUnquantizedLinearMethod(linear_config)
94
95
  elif isinstance(layer, FusedMoE):
95
- moe_config = self.get_moe_config(layer)
96
- return VllmMxfp4MoEMethod(moe_config, self.mesh)
96
+ return VllmMxfp4MoEMethod(layer.moe_config, self.mesh)
97
97
  elif isinstance(layer, Attention):
98
+ # TODO: Add support for MXFP4 Attention.
98
99
  logger.warning_once("MXFP4 attention layer is not implemented. "
99
100
  "Skipping quantization for this layer.")
100
101
  return None
@@ -102,30 +103,13 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
102
103
 
103
104
  class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
104
105
 
105
- def __init__(self,
106
- moe: FusedMoEConfig,
107
- mesh: Mesh,
108
- ep_axis_name: str = 'model'):
106
+ def __init__(self, moe: FusedMoEConfig, mesh: Mesh):
109
107
  FusedMoEMethodBase.__init__(self, moe)
110
108
 
111
109
  # We piggyback on triton implementation as it applies minimal hardware
112
110
  # specific post processing to the weights.
113
111
  self.mxfp4_backend = Mxfp4Backend.TRITON
114
-
115
112
  self.mesh = mesh
116
- self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
117
- self.ep_axis_name = ep_axis_name
118
- # TODO: Use autotune table once we have it.
119
- self.block_size = {
120
- "bt": 64,
121
- "bf": 1024,
122
- "bd1": 1536,
123
- "bd2": 1536,
124
- "btc": 64,
125
- "bfc": 1024,
126
- "bd1c": 1536,
127
- "bd2c": 1536,
128
- }
129
113
 
130
114
  def get_fused_moe_quant_config(
131
115
  self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
@@ -138,7 +122,6 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
138
122
 
139
123
  def process_weights_after_loading(self, layer: torch.nn.Module):
140
124
  assert isinstance(layer, FusedMoE)
141
- assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
142
125
 
143
126
  w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
144
127
  w13_weight_scale = e8m0_to_fp32(
@@ -157,8 +140,6 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
157
140
  w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
158
141
  MXFP4_BLOCK_SIZE, jnp.bfloat16)
159
142
 
160
- num_experts, hidden_size, intermediate_size = w2_weight.shape
161
-
162
143
  # Because we have dequantized weights, scales are not used anymore.
163
144
  delattr(layer, "w13_weight_scale")
164
145
  delattr(layer, "w2_weight_scale")
@@ -176,89 +157,63 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
176
157
  w3_bias = w13_bias[:, 1::2]
177
158
  w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
178
159
 
179
- if self.use_kernel:
180
- # Kernel expects:
181
- # w13: (num_experts, 2, hidden_size, intermediate_size)
182
- # w2: (num_experts, intermediate_size, hidden_size)
183
- # Current format:
184
- # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
185
- # w2_weight: (num_experts, hidden_size, intermediate_size)
186
-
187
- w13_reshaped = w13_weight.reshape(num_experts, 2,
188
- intermediate_size, hidden_size)
189
-
190
- # Transpose non-constracting dim to right most dim
191
- w13_weight_transposed = jnp.swapaxes(w13_reshaped, 2, 3)
192
- w2_weight_transposed = jnp.swapaxes(w2_weight, 1, 2)
193
-
194
- # Apply EP sharding
195
- ep_sharding = NamedSharding(self.mesh, P("model"))
196
-
160
+ # TODO(kyuyeunk): Add weight processing logic for the new kernel.
161
+ if layer.use_ep:
197
162
  w13_weight = jax.device_put(
198
- w13_weight_transposed, Format(Layout((0, 1, 2, 3)),
199
- ep_sharding))
200
- w2_weight = jax.device_put(w2_weight_transposed,
201
- Format(Layout((0, 1, 2)), ep_sharding))
202
-
203
- w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
204
- w13_bias = jax.device_put(w13_bias,
205
- Format(Layout((0, 1, 2)), ep_sharding))
206
- w2_bias = jax.device_put(w2_bias,
207
- Format(Layout((0, 1)), ep_sharding))
163
+ w13_weight,
164
+ Format(Layout((0, 1, 2)),
165
+ NamedSharding(self.mesh, P("model", None, None))))
166
+ w2_weight = jax.device_put(
167
+ w2_weight,
168
+ Format(Layout((0, 1, 2)),
169
+ NamedSharding(self.mesh, P("model", None, None))))
170
+
171
+ w13_bias = jax.device_put(
172
+ w13_bias,
173
+ Format(Layout((0, 1)),
174
+ NamedSharding(self.mesh, P("model", None))))
175
+ w2_bias = jax.device_put(
176
+ w2_bias,
177
+ Format(Layout((0, 1)),
178
+ NamedSharding(self.mesh, P("model", None))))
208
179
 
209
180
  else:
210
- if layer.use_ep:
211
- ep_sharding = NamedSharding(self.mesh, P("model"))
212
- w13_weight = jax.device_put(
213
- w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
214
- w2_weight = jax.device_put(
215
- w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
216
-
217
- w13_bias = jax.device_put(w13_bias,
218
- Format(Layout((0, 1)), ep_sharding))
219
- w2_bias = jax.device_put(w2_bias,
220
- Format(Layout((0, 1)), ep_sharding))
221
-
222
- else:
223
- output_sizes = [intermediate_size, intermediate_size]
224
- n_shards = self.mesh.shape["model"]
225
- assert intermediate_size % n_shards == 0
226
-
227
- w13_weight = reorder_concatenated_tensor_for_sharding(
228
- w13_weight,
229
- output_sizes,
230
- n_shards,
231
- dim=1,
232
- )
233
- w13_weight = jax.device_put(
234
- w13_weight,
235
- Format(Layout((0, 1, 2)),
236
- NamedSharding(self.mesh, P(None, "model", None))))
237
- w2_weight = jax.device_put(
238
- w2_weight,
239
- Format(Layout((0, 1, 2)),
240
- NamedSharding(self.mesh, P(None, None, "model"))))
241
-
242
- w13_bias = reorder_concatenated_tensor_for_sharding(
243
- w13_bias,
244
- output_sizes,
245
- n_shards,
246
- dim=1,
247
- )
248
- w13_bias = jax.device_put(
249
- w13_bias,
250
- Format(Layout((0, 1)),
251
- NamedSharding(self.mesh, P(None, "model"))))
252
- w2_bias = jax.device_put(
253
- w2_bias,
254
- Format(Layout((0, 1)),
255
- NamedSharding(self.mesh, P(None, None))))
181
+ intermediate_size = w13_weight.shape[1] // 2
182
+ assert intermediate_size == w2_weight.shape[-1]
183
+ output_sizes = [intermediate_size, intermediate_size]
184
+ n_shards = self.mesh.shape["model"]
185
+ assert intermediate_size % n_shards == 0
186
+ w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
187
+ output_sizes,
188
+ n_shards,
189
+ dim=1)
190
+ w13_weight = jax.device_put(
191
+ w13_weight,
192
+ Format(Layout((0, 1, 2)),
193
+ NamedSharding(self.mesh, P(None, "model", None))))
194
+ w2_weight = jax.device_put(
195
+ w2_weight,
196
+ Format(Layout((0, 1, 2)),
197
+ NamedSharding(self.mesh, P(None, None, "model"))))
198
+
199
+ w13_bias = reorder_concatenated_tensor_for_sharding(w13_bias,
200
+ output_sizes,
201
+ n_shards,
202
+ dim=1)
203
+ w13_bias = jax.device_put(
204
+ w13_bias,
205
+ Format(Layout((0, 1)),
206
+ NamedSharding(self.mesh, P(None, "model"))))
207
+ w2_bias = jax.device_put(
208
+ w2_bias,
209
+ Format(Layout((0, 1)), NamedSharding(self.mesh, P(None,
210
+ None))))
256
211
 
257
212
  layer.w13_weight = Parameter(torch_view(w13_weight),
258
213
  requires_grad=False)
259
- layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
260
-
261
214
  layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
215
+
216
+ layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
262
217
  layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
263
218
 
264
219
  pass
@@ -291,41 +246,21 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
291
246
  raise NotImplementedError(
292
247
  "Only softmax is supported for scoring_func")
293
248
 
294
- x = jax_view(x)
295
- w13_weight = jax_view(layer.w13_weight)
296
- w2_weight = jax_view(layer.w2_weight)
297
- w13_bias = jax_view(layer.w13_bias)
298
- w2_bias = jax_view(layer.w2_bias)
299
- gating_output = jax_view(router_logits)
300
-
301
- if self.use_kernel:
302
- output = fused_ep_moe(
303
- mesh=self.mesh,
304
- tokens=x,
305
- w1=w13_weight,
306
- w2=w2_weight,
307
- b1=w13_bias,
308
- b2=w2_bias,
309
- gating_output=gating_output,
310
- top_k=top_k,
311
- ep_axis_name=self.ep_axis_name,
312
- renormalize_topk_logits=renormalize,
313
- act_fn=activation,
314
- **self.block_size,
315
- )
316
- else:
317
- output = fused_moe_func(
318
- hidden_states=x,
319
- w1=w13_weight,
320
- w2=w2_weight,
321
- w1_bias=w13_bias,
322
- w2_bias=w2_bias,
323
- gating_output=gating_output,
324
- topk=top_k,
325
- renormalize=renormalize,
326
- mesh=self.mesh,
327
- use_ep=layer.use_ep,
328
- activation=activation,
329
- )
249
+ # Use the original implementation
250
+ output = fused_moe_func_padded(
251
+ jax_view(x),
252
+ jax_view(layer.w13_weight),
253
+ jax_view(layer.w2_weight),
254
+ jax_view(layer.w13_bias) if self.moe.has_bias else None,
255
+ jax_view(layer.w2_bias) if self.moe.has_bias else None,
256
+ jax_view(router_logits),
257
+ topk=top_k,
258
+ global_num_experts=global_num_experts,
259
+ renormalize=renormalize,
260
+ reduce_results=layer.reduce_results,
261
+ mesh=self.mesh,
262
+ use_ep=layer.use_ep,
263
+ activation=activation,
264
+ )
330
265
 
331
266
  return torch_view(output)
@@ -25,7 +25,7 @@ from tpu_inference import envs
25
25
  from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
26
26
  from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
27
27
  get_tpu_quant_method)
28
- from tpu_inference.layers.vllm.fused_moe import fused_moe_func
28
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
29
29
  from tpu_inference.layers.vllm.linear_common import (
30
30
  reorder_concatenated_tensor_for_sharding,
31
31
  slice_sharded_tensor_for_concatenation, torch_to_jax_param)
@@ -108,8 +108,6 @@ class VllmUnquantizedLinearMethod(UnquantizedLinearMethod):
108
108
  layer: torch.nn.Module,
109
109
  x: torch.Tensor,
110
110
  bias: Optional[torch.Tensor] = None) -> torch.Tensor:
111
- assert isinstance(layer, LinearBase)
112
-
113
111
  with jax.named_scope(layer._get_name()):
114
112
  if in_sharding := self.jax_config.get_input_sharding(x):
115
113
  x.shard_(NamedSharding(self.jax_config.mesh, in_sharding))
@@ -168,18 +166,18 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
168
166
  ep_axis_name: str = 'model'):
169
167
  super().__init__(moe)
170
168
  self.mesh = mesh
171
- self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
169
+ self.use_kernel = envs.USE_MOE_EP_KERNEL
172
170
  self.ep_axis_name = ep_axis_name
173
171
  # TODO: Use autotune table once we have it.
174
172
  self.block_size = {
175
- "bt": 64,
176
- "bf": 1024,
177
- "bd1": 1536,
178
- "bd2": 1536,
179
- "btc": 64,
180
- "bfc": 1024,
181
- "bd1c": 1536,
182
- "bd2c": 1536,
173
+ "bt": 16,
174
+ "bf": 384,
175
+ "bd1": 512,
176
+ "bd2": 512,
177
+ "btc": 16,
178
+ "bfc": 384,
179
+ "bd1c": 256,
180
+ "bd2c": 256,
183
181
  }
184
182
 
185
183
  def select_gemm_impl(
@@ -196,8 +194,6 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
196
194
  w13_weight = t2j(layer.w13_weight, use_dlpack=False)
197
195
  w2_weight = t2j(layer.w2_weight, use_dlpack=False)
198
196
 
199
- num_experts, hidden_size, intermediate_size = w2_weight.shape
200
-
201
197
  if self.moe.has_bias:
202
198
  w13_bias = t2j(layer.w13_bias, use_dlpack=False)
203
199
  w2_bias = t2j(layer.w2_bias, use_dlpack=False)
@@ -216,56 +212,76 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
216
212
  w3_bias = w13_bias[:, 1::2]
217
213
  w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
218
214
 
219
- if self.use_kernel:
215
+ if self.use_kernel and layer.use_ep:
220
216
  # Kernel expects:
221
217
  # w13: (num_experts, 2, hidden_size, intermediate_size)
222
218
  # w2: (num_experts, intermediate_size, hidden_size)
223
219
  # Current format:
224
220
  # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
225
221
  # w2_weight: (num_experts, hidden_size, intermediate_size)
222
+ num_experts = w13_weight.shape[0]
223
+ intermediate_size = w13_weight.shape[1] // 2
224
+ hidden_size = w13_weight.shape[2]
226
225
 
226
+ # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
227
227
  w13_reshaped = w13_weight.reshape(num_experts, 2,
228
228
  intermediate_size, hidden_size)
229
+ w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
229
230
 
230
- # Transpose non-constracting dim to right most dim
231
- w13_weight_transposed = jnp.swapaxes(w13_reshaped, 2, 3)
232
- w2_weight_transposed = jnp.swapaxes(w2_weight, 1, 2)
231
+ # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
232
+ w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
233
233
 
234
234
  # Apply EP sharding
235
- ep_sharding = NamedSharding(self.mesh, P("model"))
236
-
237
235
  w13_weight = jax.device_put(
238
- w13_weight_transposed, Format(Layout((0, 1, 2, 3)),
239
- ep_sharding))
240
- w2_weight = jax.device_put(w2_weight_transposed,
241
- Format(Layout((0, 1, 2)), ep_sharding))
236
+ w13_weight_transposed,
237
+ Format(Layout((0, 1, 2, 3)),
238
+ NamedSharding(self.mesh, P("model", None, None, None))))
239
+ w2_weight = jax.device_put(
240
+ w2_weight_transposed,
241
+ Format(Layout((0, 1, 2)),
242
+ NamedSharding(self.mesh, P("model", None, None))))
242
243
 
243
244
  if self.moe.has_bias:
244
245
  w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
246
+
247
+ # Apply EP sharding
245
248
  w13_bias = jax.device_put(
246
- w13_bias, Format(Layout((0, 1, 2)), ep_sharding))
247
- w2_bias = jax.device_put(w2_bias,
248
- Format(Layout((0, 1)), ep_sharding))
249
- else:
249
+ w13_bias,
250
+ Format(Layout((0, 1, 2)),
251
+ NamedSharding(self.mesh, P("model", None, None))))
252
+ w2_bias = jax.device_put(
253
+ w2_bias,
254
+ Format(Layout((0, 1)),
255
+ NamedSharding(self.mesh, P("model", None))))
250
256
 
257
+ else:
258
+ # Original logic for non-kernel path
251
259
  if layer.use_ep:
252
- ep_sharding = NamedSharding(self.mesh, P("model"))
253
260
  w13_weight = jax.device_put(
254
- w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
261
+ w13_weight,
262
+ Format(Layout((0, 1, 2)),
263
+ NamedSharding(self.mesh, P("model", None, None))))
255
264
  w2_weight = jax.device_put(
256
- w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
265
+ w2_weight,
266
+ Format(Layout((0, 1, 2)),
267
+ NamedSharding(self.mesh, P("model", None, None))))
257
268
 
258
269
  if self.moe.has_bias:
259
270
  w13_bias = jax.device_put(
260
- w13_bias, Format(Layout((0, 1)), ep_sharding))
271
+ w13_bias,
272
+ Format(Layout((0, 1)),
273
+ NamedSharding(self.mesh, P("model", None))))
261
274
  w2_bias = jax.device_put(
262
- w2_bias, Format(Layout((0, 1)), ep_sharding))
275
+ w2_bias,
276
+ Format(Layout((0, 1)),
277
+ NamedSharding(self.mesh, P("model", None))))
263
278
 
264
279
  else:
280
+ intermediate_size = w13_weight.shape[1] // 2
281
+ assert intermediate_size == w2_weight.shape[-1]
265
282
  output_sizes = [intermediate_size, intermediate_size]
266
283
  n_shards = self.mesh.shape["model"]
267
284
  assert intermediate_size % n_shards == 0
268
-
269
285
  w13_weight = reorder_concatenated_tensor_for_sharding(
270
286
  w13_weight, output_sizes, n_shards, dim=1)
271
287
  w13_weight = jax.device_put(
@@ -326,40 +342,30 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
326
342
  raise NotImplementedError(
327
343
  "Only softmax is supported for scoring_func")
328
344
 
329
- x = jax_view(x)
330
- w13_weight = jax_view(layer.w13_weight)
331
- w2_weight = jax_view(layer.w2_weight)
332
- w13_bias = w2_bias = None
333
- if self.moe.has_bias:
334
- w13_bias = jax_view(layer.w13_bias)
335
- w2_bias = jax_view(layer.w2_bias)
336
- gating_output = jax_view(router_logits)
337
-
338
345
  if self.use_kernel and layer.use_ep:
339
346
  output = fused_ep_moe(
340
347
  mesh=self.mesh,
341
- tokens=x,
342
- w1=w13_weight,
343
- w2=w2_weight,
344
- b1=w13_bias,
345
- b2=w2_bias,
346
- gating_output=gating_output,
348
+ tokens=jax_view(x),
349
+ w1=jax_view(layer.w13_weight),
350
+ w2=jax_view(layer.w2_weight),
351
+ gating_output=jax_view(router_logits),
347
352
  top_k=top_k,
348
353
  ep_axis_name=self.ep_axis_name,
349
- renormalize_topk_logits=renormalize,
350
- act_fn=activation,
351
354
  **self.block_size,
352
355
  )
353
356
  else:
354
- output = fused_moe_func(
355
- hidden_states=x,
356
- w1=w13_weight,
357
- w2=w2_weight,
358
- w1_bias=w13_bias,
359
- w2_bias=w2_bias,
360
- gating_output=gating_output,
357
+ # Use the original implementation
358
+ output = fused_moe_func_padded(
359
+ jax_view(x),
360
+ jax_view(layer.w13_weight),
361
+ jax_view(layer.w2_weight),
362
+ jax_view(layer.w13_bias) if self.moe.has_bias else None,
363
+ jax_view(layer.w2_bias) if self.moe.has_bias else None,
364
+ jax_view(router_logits),
361
365
  topk=top_k,
366
+ global_num_experts=global_num_experts,
362
367
  renormalize=renormalize,
368
+ reduce_results=layer.reduce_results,
363
369
  mesh=self.mesh,
364
370
  use_ep=layer.use_ep,
365
371
  activation=activation,
@@ -19,7 +19,6 @@ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
19
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
20
20
  ParallelLMHead, VocabParallelEmbedding)
21
21
 
22
- from tpu_inference import envs
23
22
  from tpu_inference.logger import init_logger
24
23
 
25
24
  P = PartitionSpec
@@ -212,7 +211,8 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
212
211
  def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
213
212
  if isinstance(tensor, tuple):
214
213
  return tuple(_sharded_device_put(t, sharding) for t in tensor)
215
- multihost_backend = envs.TPU_MULTIHOST_BACKEND
214
+ import os
215
+ multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
216
216
  if multihost_backend != "ray":
217
217
  return jax.device_put(tensor, sharding)
218
218
 
@@ -239,6 +239,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
239
239
  lora_index_to_id: list[Optional[int]],
240
240
  max_loras: int,
241
241
  vocab_size: int,
242
+ extra_vocab_size: int,
242
243
  ):
243
244
  # Pad the prompt mapping to avoid running into recompiles on the TPU
244
245
  # TODO: Should this happen inside mapping internally? If so how can we
@@ -257,7 +258,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
257
258
  lora_index_to_id,
258
259
  max_loras,
259
260
  vocab_size,
260
- 0, # extra_vocab_size
261
+ extra_vocab_size,
261
262
  "cpu",
262
263
  )
263
264
  with torchax.default_env():
File without changes
@@ -0,0 +1,28 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, List, Mapping
3
+
4
+
5
+ @dataclass
6
+ class ModelConfig():
7
+ max_model_len: int = 2048
8
+ max_prefill_len: int = 1024
9
+ prefill_batch_size: int = 1
10
+ decode_batch_size: int = 1
11
+ block_size: int = 16
12
+ num_layers: int = 32
13
+ num_kv_heads: int = 32
14
+ head_dim: int = 128
15
+ vocab_size: int = 32000
16
+ model: str = "llama3"
17
+ hf_config: str = ""
18
+ architectures: List[str] = field(default_factory=list)
19
+ override_generation_config: dict[str, Any] = field(default_factory=dict)
20
+ hf_overrides: dict[str, Any] = field(default_factory=dict)
21
+
22
+
23
+ @dataclass
24
+ class VllmConfig():
25
+ additional_config: Mapping[str, Any] = field(default_factory=dict)
26
+ # Set default max_model_len to turn off warnings.
27
+ model_config: ModelConfig = field(
28
+ default_factory=lambda: ModelConfig(max_model_len=1024))