tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tests/test_envs.py +32 -11
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +61 -8
- tpu_inference/executors/ray_distributed_executor.py +31 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +45 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
- tpu_inference/platforms/tpu_platform.py +28 -22
- tpu_inference/runner/compilation_manager.py +144 -59
- tpu_inference/runner/kv_cache_manager.py +17 -18
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +271 -147
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +36 -13
- tpu_inference/worker/tpu_worker.py +162 -25
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import math
|
|
3
|
-
import os
|
|
4
3
|
from dataclasses import asdict, dataclass
|
|
5
4
|
from typing import TYPE_CHECKING, List, Optional
|
|
6
5
|
|
|
@@ -8,7 +7,7 @@ import jax.numpy as jnp
|
|
|
8
7
|
import numpy as np
|
|
9
8
|
from jax.sharding import Mesh
|
|
10
9
|
|
|
11
|
-
from tpu_inference import utils
|
|
10
|
+
from tpu_inference import envs, utils
|
|
12
11
|
|
|
13
12
|
if TYPE_CHECKING:
|
|
14
13
|
from vllm.v1.configs.vllm_config import VllmConfig
|
|
@@ -48,7 +47,7 @@ class ShardingAxisName2D:
|
|
|
48
47
|
|
|
49
48
|
|
|
50
49
|
try:
|
|
51
|
-
_use_base_sharding =
|
|
50
|
+
_use_base_sharding = envs.NEW_MODEL_DESIGN
|
|
52
51
|
if _use_base_sharding:
|
|
53
52
|
ShardingAxisName = ShardingAxisNameBase
|
|
54
53
|
else:
|
|
@@ -166,9 +165,10 @@ class ShardingConfigManager:
|
|
|
166
165
|
f"LoRA is not supported with data parallelism "
|
|
167
166
|
f"(DP size: {total_dp_size}). Please disable LoRA or "
|
|
168
167
|
f"set data parallelism to 1.")
|
|
169
|
-
|
|
168
|
+
if sharding_strategy.attention_data_parallelism > 1:
|
|
169
|
+
if not envs.NEW_MODEL_DESIGN:
|
|
170
170
|
raise ValueError(
|
|
171
|
-
"Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
|
|
171
|
+
"Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
|
|
172
172
|
"NEW_MODEL_DESIGN=True.")
|
|
173
173
|
|
|
174
174
|
@property
|
|
@@ -110,7 +110,8 @@ def tensor_sharded_gmm_merged_column_parallel(
|
|
|
110
110
|
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
|
|
111
111
|
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
|
|
112
112
|
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
|
|
113
|
-
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n,
|
|
113
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
|
|
114
|
+
g)
|
|
114
115
|
|
|
115
116
|
_gmm = functools.partial(
|
|
116
117
|
gmm,
|
|
@@ -123,14 +124,26 @@ def tensor_sharded_gmm_merged_column_parallel(
|
|
|
123
124
|
gmm_result = shard_map(
|
|
124
125
|
_gmm,
|
|
125
126
|
mesh=mesh,
|
|
126
|
-
in_specs=(P(), P(None, "model", None), P()),
|
|
127
|
-
out_specs=(P(
|
|
127
|
+
in_specs=(P("data", None), P(None, "model", None), P("data")),
|
|
128
|
+
out_specs=(P("data", "model")),
|
|
128
129
|
check_rep=False,
|
|
129
130
|
)(lhs, rhs, group_sizes)
|
|
130
131
|
|
|
131
132
|
if rhs_bias is not None:
|
|
132
|
-
|
|
133
|
-
|
|
133
|
+
|
|
134
|
+
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
|
|
135
|
+
rhs_bis = jnp.repeat(rhs_bias_local,
|
|
136
|
+
group_sizes_global,
|
|
137
|
+
0,
|
|
138
|
+
total_repeat_length=m // mesh.shape["data"])
|
|
139
|
+
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
|
|
140
|
+
|
|
141
|
+
gmm_result = shard_map(
|
|
142
|
+
_add_bias,
|
|
143
|
+
mesh=mesh,
|
|
144
|
+
in_specs=(P("data", "model"), P(None, "model"), P("data")),
|
|
145
|
+
out_specs=(P("data", "model")),
|
|
146
|
+
)(gmm_result, rhs_bias, group_sizes)
|
|
134
147
|
|
|
135
148
|
n_shards = mesh.shape["model"]
|
|
136
149
|
output_sizes = [intermediate_size, intermediate_size]
|
|
@@ -150,7 +163,8 @@ def tensor_sharded_gmm_row_parallel(
|
|
|
150
163
|
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
|
|
151
164
|
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
|
|
152
165
|
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
|
|
153
|
-
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n,
|
|
166
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
|
|
167
|
+
g)
|
|
154
168
|
|
|
155
169
|
_gmm = functools.partial(
|
|
156
170
|
gmm,
|
|
@@ -167,14 +181,25 @@ def tensor_sharded_gmm_row_parallel(
|
|
|
167
181
|
gmm_result = shard_map(
|
|
168
182
|
_gmm_all_reduce,
|
|
169
183
|
mesh=mesh,
|
|
170
|
-
in_specs=(P(
|
|
171
|
-
out_specs=(P()),
|
|
184
|
+
in_specs=(P("data", "model"), P(None, None, "model"), P("data")),
|
|
185
|
+
out_specs=(P("data")),
|
|
172
186
|
check_rep=False,
|
|
173
187
|
)(lhs, rhs, group_sizes)
|
|
174
|
-
|
|
175
188
|
if rhs_bias is not None:
|
|
176
|
-
|
|
177
|
-
|
|
189
|
+
|
|
190
|
+
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
|
|
191
|
+
rhs_bis = jnp.repeat(rhs_bias_local,
|
|
192
|
+
group_sizes_global,
|
|
193
|
+
0,
|
|
194
|
+
total_repeat_length=m // mesh.shape["data"])
|
|
195
|
+
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
|
|
196
|
+
|
|
197
|
+
gmm_result = shard_map(
|
|
198
|
+
_add_bias,
|
|
199
|
+
mesh=mesh,
|
|
200
|
+
in_specs=(P("data"), P(), P("data")),
|
|
201
|
+
out_specs=(P("data")),
|
|
202
|
+
)(gmm_result, rhs_bias, group_sizes)
|
|
178
203
|
|
|
179
204
|
return gmm_result
|
|
180
205
|
|
|
@@ -366,15 +391,27 @@ def fused_moe_func(
|
|
|
366
391
|
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
|
|
367
392
|
topk_weights = topk_weights.astype(dtype)
|
|
368
393
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
394
|
+
def _process_tokens_locally(hidden_states_local, topk_indices_local):
|
|
395
|
+
num_tokens_local = hidden_states_local.shape[0]
|
|
396
|
+
topk_indices_flat = topk_indices_local.flatten()
|
|
397
|
+
topk_argsort_indices = jnp.argsort(topk_indices_flat)
|
|
398
|
+
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
|
|
399
|
+
token_indices = jnp.arange(num_tokens_local,
|
|
400
|
+
dtype=jnp.int32).repeat(topk)
|
|
401
|
+
token_indices_sorted = token_indices[topk_argsort_indices]
|
|
402
|
+
group_sizes_local = jnp.bincount(topk_indices_flat,
|
|
403
|
+
length=global_num_experts)
|
|
404
|
+
|
|
405
|
+
x = hidden_states_local[token_indices_sorted]
|
|
406
|
+
return x, group_sizes_local, topk_argsort_revert_indices
|
|
407
|
+
|
|
408
|
+
x, group_sizes, topk_argsort_revert_indices = shard_map(
|
|
409
|
+
_process_tokens_locally,
|
|
410
|
+
mesh=mesh,
|
|
411
|
+
in_specs=(P("data", None), P("data", None)),
|
|
412
|
+
out_specs=(P("data", None), P("data"), P("data")),
|
|
413
|
+
check_rep=False,
|
|
414
|
+
)(hidden_states, topk_indices)
|
|
378
415
|
if use_ep:
|
|
379
416
|
x = expert_sharded_gmm(
|
|
380
417
|
x,
|
|
@@ -411,7 +448,7 @@ def fused_moe_func(
|
|
|
411
448
|
)
|
|
412
449
|
else:
|
|
413
450
|
x = jax.lax.with_sharding_constraint(
|
|
414
|
-
x, NamedSharding(mesh, P(
|
|
451
|
+
x, NamedSharding(mesh, P("data", "model")))
|
|
415
452
|
x = tensor_sharded_gmm_row_parallel(
|
|
416
453
|
x,
|
|
417
454
|
w2,
|
|
@@ -421,13 +458,25 @@ def fused_moe_func(
|
|
|
421
458
|
mesh=mesh,
|
|
422
459
|
)
|
|
423
460
|
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
461
|
+
def _finalize_output(x_local, topk_argsort_revert_indices_local,
|
|
462
|
+
topk_weights_local):
|
|
463
|
+
x_local = x_local[topk_argsort_revert_indices_local].reshape(
|
|
464
|
+
-1, topk, hidden_size)
|
|
465
|
+
x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
|
|
466
|
+
x_local = x_local.sum(axis=-2)
|
|
467
|
+
return x_local
|
|
468
|
+
|
|
469
|
+
x = shard_map(
|
|
470
|
+
_finalize_output,
|
|
471
|
+
mesh=mesh,
|
|
472
|
+
in_specs=(P("data", None), P("data"), P("data", None)),
|
|
473
|
+
out_specs=(P("data", None)),
|
|
474
|
+
check_rep=False,
|
|
475
|
+
)(x, topk_argsort_revert_indices, topk_weights)
|
|
427
476
|
x = x.reshape(orig_shape)
|
|
428
477
|
|
|
429
478
|
if reduce_results:
|
|
430
|
-
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
|
|
479
|
+
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P("data")))
|
|
431
480
|
return x
|
|
432
481
|
|
|
433
482
|
|
|
@@ -61,7 +61,12 @@ class JaxCommonLinearConfig:
|
|
|
61
61
|
" bad performance.", type(layer))
|
|
62
62
|
|
|
63
63
|
self.bias_sharding = P(self.weight_sharding[0])
|
|
64
|
-
|
|
64
|
+
if isinstance(self.weight_sharding[0], tuple):
|
|
65
|
+
self.n_shards = 1
|
|
66
|
+
for axis in self.weight_sharding[0]:
|
|
67
|
+
self.n_shards *= self.mesh.shape.get(axis, 1)
|
|
68
|
+
else:
|
|
69
|
+
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
|
|
65
70
|
|
|
66
71
|
def get_input_sharding(self, x: torchax.tensor.Tensor):
|
|
67
72
|
if self.enable_sequence_parallelism:
|
|
@@ -24,6 +24,8 @@ from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
|
|
|
24
24
|
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
25
25
|
is_layer_skipped
|
|
26
26
|
|
|
27
|
+
from tpu_inference import envs
|
|
28
|
+
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
27
29
|
from tpu_inference.layers.common.quant_methods import (MXFP4,
|
|
28
30
|
get_tpu_quant_method)
|
|
29
31
|
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
|
|
@@ -93,7 +95,8 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
93
95
|
"UnquantizedLinearMethod.")
|
|
94
96
|
return VllmUnquantizedLinearMethod(linear_config)
|
|
95
97
|
elif isinstance(layer, FusedMoE):
|
|
96
|
-
|
|
98
|
+
moe_config = self.get_moe_config(layer)
|
|
99
|
+
return VllmMxfp4MoEMethod(moe_config, self.mesh)
|
|
97
100
|
elif isinstance(layer, Attention):
|
|
98
101
|
# TODO: Add support for MXFP4 Attention.
|
|
99
102
|
logger.warning_once("MXFP4 attention layer is not implemented. "
|
|
@@ -103,13 +106,30 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
|
103
106
|
|
|
104
107
|
class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
105
108
|
|
|
106
|
-
def __init__(self,
|
|
109
|
+
def __init__(self,
|
|
110
|
+
moe: FusedMoEConfig,
|
|
111
|
+
mesh: Mesh,
|
|
112
|
+
ep_axis_name: str = 'model'):
|
|
107
113
|
FusedMoEMethodBase.__init__(self, moe)
|
|
108
114
|
|
|
109
115
|
# We piggyback on triton implementation as it applies minimal hardware
|
|
110
116
|
# specific post processing to the weights.
|
|
111
117
|
self.mxfp4_backend = Mxfp4Backend.TRITON
|
|
118
|
+
|
|
112
119
|
self.mesh = mesh
|
|
120
|
+
self.use_kernel = envs.USE_MOE_EP_KERNEL
|
|
121
|
+
self.ep_axis_name = ep_axis_name
|
|
122
|
+
# TODO: Use autotune table once we have it.
|
|
123
|
+
self.block_size = {
|
|
124
|
+
"bt": 64,
|
|
125
|
+
"bf": 1024,
|
|
126
|
+
"bd1": 1536,
|
|
127
|
+
"bd2": 1536,
|
|
128
|
+
"btc": 64,
|
|
129
|
+
"bfc": 1024,
|
|
130
|
+
"bd1c": 1536,
|
|
131
|
+
"bd2c": 1536,
|
|
132
|
+
}
|
|
113
133
|
|
|
114
134
|
def get_fused_moe_quant_config(
|
|
115
135
|
self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
|
|
@@ -122,6 +142,7 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
122
142
|
|
|
123
143
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
124
144
|
assert isinstance(layer, FusedMoE)
|
|
145
|
+
assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
|
|
125
146
|
|
|
126
147
|
w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
|
|
127
148
|
w13_weight_scale = e8m0_to_fp32(
|
|
@@ -157,57 +178,95 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
157
178
|
w3_bias = w13_bias[:, 1::2]
|
|
158
179
|
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
159
180
|
|
|
160
|
-
|
|
161
|
-
|
|
181
|
+
if self.use_kernel and layer.use_ep:
|
|
182
|
+
# Kernel expects:
|
|
183
|
+
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
184
|
+
# w2: (num_experts, intermediate_size, hidden_size)
|
|
185
|
+
# Current format:
|
|
186
|
+
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
187
|
+
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
188
|
+
num_experts = w13_weight.shape[0]
|
|
189
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
190
|
+
hidden_size = w13_weight.shape[2]
|
|
191
|
+
|
|
192
|
+
# Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
|
|
193
|
+
w13_reshaped = w13_weight.reshape(num_experts, 2,
|
|
194
|
+
intermediate_size, hidden_size)
|
|
195
|
+
w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
|
|
196
|
+
|
|
197
|
+
# Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
|
|
198
|
+
w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
|
|
199
|
+
|
|
200
|
+
# Apply EP sharding
|
|
162
201
|
w13_weight = jax.device_put(
|
|
163
|
-
|
|
164
|
-
Format(Layout((0, 1, 2)),
|
|
165
|
-
NamedSharding(self.mesh, P("model", None, None))))
|
|
202
|
+
w13_weight_transposed,
|
|
203
|
+
Format(Layout((0, 1, 2, 3)),
|
|
204
|
+
NamedSharding(self.mesh, P("model", None, None, None))))
|
|
166
205
|
w2_weight = jax.device_put(
|
|
167
|
-
|
|
206
|
+
w2_weight_transposed,
|
|
168
207
|
Format(Layout((0, 1, 2)),
|
|
169
208
|
NamedSharding(self.mesh, P("model", None, None))))
|
|
170
209
|
|
|
171
|
-
|
|
172
|
-
w13_bias,
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
210
|
+
if self.moe.has_bias:
|
|
211
|
+
w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
|
|
212
|
+
|
|
213
|
+
# Apply EP sharding
|
|
214
|
+
w13_bias = jax.device_put(
|
|
215
|
+
w13_bias,
|
|
216
|
+
Format(Layout((0, 1, 2)),
|
|
217
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
218
|
+
w2_bias = jax.device_put(
|
|
219
|
+
w2_bias,
|
|
220
|
+
Format(Layout((0, 1)),
|
|
221
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
179
222
|
|
|
180
223
|
else:
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
224
|
+
if layer.use_ep:
|
|
225
|
+
w13_weight = jax.device_put(
|
|
226
|
+
w13_weight,
|
|
227
|
+
Format(Layout((0, 1, 2)),
|
|
228
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
229
|
+
w2_weight = jax.device_put(
|
|
230
|
+
w2_weight,
|
|
231
|
+
Format(Layout((0, 1, 2)),
|
|
232
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
233
|
+
|
|
234
|
+
w13_bias = jax.device_put(
|
|
235
|
+
w13_bias,
|
|
236
|
+
Format(Layout((0, 1)),
|
|
237
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
238
|
+
w2_bias = jax.device_put(
|
|
239
|
+
w2_bias,
|
|
240
|
+
Format(Layout((0, 1)),
|
|
241
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
242
|
+
|
|
243
|
+
else:
|
|
244
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
245
|
+
assert intermediate_size == w2_weight.shape[-1]
|
|
246
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
247
|
+
n_shards = self.mesh.shape["model"]
|
|
248
|
+
assert intermediate_size % n_shards == 0
|
|
249
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
250
|
+
w13_weight, output_sizes, n_shards, dim=1)
|
|
251
|
+
w13_weight = jax.device_put(
|
|
252
|
+
w13_weight,
|
|
253
|
+
Format(Layout((0, 1, 2)),
|
|
254
|
+
NamedSharding(self.mesh, P(None, "model", None))))
|
|
255
|
+
w2_weight = jax.device_put(
|
|
256
|
+
w2_weight,
|
|
257
|
+
Format(Layout((0, 1, 2)),
|
|
258
|
+
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
259
|
+
|
|
260
|
+
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
261
|
+
w13_bias, output_sizes, n_shards, dim=1)
|
|
262
|
+
w13_bias = jax.device_put(
|
|
263
|
+
w13_bias,
|
|
264
|
+
Format(Layout((0, 1)),
|
|
265
|
+
NamedSharding(self.mesh, P(None, "model"))))
|
|
266
|
+
w2_bias = jax.device_put(
|
|
267
|
+
w2_bias,
|
|
268
|
+
Format(Layout((0, 1)),
|
|
269
|
+
NamedSharding(self.mesh, P(None, None))))
|
|
211
270
|
|
|
212
271
|
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
213
272
|
requires_grad=False)
|
|
@@ -246,21 +305,37 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
|
246
305
|
raise NotImplementedError(
|
|
247
306
|
"Only softmax is supported for scoring_func")
|
|
248
307
|
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
308
|
+
if self.use_kernel and layer.use_ep:
|
|
309
|
+
output = fused_ep_moe(
|
|
310
|
+
mesh=self.mesh,
|
|
311
|
+
tokens=jax_view(x),
|
|
312
|
+
w1=jax_view(layer.w13_weight),
|
|
313
|
+
w2=jax_view(layer.w2_weight),
|
|
314
|
+
b1=jax_view(layer.w13_bias),
|
|
315
|
+
b2=jax_view(layer.w2_bias),
|
|
316
|
+
gating_output=jax_view(router_logits),
|
|
317
|
+
top_k=top_k,
|
|
318
|
+
ep_axis_name=self.ep_axis_name,
|
|
319
|
+
renormalize_topk_logits=renormalize,
|
|
320
|
+
act_fn=activation,
|
|
321
|
+
**self.block_size,
|
|
322
|
+
)
|
|
323
|
+
else:
|
|
324
|
+
# Use the original implementation
|
|
325
|
+
output = fused_moe_func_padded(
|
|
326
|
+
jax_view(x),
|
|
327
|
+
jax_view(layer.w13_weight),
|
|
328
|
+
jax_view(layer.w2_weight),
|
|
329
|
+
jax_view(layer.w13_bias),
|
|
330
|
+
jax_view(layer.w2_bias),
|
|
331
|
+
jax_view(router_logits),
|
|
332
|
+
topk=top_k,
|
|
333
|
+
global_num_experts=global_num_experts,
|
|
334
|
+
renormalize=renormalize,
|
|
335
|
+
reduce_results=layer.reduce_results,
|
|
336
|
+
mesh=self.mesh,
|
|
337
|
+
use_ep=layer.use_ep,
|
|
338
|
+
activation=activation,
|
|
339
|
+
)
|
|
265
340
|
|
|
266
341
|
return torch_view(output)
|