tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202511270815__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +27 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +141 -107
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +2 -1
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +135 -61
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +43 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +13 -5
- tpu_inference/platforms/tpu_platform.py +15 -2
- tpu_inference/runner/compilation_manager.py +58 -33
- tpu_inference/runner/kv_cache_manager.py +9 -3
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +203 -102
- tpu_inference/spec_decode/jax/eagle3.py +19 -2
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +5 -4
- tpu_inference/worker/tpu_worker.py +160 -23
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/RECORD +43 -48
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/top_level.txt +0 -0
|
@@ -110,7 +110,8 @@ def tensor_sharded_gmm_merged_column_parallel(
|
|
|
110
110
|
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
|
|
111
111
|
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
|
|
112
112
|
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
|
|
113
|
-
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n,
|
|
113
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
|
|
114
|
+
g)
|
|
114
115
|
|
|
115
116
|
_gmm = functools.partial(
|
|
116
117
|
gmm,
|
|
@@ -123,14 +124,26 @@ def tensor_sharded_gmm_merged_column_parallel(
|
|
|
123
124
|
gmm_result = shard_map(
|
|
124
125
|
_gmm,
|
|
125
126
|
mesh=mesh,
|
|
126
|
-
in_specs=(P(), P(None, "model", None), P()),
|
|
127
|
-
out_specs=(P(
|
|
127
|
+
in_specs=(P("data", None), P(None, "model", None), P("data")),
|
|
128
|
+
out_specs=(P("data", "model")),
|
|
128
129
|
check_rep=False,
|
|
129
130
|
)(lhs, rhs, group_sizes)
|
|
130
131
|
|
|
131
132
|
if rhs_bias is not None:
|
|
132
|
-
|
|
133
|
-
|
|
133
|
+
|
|
134
|
+
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
|
|
135
|
+
rhs_bis = jnp.repeat(rhs_bias_local,
|
|
136
|
+
group_sizes_global,
|
|
137
|
+
0,
|
|
138
|
+
total_repeat_length=m // mesh.shape["data"])
|
|
139
|
+
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
|
|
140
|
+
|
|
141
|
+
gmm_result = shard_map(
|
|
142
|
+
_add_bias,
|
|
143
|
+
mesh=mesh,
|
|
144
|
+
in_specs=(P("data", "model"), P(None, "model"), P("data")),
|
|
145
|
+
out_specs=(P("data", "model")),
|
|
146
|
+
)(gmm_result, rhs_bias, group_sizes)
|
|
134
147
|
|
|
135
148
|
n_shards = mesh.shape["model"]
|
|
136
149
|
output_sizes = [intermediate_size, intermediate_size]
|
|
@@ -150,7 +163,8 @@ def tensor_sharded_gmm_row_parallel(
|
|
|
150
163
|
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
|
|
151
164
|
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
|
|
152
165
|
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
|
|
153
|
-
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n,
|
|
166
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
|
|
167
|
+
g)
|
|
154
168
|
|
|
155
169
|
_gmm = functools.partial(
|
|
156
170
|
gmm,
|
|
@@ -167,14 +181,25 @@ def tensor_sharded_gmm_row_parallel(
|
|
|
167
181
|
gmm_result = shard_map(
|
|
168
182
|
_gmm_all_reduce,
|
|
169
183
|
mesh=mesh,
|
|
170
|
-
in_specs=(P(
|
|
171
|
-
out_specs=(P()),
|
|
184
|
+
in_specs=(P("data", "model"), P(None, None, "model"), P("data")),
|
|
185
|
+
out_specs=(P("data")),
|
|
172
186
|
check_rep=False,
|
|
173
187
|
)(lhs, rhs, group_sizes)
|
|
174
|
-
|
|
175
188
|
if rhs_bias is not None:
|
|
176
|
-
|
|
177
|
-
|
|
189
|
+
|
|
190
|
+
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
|
|
191
|
+
rhs_bis = jnp.repeat(rhs_bias_local,
|
|
192
|
+
group_sizes_global,
|
|
193
|
+
0,
|
|
194
|
+
total_repeat_length=m // mesh.shape["data"])
|
|
195
|
+
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
|
|
196
|
+
|
|
197
|
+
gmm_result = shard_map(
|
|
198
|
+
_add_bias,
|
|
199
|
+
mesh=mesh,
|
|
200
|
+
in_specs=(P("data"), P(), P("data")),
|
|
201
|
+
out_specs=(P("data")),
|
|
202
|
+
)(gmm_result, rhs_bias, group_sizes)
|
|
178
203
|
|
|
179
204
|
return gmm_result
|
|
180
205
|
|
|
@@ -366,15 +391,27 @@ def fused_moe_func(
|
|
|
366
391
|
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
|
|
367
392
|
topk_weights = topk_weights.astype(dtype)
|
|
368
393
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
394
|
+
def _process_tokens_locally(hidden_states_local, topk_indices_local):
|
|
395
|
+
num_tokens_local = hidden_states_local.shape[0]
|
|
396
|
+
topk_indices_flat = topk_indices_local.flatten()
|
|
397
|
+
topk_argsort_indices = jnp.argsort(topk_indices_flat)
|
|
398
|
+
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
|
|
399
|
+
token_indices = jnp.arange(num_tokens_local,
|
|
400
|
+
dtype=jnp.int32).repeat(topk)
|
|
401
|
+
token_indices_sorted = token_indices[topk_argsort_indices]
|
|
402
|
+
group_sizes_local = jnp.bincount(topk_indices_flat,
|
|
403
|
+
length=global_num_experts)
|
|
404
|
+
|
|
405
|
+
x = hidden_states_local[token_indices_sorted]
|
|
406
|
+
return x, group_sizes_local, topk_argsort_revert_indices
|
|
407
|
+
|
|
408
|
+
x, group_sizes, topk_argsort_revert_indices = shard_map(
|
|
409
|
+
_process_tokens_locally,
|
|
410
|
+
mesh=mesh,
|
|
411
|
+
in_specs=(P("data", None), P("data", None)),
|
|
412
|
+
out_specs=(P("data", None), P("data"), P("data")),
|
|
413
|
+
check_rep=False,
|
|
414
|
+
)(hidden_states, topk_indices)
|
|
378
415
|
if use_ep:
|
|
379
416
|
x = expert_sharded_gmm(
|
|
380
417
|
x,
|
|
@@ -411,7 +448,7 @@ def fused_moe_func(
|
|
|
411
448
|
)
|
|
412
449
|
else:
|
|
413
450
|
x = jax.lax.with_sharding_constraint(
|
|
414
|
-
x, NamedSharding(mesh, P(
|
|
451
|
+
x, NamedSharding(mesh, P("data", "model")))
|
|
415
452
|
x = tensor_sharded_gmm_row_parallel(
|
|
416
453
|
x,
|
|
417
454
|
w2,
|
|
@@ -421,13 +458,25 @@ def fused_moe_func(
|
|
|
421
458
|
mesh=mesh,
|
|
422
459
|
)
|
|
423
460
|
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
461
|
+
def _finalize_output(x_local, topk_argsort_revert_indices_local,
|
|
462
|
+
topk_weights_local):
|
|
463
|
+
x_local = x_local[topk_argsort_revert_indices_local].reshape(
|
|
464
|
+
-1, topk, hidden_size)
|
|
465
|
+
x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
|
|
466
|
+
x_local = x_local.sum(axis=-2)
|
|
467
|
+
return x_local
|
|
468
|
+
|
|
469
|
+
x = shard_map(
|
|
470
|
+
_finalize_output,
|
|
471
|
+
mesh=mesh,
|
|
472
|
+
in_specs=(P("data", None), P("data"), P("data", None)),
|
|
473
|
+
out_specs=(P("data", None)),
|
|
474
|
+
check_rep=False,
|
|
475
|
+
)(x, topk_argsort_revert_indices, topk_weights)
|
|
427
476
|
x = x.reshape(orig_shape)
|
|
428
477
|
|
|
429
478
|
if reduce_results:
|
|
430
|
-
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
|
|
479
|
+
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P("data")))
|
|
431
480
|
return x
|
|
432
481
|
|
|
433
482
|
|
|
@@ -61,7 +61,12 @@ class JaxCommonLinearConfig:
|
|
|
61
61
|
" bad performance.", type(layer))
|
|
62
62
|
|
|
63
63
|
self.bias_sharding = P(self.weight_sharding[0])
|
|
64
|
-
|
|
64
|
+
if isinstance(self.weight_sharding[0], tuple):
|
|
65
|
+
self.n_shards = 1
|
|
66
|
+
for axis in self.weight_sharding[0]:
|
|
67
|
+
self.n_shards *= self.mesh.shape.get(axis, 1)
|
|
68
|
+
else:
|
|
69
|
+
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
|
|
65
70
|
|
|
66
71
|
def get_input_sharding(self, x: torchax.tensor.Tensor):
|
|
67
72
|
if self.enable_sequence_parallelism:
|
|
@@ -24,6 +24,8 @@ from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
|
|
|
24
24
|
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
25
25
|
is_layer_skipped
|
|
26
26
|
|
|
27
|
+
from tpu_inference import envs
|
|
28
|
+
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
27
29
|
from tpu_inference.layers.common.quant_methods import (MXFP4,
|
|
28
30
|
get_tpu_quant_method)
|
|
29
31
|
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
|
|
@@ -103,13 +105,30 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
103
105
|
|
|
104
106
|
class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
105
107
|
|
|
106
|
-
def __init__(self,
|
|
108
|
+
def __init__(self,
|
|
109
|
+
moe: FusedMoEConfig,
|
|
110
|
+
mesh: Mesh,
|
|
111
|
+
ep_axis_name: str = 'model'):
|
|
107
112
|
FusedMoEMethodBase.__init__(self, moe)
|
|
108
113
|
|
|
109
114
|
# We piggyback on triton implementation as it applies minimal hardware
|
|
110
115
|
# specific post processing to the weights.
|
|
111
116
|
self.mxfp4_backend = Mxfp4Backend.TRITON
|
|
117
|
+
|
|
112
118
|
self.mesh = mesh
|
|
119
|
+
self.use_kernel = envs.USE_MOE_EP_KERNEL
|
|
120
|
+
self.ep_axis_name = ep_axis_name
|
|
121
|
+
# TODO: Use autotune table once we have it.
|
|
122
|
+
self.block_size = {
|
|
123
|
+
"bt": 64,
|
|
124
|
+
"bf": 1024,
|
|
125
|
+
"bd1": 1536,
|
|
126
|
+
"bd2": 1536,
|
|
127
|
+
"btc": 64,
|
|
128
|
+
"bfc": 1024,
|
|
129
|
+
"bd1c": 1536,
|
|
130
|
+
"bd2c": 1536,
|
|
131
|
+
}
|
|
113
132
|
|
|
114
133
|
def get_fused_moe_quant_config(
|
|
115
134
|
self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
|
|
@@ -122,6 +141,7 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
122
141
|
|
|
123
142
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
124
143
|
assert isinstance(layer, FusedMoE)
|
|
144
|
+
assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
|
|
125
145
|
|
|
126
146
|
w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
|
|
127
147
|
w13_weight_scale = e8m0_to_fp32(
|
|
@@ -157,57 +177,95 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
157
177
|
w3_bias = w13_bias[:, 1::2]
|
|
158
178
|
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
159
179
|
|
|
160
|
-
|
|
161
|
-
|
|
180
|
+
if self.use_kernel and layer.use_ep:
|
|
181
|
+
# Kernel expects:
|
|
182
|
+
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
183
|
+
# w2: (num_experts, intermediate_size, hidden_size)
|
|
184
|
+
# Current format:
|
|
185
|
+
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
186
|
+
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
187
|
+
num_experts = w13_weight.shape[0]
|
|
188
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
189
|
+
hidden_size = w13_weight.shape[2]
|
|
190
|
+
|
|
191
|
+
# Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
|
|
192
|
+
w13_reshaped = w13_weight.reshape(num_experts, 2,
|
|
193
|
+
intermediate_size, hidden_size)
|
|
194
|
+
w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
|
|
195
|
+
|
|
196
|
+
# Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
|
|
197
|
+
w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
|
|
198
|
+
|
|
199
|
+
# Apply EP sharding
|
|
162
200
|
w13_weight = jax.device_put(
|
|
163
|
-
|
|
164
|
-
Format(Layout((0, 1, 2)),
|
|
165
|
-
NamedSharding(self.mesh, P("model", None, None))))
|
|
201
|
+
w13_weight_transposed,
|
|
202
|
+
Format(Layout((0, 1, 2, 3)),
|
|
203
|
+
NamedSharding(self.mesh, P("model", None, None, None))))
|
|
166
204
|
w2_weight = jax.device_put(
|
|
167
|
-
|
|
205
|
+
w2_weight_transposed,
|
|
168
206
|
Format(Layout((0, 1, 2)),
|
|
169
207
|
NamedSharding(self.mesh, P("model", None, None))))
|
|
170
208
|
|
|
171
|
-
|
|
172
|
-
w13_bias,
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
209
|
+
if self.moe.has_bias:
|
|
210
|
+
w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
|
|
211
|
+
|
|
212
|
+
# Apply EP sharding
|
|
213
|
+
w13_bias = jax.device_put(
|
|
214
|
+
w13_bias,
|
|
215
|
+
Format(Layout((0, 1, 2)),
|
|
216
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
217
|
+
w2_bias = jax.device_put(
|
|
218
|
+
w2_bias,
|
|
219
|
+
Format(Layout((0, 1)),
|
|
220
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
179
221
|
|
|
180
222
|
else:
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
223
|
+
if layer.use_ep:
|
|
224
|
+
w13_weight = jax.device_put(
|
|
225
|
+
w13_weight,
|
|
226
|
+
Format(Layout((0, 1, 2)),
|
|
227
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
228
|
+
w2_weight = jax.device_put(
|
|
229
|
+
w2_weight,
|
|
230
|
+
Format(Layout((0, 1, 2)),
|
|
231
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
232
|
+
|
|
233
|
+
w13_bias = jax.device_put(
|
|
234
|
+
w13_bias,
|
|
235
|
+
Format(Layout((0, 1)),
|
|
236
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
237
|
+
w2_bias = jax.device_put(
|
|
238
|
+
w2_bias,
|
|
239
|
+
Format(Layout((0, 1)),
|
|
240
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
241
|
+
|
|
242
|
+
else:
|
|
243
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
244
|
+
assert intermediate_size == w2_weight.shape[-1]
|
|
245
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
246
|
+
n_shards = self.mesh.shape["model"]
|
|
247
|
+
assert intermediate_size % n_shards == 0
|
|
248
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
249
|
+
w13_weight, output_sizes, n_shards, dim=1)
|
|
250
|
+
w13_weight = jax.device_put(
|
|
251
|
+
w13_weight,
|
|
252
|
+
Format(Layout((0, 1, 2)),
|
|
253
|
+
NamedSharding(self.mesh, P(None, "model", None))))
|
|
254
|
+
w2_weight = jax.device_put(
|
|
255
|
+
w2_weight,
|
|
256
|
+
Format(Layout((0, 1, 2)),
|
|
257
|
+
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
258
|
+
|
|
259
|
+
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
260
|
+
w13_bias, output_sizes, n_shards, dim=1)
|
|
261
|
+
w13_bias = jax.device_put(
|
|
262
|
+
w13_bias,
|
|
263
|
+
Format(Layout((0, 1)),
|
|
264
|
+
NamedSharding(self.mesh, P(None, "model"))))
|
|
265
|
+
w2_bias = jax.device_put(
|
|
266
|
+
w2_bias,
|
|
267
|
+
Format(Layout((0, 1)),
|
|
268
|
+
NamedSharding(self.mesh, P(None, None))))
|
|
211
269
|
|
|
212
270
|
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
213
271
|
requires_grad=False)
|
|
@@ -246,21 +304,37 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
246
304
|
raise NotImplementedError(
|
|
247
305
|
"Only softmax is supported for scoring_func")
|
|
248
306
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
307
|
+
if self.use_kernel and layer.use_ep:
|
|
308
|
+
output = fused_ep_moe(
|
|
309
|
+
mesh=self.mesh,
|
|
310
|
+
tokens=jax_view(x),
|
|
311
|
+
w1=jax_view(layer.w13_weight),
|
|
312
|
+
w2=jax_view(layer.w2_weight),
|
|
313
|
+
b1=jax_view(layer.w13_bias),
|
|
314
|
+
b2=jax_view(layer.w2_bias),
|
|
315
|
+
gating_output=jax_view(router_logits),
|
|
316
|
+
top_k=top_k,
|
|
317
|
+
ep_axis_name=self.ep_axis_name,
|
|
318
|
+
renormalize_topk_logits=renormalize,
|
|
319
|
+
act_fn=activation,
|
|
320
|
+
**self.block_size,
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
# Use the original implementation
|
|
324
|
+
output = fused_moe_func_padded(
|
|
325
|
+
jax_view(x),
|
|
326
|
+
jax_view(layer.w13_weight),
|
|
327
|
+
jax_view(layer.w2_weight),
|
|
328
|
+
jax_view(layer.w13_bias),
|
|
329
|
+
jax_view(layer.w2_bias),
|
|
330
|
+
jax_view(router_logits),
|
|
331
|
+
topk=top_k,
|
|
332
|
+
global_num_experts=global_num_experts,
|
|
333
|
+
renormalize=renormalize,
|
|
334
|
+
reduce_results=layer.reduce_results,
|
|
335
|
+
mesh=self.mesh,
|
|
336
|
+
use_ep=layer.use_ep,
|
|
337
|
+
activation=activation,
|
|
338
|
+
)
|
|
265
339
|
|
|
266
340
|
return torch_view(output)
|