sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +7 -0
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +16 -1
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mooncake/conn.py +16 -0
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +13 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +132 -79
- sglang/srt/function_call/ebnf_composer.py +10 -3
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/qwen3_coder_detector.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +14 -3
- sglang/srt/layers/moe/ep_moe/layer.py +323 -242
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +90 -24
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +27 -10
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/lora/lora_registry.py +93 -29
- sglang/srt/managers/cache_controller.py +9 -7
- sglang/srt/managers/data_parallel_controller.py +4 -0
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +14 -8
- sglang/srt/managers/scheduler.py +64 -1
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/tokenizer_manager.py +80 -15
- sglang/srt/managers/tp_worker.py +8 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +83 -27
- sglang/srt/models/deepseek_v2.py +75 -84
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/qwen3_moe.py +17 -71
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +65 -6
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/srt/utils.py +96 -1
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +118 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import logging
|
2
|
-
from typing import List, Optional, Tuple
|
4
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
3
5
|
|
4
6
|
import torch
|
5
7
|
|
@@ -30,13 +32,13 @@ from sglang.srt.layers.quantization.base_config import (
|
|
30
32
|
QuantizationConfig,
|
31
33
|
QuantizeMethodBase,
|
32
34
|
)
|
33
|
-
from sglang.srt.layers.quantization.fp8 import
|
35
|
+
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
34
36
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
35
37
|
is_fp8_fnuz,
|
36
38
|
sglang_per_token_group_quant_fp8,
|
37
39
|
sglang_per_token_quant_fp8,
|
38
40
|
)
|
39
|
-
from sglang.srt.layers.quantization.unquant import
|
41
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
40
42
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
41
43
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
42
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -47,23 +49,40 @@ from sglang.srt.utils import (
|
|
47
49
|
get_bool_env_var,
|
48
50
|
is_hip,
|
49
51
|
is_npu,
|
52
|
+
next_power_of_2,
|
50
53
|
)
|
51
54
|
|
55
|
+
if TYPE_CHECKING:
|
56
|
+
from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
|
57
|
+
DeepEPLLOutput,
|
58
|
+
DeepEPNormalOutput,
|
59
|
+
DispatchOutput,
|
60
|
+
)
|
61
|
+
|
52
62
|
_is_hip = is_hip()
|
53
63
|
_is_npu = is_npu()
|
54
64
|
_is_fp8_fnuz = is_fp8_fnuz()
|
55
65
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
66
|
+
use_flashinfer_trtllm_moe = (
|
67
|
+
global_server_args_dict["enable_flashinfer_trtllm_moe"]
|
68
|
+
and global_server_args_dict["enable_ep_moe"]
|
69
|
+
)
|
56
70
|
|
57
71
|
if not (_is_npu or _is_hip):
|
58
72
|
from sgl_kernel import silu_and_mul
|
59
73
|
|
60
|
-
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
61
|
-
|
62
74
|
if _use_aiter:
|
63
75
|
from aiter import ActivationType, QuantType
|
64
76
|
from aiter.fused_moe import fused_moe
|
65
77
|
from aiter.ops.shuffle import shuffle_weight
|
66
78
|
|
79
|
+
if use_flashinfer_trtllm_moe:
|
80
|
+
try:
|
81
|
+
import flashinfer.fused_moe as fi_fused_moe
|
82
|
+
except ImportError:
|
83
|
+
fi_fused_moe = None
|
84
|
+
use_flashinfer_trtllm_moe = False
|
85
|
+
|
67
86
|
logger = logging.getLogger(__name__)
|
68
87
|
|
69
88
|
|
@@ -140,7 +159,17 @@ class GroupedGemmRunner(torch.nn.Module):
|
|
140
159
|
return c
|
141
160
|
|
142
161
|
|
143
|
-
|
162
|
+
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
163
|
+
# Guess tokens per expert assuming perfect expert distribution first.
|
164
|
+
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
165
|
+
# And pad the number to the next power of 2.
|
166
|
+
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
167
|
+
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
168
|
+
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
169
|
+
return tile_tokens_dim
|
170
|
+
|
171
|
+
|
172
|
+
class EPMoE(FusedMoE):
|
144
173
|
"""
|
145
174
|
MoE Expert Parallel Impl
|
146
175
|
|
@@ -162,51 +191,60 @@ class EPMoE(torch.nn.Module):
|
|
162
191
|
routed_scaling_factor: Optional[float] = None,
|
163
192
|
use_per_token_if_dynamic: bool = True,
|
164
193
|
):
|
165
|
-
super().__init__(
|
194
|
+
super().__init__(
|
195
|
+
num_experts=num_experts,
|
196
|
+
hidden_size=hidden_size,
|
197
|
+
intermediate_size=intermediate_size,
|
198
|
+
top_k=top_k,
|
199
|
+
layer_id=layer_id,
|
200
|
+
params_dtype=params_dtype,
|
201
|
+
quant_config=quant_config,
|
202
|
+
tp_size=tp_size,
|
203
|
+
prefix=prefix,
|
204
|
+
activation=activation,
|
205
|
+
routed_scaling_factor=routed_scaling_factor,
|
206
|
+
enable_ep_moe=True,
|
207
|
+
skip_quant=True,
|
208
|
+
)
|
166
209
|
|
167
210
|
if params_dtype is None:
|
168
211
|
params_dtype = torch.get_default_dtype()
|
169
212
|
|
170
|
-
self.tp_size = (
|
171
|
-
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
172
|
-
)
|
173
|
-
self.tp_rank = get_tensor_model_parallel_rank()
|
174
|
-
|
175
213
|
self.layer_id = layer_id
|
176
|
-
self.
|
177
|
-
|
178
|
-
self.
|
179
|
-
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
180
|
-
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
214
|
+
self.num_local_experts, self.expert_map = self.determine_expert_map()
|
215
|
+
self.start_expert_id = self.ep_rank * self.num_local_experts
|
216
|
+
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
181
217
|
|
182
|
-
self.top_k = top_k
|
183
218
|
self.intermediate_size = intermediate_size
|
184
|
-
self.activation = activation
|
185
|
-
self.routed_scaling_factor = routed_scaling_factor
|
186
219
|
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
187
220
|
|
221
|
+
# TODO(ch-wan): move quant preparation to FusedMoE
|
188
222
|
if quant_config is None:
|
189
|
-
self.quant_method: Optional[QuantizeMethodBase] =
|
223
|
+
self.quant_method: Optional[QuantizeMethodBase] = (
|
224
|
+
UnquantizedFusedMoEMethod()
|
225
|
+
)
|
190
226
|
self.use_fp8_w8a8 = False
|
191
227
|
self.use_block_quant = False
|
192
228
|
self.block_shape = None
|
193
229
|
self.activation_scheme = None
|
194
|
-
self.
|
230
|
+
self.w13_input_scale = None
|
231
|
+
self.w2_input_scale = None
|
232
|
+
self.w13_weight_scale = None
|
233
|
+
self.w2_weight_scale = None
|
195
234
|
elif isinstance(quant_config, W4AFp8Config):
|
196
235
|
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
|
197
236
|
quant_config
|
198
237
|
)
|
199
|
-
self.use_w4afp8 = True
|
200
238
|
self.use_fp8_w8a8 = False
|
201
239
|
self.use_block_quant = False
|
202
240
|
self.fp8_dtype = torch.float8_e4m3fn
|
241
|
+
self.w13_input_scale = None
|
242
|
+
self.w2_input_scale = None
|
203
243
|
self.w13_weight_scale = None
|
204
244
|
self.w2_weight_scale = None
|
205
245
|
self.activation_scheme = quant_config.moe_activation_scheme
|
206
|
-
|
207
|
-
self.quant_method: Optional[QuantizeMethodBase] =
|
208
|
-
quant_config
|
209
|
-
)
|
246
|
+
elif isinstance(quant_config, Fp8Config):
|
247
|
+
self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
|
210
248
|
self.use_fp8_w8a8 = True
|
211
249
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
212
250
|
self.block_shape = (
|
@@ -216,11 +254,13 @@ class EPMoE(torch.nn.Module):
|
|
216
254
|
)
|
217
255
|
self.fp8_dtype = torch.float8_e4m3fn
|
218
256
|
self.activation_scheme = quant_config.activation_scheme
|
219
|
-
|
257
|
+
else:
|
258
|
+
raise ValueError(f"Unsupported quant_config: {quant_config}")
|
220
259
|
|
260
|
+
self.quant_config = quant_config
|
221
261
|
self.quant_method.create_weights(
|
222
262
|
layer=self,
|
223
|
-
|
263
|
+
num_experts=self.num_local_experts,
|
224
264
|
hidden_size=hidden_size,
|
225
265
|
intermediate_size=self.intermediate_size,
|
226
266
|
params_dtype=params_dtype,
|
@@ -229,19 +269,6 @@ class EPMoE(torch.nn.Module):
|
|
229
269
|
|
230
270
|
self.grouped_gemm_runner = None
|
231
271
|
|
232
|
-
self.w13_weight_fp8 = (
|
233
|
-
self.w13_weight,
|
234
|
-
(
|
235
|
-
self.w13_weight_scale_inv
|
236
|
-
if self.use_block_quant
|
237
|
-
else self.w13_weight_scale
|
238
|
-
),
|
239
|
-
)
|
240
|
-
self.w2_weight_fp8 = (
|
241
|
-
self.w2_weight,
|
242
|
-
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
243
|
-
)
|
244
|
-
|
245
272
|
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
|
246
273
|
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
|
247
274
|
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
|
@@ -260,8 +287,8 @@ class EPMoE(torch.nn.Module):
|
|
260
287
|
Contains global_num_experts for experts not assigned to the current rank.
|
261
288
|
Returns None if ep_size is 1.
|
262
289
|
"""
|
263
|
-
ep_size = self.
|
264
|
-
ep_rank = self.
|
290
|
+
ep_size = self.ep_size
|
291
|
+
ep_rank = self.ep_rank
|
265
292
|
global_num_experts = self.num_experts
|
266
293
|
|
267
294
|
assert ep_size > 0
|
@@ -271,7 +298,7 @@ class EPMoE(torch.nn.Module):
|
|
271
298
|
local_num_experts = global_num_experts // ep_size
|
272
299
|
|
273
300
|
expert_map = torch.full(
|
274
|
-
(global_num_experts,),
|
301
|
+
(global_num_experts,), global_num_experts, dtype=torch.int32
|
275
302
|
)
|
276
303
|
if ep_rank < (ep_size - 1):
|
277
304
|
expert_map[
|
@@ -296,6 +323,20 @@ class EPMoE(torch.nn.Module):
|
|
296
323
|
hidden_states: torch.Tensor,
|
297
324
|
topk_output: TopKOutput,
|
298
325
|
):
|
326
|
+
|
327
|
+
self.w13_weight_fp8 = (
|
328
|
+
self.w13_weight,
|
329
|
+
(
|
330
|
+
self.w13_weight_scale_inv
|
331
|
+
if self.use_block_quant
|
332
|
+
else self.w13_weight_scale
|
333
|
+
),
|
334
|
+
)
|
335
|
+
self.w2_weight_fp8 = (
|
336
|
+
self.w2_weight,
|
337
|
+
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
338
|
+
)
|
339
|
+
|
299
340
|
assert self.quant_method is not None
|
300
341
|
assert self.activation == "silu"
|
301
342
|
hidden_states_shape = hidden_states.shape
|
@@ -435,7 +476,10 @@ class EPMoE(torch.nn.Module):
|
|
435
476
|
return output
|
436
477
|
|
437
478
|
def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
438
|
-
|
479
|
+
return self.quant_method.apply(self, hidden_states, topk_output)
|
480
|
+
|
481
|
+
def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
482
|
+
|
439
483
|
topk_weights, topk_ids, _ = topk_output
|
440
484
|
|
441
485
|
hidden_states_shape = hidden_states.shape
|
@@ -448,53 +492,11 @@ class EPMoE(torch.nn.Module):
|
|
448
492
|
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
449
493
|
)
|
450
494
|
|
451
|
-
|
452
|
-
local_topk_ids = topk_ids
|
453
|
-
if self.expert_map is not None:
|
454
|
-
"Translate info from expert_map to topk_ids"
|
455
|
-
local_topk_ids = torch.where(
|
456
|
-
self.expert_map[topk_ids] != self.num_experts,
|
457
|
-
self.expert_map[topk_ids],
|
458
|
-
self.num_experts,
|
459
|
-
)
|
460
|
-
|
461
|
-
output = cutlass_w4a8_moe(
|
462
|
-
self.start_expert_id,
|
463
|
-
self.end_expert_id,
|
464
|
-
self.num_experts,
|
465
|
-
hidden_states,
|
466
|
-
self.w13_weight,
|
467
|
-
self.w2_weight,
|
468
|
-
self.w13_weight_scale_inv,
|
469
|
-
self.w2_weight_scale_inv,
|
470
|
-
topk_weights,
|
471
|
-
topk_ids,
|
472
|
-
local_topk_ids,
|
473
|
-
self.quant_method.a_strides1,
|
474
|
-
self.quant_method.b_strides1,
|
475
|
-
self.quant_method.c_strides1,
|
476
|
-
self.quant_method.a_strides2,
|
477
|
-
self.quant_method.b_strides2,
|
478
|
-
self.quant_method.c_strides2,
|
479
|
-
self.quant_method.s_strides13,
|
480
|
-
self.quant_method.s_strides2,
|
481
|
-
self.quant_method.expert_offsets,
|
482
|
-
self.quant_method.problem_sizes1,
|
483
|
-
self.quant_method.problem_sizes2,
|
484
|
-
self.w13_input_scale,
|
485
|
-
self.w2_input_scale,
|
486
|
-
)
|
487
|
-
return output
|
488
|
-
|
489
|
-
if self.grouped_gemm_runner is None:
|
490
|
-
self.grouped_gemm_runner = GroupedGemmRunner(
|
491
|
-
hidden_states.device,
|
492
|
-
use_flashinfer=False, # TODO: use flashinfer
|
493
|
-
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
494
|
-
)
|
495
|
+
num_experts = self.num_experts
|
495
496
|
|
496
497
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
497
|
-
topk_ids,
|
498
|
+
topk_ids,
|
499
|
+
num_experts,
|
498
500
|
)
|
499
501
|
|
500
502
|
gateup_input = torch.empty(
|
@@ -502,7 +504,7 @@ class EPMoE(torch.nn.Module):
|
|
502
504
|
device=hidden_states.device,
|
503
505
|
dtype=(
|
504
506
|
self.fp8_dtype
|
505
|
-
if
|
507
|
+
if self.use_fp8_w8a8 and not self.use_block_quant
|
506
508
|
else hidden_states.dtype
|
507
509
|
),
|
508
510
|
)
|
@@ -513,7 +515,7 @@ class EPMoE(torch.nn.Module):
|
|
513
515
|
else:
|
514
516
|
max_value = (
|
515
517
|
torch.max(hidden_states)
|
516
|
-
.repeat(self.
|
518
|
+
.repeat(self.num_local_experts)
|
517
519
|
.to(torch.float32)
|
518
520
|
)
|
519
521
|
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
@@ -554,7 +556,7 @@ class EPMoE(torch.nn.Module):
|
|
554
556
|
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
555
557
|
weight_indices_cur_rank = torch.arange(
|
556
558
|
0,
|
557
|
-
self.
|
559
|
+
self.num_local_experts,
|
558
560
|
device=hidden_states_device,
|
559
561
|
dtype=torch.int64,
|
560
562
|
)
|
@@ -564,17 +566,13 @@ class EPMoE(torch.nn.Module):
|
|
564
566
|
b=self.w13_weight,
|
565
567
|
c=None,
|
566
568
|
c_dtype=hidden_states_dtype,
|
567
|
-
batch_size=self.
|
569
|
+
batch_size=self.num_local_experts,
|
568
570
|
weight_column_major=True,
|
569
571
|
seg_indptr=seg_indptr_cur_rank,
|
570
572
|
weight_indices=weight_indices_cur_rank,
|
571
573
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
572
574
|
scale_a=self.w13_input_scale,
|
573
|
-
scale_b=
|
574
|
-
self.w13_weight_scale_inv
|
575
|
-
if self.use_block_quant
|
576
|
-
else self.w13_weight_scale
|
577
|
-
),
|
575
|
+
scale_b=self.w13_weight_scale,
|
578
576
|
block_shape=self.block_shape,
|
579
577
|
)
|
580
578
|
del gateup_input
|
@@ -631,7 +629,7 @@ class EPMoE(torch.nn.Module):
|
|
631
629
|
down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
|
632
630
|
else:
|
633
631
|
self.w2_input_scale = torch.ones(
|
634
|
-
self.
|
632
|
+
self.num_local_experts,
|
635
633
|
dtype=torch.float32,
|
636
634
|
device=hidden_states_device,
|
637
635
|
)
|
@@ -647,17 +645,13 @@ class EPMoE(torch.nn.Module):
|
|
647
645
|
a=down_input,
|
648
646
|
b=self.w2_weight,
|
649
647
|
c=down_output,
|
650
|
-
batch_size=self.
|
648
|
+
batch_size=self.num_local_experts,
|
651
649
|
weight_column_major=True,
|
652
650
|
seg_indptr=seg_indptr_cur_rank,
|
653
651
|
weight_indices=weight_indices_cur_rank,
|
654
652
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
655
653
|
scale_a=self.w2_input_scale,
|
656
|
-
scale_b=
|
657
|
-
self.w2_weight_scale_inv
|
658
|
-
if self.use_block_quant
|
659
|
-
else self.w2_weight_scale
|
660
|
-
),
|
654
|
+
scale_b=self.w2_weight_scale,
|
661
655
|
block_shape=self.block_shape,
|
662
656
|
)
|
663
657
|
del down_input
|
@@ -760,95 +754,14 @@ class EPMoE(torch.nn.Module):
|
|
760
754
|
return
|
761
755
|
expert_id = expert_id - self.start_expert_id
|
762
756
|
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
param.data,
|
772
|
-
loaded_weight,
|
773
|
-
weight_name,
|
774
|
-
shard_id,
|
775
|
-
expert_id,
|
776
|
-
)
|
777
|
-
return
|
778
|
-
|
779
|
-
if shard_id == "w2":
|
780
|
-
param.data[expert_id] = loaded_weight
|
781
|
-
elif shard_id == "w1":
|
782
|
-
param.data[expert_id][: self.intermediate_size, :] = loaded_weight
|
783
|
-
elif shard_id == "w3":
|
784
|
-
param.data[expert_id][self.intermediate_size :, :] = loaded_weight
|
785
|
-
else:
|
786
|
-
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}")
|
787
|
-
|
788
|
-
def _load_fp8_scale(
|
789
|
-
self,
|
790
|
-
param: torch.nn.Parameter,
|
791
|
-
loaded_weight: torch.Tensor,
|
792
|
-
weight_name: str,
|
793
|
-
shard_id: str,
|
794
|
-
expert_id: int,
|
795
|
-
) -> None:
|
796
|
-
param_data = param.data
|
797
|
-
|
798
|
-
# Input scales can be loaded directly and should be equal.
|
799
|
-
if "input_scale" in weight_name:
|
800
|
-
if self.use_w4afp8:
|
801
|
-
if shard_id == "w1":
|
802
|
-
param_data[expert_id][0] = loaded_weight
|
803
|
-
elif shard_id == "w3":
|
804
|
-
param_data[expert_id][1] = loaded_weight
|
805
|
-
else:
|
806
|
-
param_data[expert_id] = loaded_weight
|
807
|
-
return
|
808
|
-
|
809
|
-
if (
|
810
|
-
(shard_id == "w1" or shard_id == "w3")
|
811
|
-
and param_data[expert_id] != 1
|
812
|
-
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
813
|
-
):
|
814
|
-
raise ValueError(
|
815
|
-
"input_scales of w1 and w3 of a layer "
|
816
|
-
f"must be equal. But got {param_data[expert_id]} "
|
817
|
-
f"vs. {loaded_weight}"
|
818
|
-
)
|
819
|
-
param_data[expert_id] = loaded_weight
|
820
|
-
# Weight scales
|
821
|
-
elif "weight_scale" in weight_name:
|
822
|
-
if self.use_block_quant:
|
823
|
-
block_n, block_k = self.block_shape[0], self.block_shape[1]
|
824
|
-
if shard_id == "w1":
|
825
|
-
param_data[expert_id][
|
826
|
-
: (self.intermediate_size + block_n - 1) // block_n, :
|
827
|
-
] = loaded_weight
|
828
|
-
elif shard_id == "w3":
|
829
|
-
param_data[expert_id][
|
830
|
-
(self.intermediate_size + block_n - 1) // block_n :, :
|
831
|
-
] = loaded_weight
|
832
|
-
else: # w2
|
833
|
-
param_data[expert_id] = loaded_weight
|
834
|
-
elif self.use_w4afp8:
|
835
|
-
if shard_id == "w1":
|
836
|
-
param_data[expert_id][: self.intermediate_size, :] = loaded_weight
|
837
|
-
elif shard_id == "w3":
|
838
|
-
param_data[expert_id][self.intermediate_size :, :] = loaded_weight
|
839
|
-
else:
|
840
|
-
param_data[expert_id] = loaded_weight
|
841
|
-
# If we are in merged column case (gate_up_proj)
|
842
|
-
else:
|
843
|
-
if shard_id in ("w1", "w3"):
|
844
|
-
# We have to keep the weight scales of w1 and w3 because
|
845
|
-
# we need to re-quantize w1/w3 weights after weight loading.
|
846
|
-
idx = 0 if shard_id == "w1" else 1
|
847
|
-
param_data[expert_id][idx] = loaded_weight
|
848
|
-
|
849
|
-
# If we are in the row parallel case (down_proj)
|
850
|
-
else:
|
851
|
-
param_data[expert_id] = loaded_weight
|
757
|
+
self._weight_loader_impl(
|
758
|
+
param=param,
|
759
|
+
loaded_weight=loaded_weight,
|
760
|
+
weight_name=weight_name,
|
761
|
+
shard_id=shard_id,
|
762
|
+
expert_id=expert_id,
|
763
|
+
)
|
764
|
+
return
|
852
765
|
|
853
766
|
|
854
767
|
class DeepEPMoE(EPMoE):
|
@@ -887,24 +800,37 @@ class DeepEPMoE(EPMoE):
|
|
887
800
|
routed_scaling_factor=routed_scaling_factor,
|
888
801
|
)
|
889
802
|
self.deepep_mode = deepep_mode
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
803
|
+
|
804
|
+
# TODO: move to the beginning of the file
|
805
|
+
from sglang.srt.distributed.parallel_state import get_tp_group
|
806
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
807
|
+
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
808
|
+
|
809
|
+
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
810
|
+
group=get_tp_group().device_group,
|
811
|
+
router_topk=self.top_k,
|
812
|
+
permute_fusion=True,
|
813
|
+
num_experts=self.num_experts,
|
814
|
+
num_local_experts=self.num_local_experts,
|
815
|
+
hidden_size=hidden_size,
|
816
|
+
params_dtype=params_dtype,
|
817
|
+
deepep_mode=deepep_mode,
|
818
|
+
async_finish=True, # TODO
|
819
|
+
return_recv_hook=True,
|
820
|
+
)
|
895
821
|
|
896
822
|
if self.deepep_mode.enable_low_latency():
|
897
823
|
assert (
|
898
824
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
899
825
|
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
900
826
|
if _use_aiter:
|
901
|
-
# expert_mask is of size (self.
|
827
|
+
# expert_mask is of size (self.num_local_experts + 1),
|
902
828
|
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
903
829
|
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
|
904
830
|
# self.expert_mask = [1, 1, 1, 1, 0]
|
905
831
|
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
|
906
832
|
self.expert_mask = torch.zeros(
|
907
|
-
(self.
|
833
|
+
(self.num_local_experts + 1),
|
908
834
|
device=torch.cuda.current_device(),
|
909
835
|
dtype=torch.int,
|
910
836
|
)
|
@@ -933,37 +859,128 @@ class DeepEPMoE(EPMoE):
|
|
933
859
|
hidden_states: torch.Tensor,
|
934
860
|
topk_idx: torch.Tensor,
|
935
861
|
topk_weights: torch.Tensor,
|
936
|
-
reorder_topk_ids: torch.Tensor,
|
937
|
-
seg_indptr: torch.Tensor,
|
938
|
-
masked_m: torch.Tensor,
|
939
|
-
expected_m: int,
|
940
|
-
num_recv_tokens_per_expert: List[int],
|
941
862
|
forward_batch: ForwardBatch,
|
942
863
|
):
|
864
|
+
dispatch_output = self.dispatch(
|
865
|
+
hidden_states, topk_idx, topk_weights, forward_batch
|
866
|
+
)
|
867
|
+
hidden_states = self.moe_impl(dispatch_output)
|
868
|
+
hidden_states = self.combine(
|
869
|
+
hidden_states,
|
870
|
+
dispatch_output.topk_idx,
|
871
|
+
dispatch_output.topk_weights,
|
872
|
+
forward_batch,
|
873
|
+
)
|
874
|
+
return hidden_states
|
875
|
+
|
876
|
+
def dispatch(
|
877
|
+
self,
|
878
|
+
hidden_states: torch.Tensor,
|
879
|
+
topk_idx: torch.Tensor,
|
880
|
+
topk_weights: torch.Tensor,
|
881
|
+
forward_batch: ForwardBatch,
|
882
|
+
):
|
883
|
+
return self.deepep_dispatcher.dispatch(
|
884
|
+
hidden_states=hidden_states,
|
885
|
+
topk_idx=topk_idx,
|
886
|
+
topk_weights=topk_weights,
|
887
|
+
forward_batch=forward_batch,
|
888
|
+
)
|
889
|
+
|
890
|
+
def moe_impl(self, dispatch_output: DispatchOutput):
|
943
891
|
if _use_aiter:
|
944
892
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
945
|
-
return self.forward_aiter(
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
if resolved_deepep_mode == DeepEPMode.normal:
|
950
|
-
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
951
|
-
return self.forward_deepgemm_contiguous(
|
952
|
-
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
953
|
-
)
|
893
|
+
return self.forward_aiter(dispatch_output)
|
894
|
+
if dispatch_output.format.is_deepep_normal():
|
895
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
896
|
+
return self.forward_deepgemm_contiguous(dispatch_output)
|
954
897
|
else:
|
955
|
-
return self.forward_normal(
|
956
|
-
elif
|
957
|
-
return self.forward_deepgemm_masked(
|
898
|
+
return self.forward_normal(dispatch_output)
|
899
|
+
elif dispatch_output.format.is_deepep_ll():
|
900
|
+
return self.forward_deepgemm_masked(dispatch_output)
|
958
901
|
else:
|
959
902
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
960
903
|
|
961
|
-
def
|
904
|
+
def combine(
|
962
905
|
self,
|
963
906
|
hidden_states: torch.Tensor,
|
964
|
-
|
965
|
-
|
907
|
+
topk_idx: torch.Tensor,
|
908
|
+
topk_weights: torch.Tensor,
|
909
|
+
forward_batch: ForwardBatch,
|
910
|
+
):
|
911
|
+
return self.deepep_dispatcher.combine(
|
912
|
+
hidden_states=hidden_states,
|
913
|
+
topk_idx=topk_idx,
|
914
|
+
topk_weights=topk_weights,
|
915
|
+
forward_batch=forward_batch,
|
916
|
+
)
|
917
|
+
|
918
|
+
def _prepare_for_normal(
|
919
|
+
self,
|
920
|
+
hidden_states: torch.Tensor,
|
921
|
+
topk_idx: torch.Tensor,
|
922
|
+
):
|
923
|
+
from sglang.srt.layers.moe.ep_moe.kernels import (
|
924
|
+
deepep_permute_triton_kernel,
|
925
|
+
deepep_run_moe_deep_preprocess,
|
926
|
+
)
|
927
|
+
|
928
|
+
if hidden_states.shape[0] == 0:
|
929
|
+
reorder_topk_ids = torch.empty(
|
930
|
+
(0,), device=hidden_states.device, dtype=torch.int64
|
931
|
+
)
|
932
|
+
seg_indptr = torch.zeros(
|
933
|
+
(self.num_experts + 1,),
|
934
|
+
device=hidden_states.device,
|
935
|
+
dtype=torch.int64,
|
936
|
+
)
|
937
|
+
return reorder_topk_ids, seg_indptr, hidden_states
|
938
|
+
else:
|
939
|
+
if _use_aiter:
|
940
|
+
# skip permutation here as aiter fused_moe has fused inside
|
941
|
+
reorder_topk_ids = torch.empty(
|
942
|
+
(0,), device=hidden_states.device, dtype=torch.int64
|
943
|
+
)
|
944
|
+
seg_indptr = torch.zeros(
|
945
|
+
(self.num_experts + 1,),
|
946
|
+
device=hidden_states.device,
|
947
|
+
dtype=torch.int64,
|
948
|
+
)
|
949
|
+
return reorder_topk_ids, seg_indptr, hidden_states
|
950
|
+
|
951
|
+
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
952
|
+
topk_idx, self.num_experts
|
953
|
+
)
|
954
|
+
num_total_tokens = reorder_topk_ids.numel()
|
955
|
+
gateup_input = torch.empty(
|
956
|
+
(int(num_total_tokens), hidden_states.shape[1]),
|
957
|
+
device=hidden_states.device,
|
958
|
+
dtype=hidden_states.dtype,
|
959
|
+
)
|
960
|
+
# PreReorder
|
961
|
+
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
962
|
+
hidden_states,
|
963
|
+
gateup_input,
|
964
|
+
self.src2dst,
|
965
|
+
topk_idx,
|
966
|
+
None,
|
967
|
+
self.router_topk,
|
968
|
+
hidden_states.shape[1],
|
969
|
+
BLOCK_SIZE=512,
|
970
|
+
)
|
971
|
+
return reorder_topk_ids, seg_indptr, gateup_input
|
972
|
+
|
973
|
+
def forward_normal(
|
974
|
+
self,
|
975
|
+
dispatch_output: DeepEPNormalOutput,
|
966
976
|
):
|
977
|
+
hidden_states, topk_idx = (
|
978
|
+
dispatch_output.hidden_states,
|
979
|
+
dispatch_output.topk_idx,
|
980
|
+
)
|
981
|
+
reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
|
982
|
+
hidden_states, topk_idx
|
983
|
+
)
|
967
984
|
hidden_states_dtype = hidden_states.dtype
|
968
985
|
hidden_states_device = hidden_states.device
|
969
986
|
|
@@ -977,13 +994,13 @@ class DeepEPMoE(EPMoE):
|
|
977
994
|
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
978
995
|
max_value = (
|
979
996
|
torch.max(hidden_states)
|
980
|
-
.repeat(self.
|
997
|
+
.repeat(self.num_local_experts)
|
981
998
|
.to(torch.float32)
|
982
999
|
)
|
983
1000
|
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
984
1001
|
weight_indices_cur_rank = torch.arange(
|
985
1002
|
0,
|
986
|
-
self.
|
1003
|
+
self.num_local_experts,
|
987
1004
|
device=hidden_states.device,
|
988
1005
|
dtype=torch.int64,
|
989
1006
|
)
|
@@ -995,7 +1012,7 @@ class DeepEPMoE(EPMoE):
|
|
995
1012
|
b=self.w13_weight,
|
996
1013
|
c=None,
|
997
1014
|
c_dtype=hidden_states.dtype,
|
998
|
-
batch_size=self.
|
1015
|
+
batch_size=self.num_local_experts,
|
999
1016
|
weight_column_major=True,
|
1000
1017
|
seg_indptr=seg_indptr,
|
1001
1018
|
weight_indices=weight_indices_cur_rank,
|
@@ -1029,7 +1046,7 @@ class DeepEPMoE(EPMoE):
|
|
1029
1046
|
)
|
1030
1047
|
if self.w2_input_scale is None and not self.use_block_quant:
|
1031
1048
|
self.w2_input_scale = torch.ones(
|
1032
|
-
self.
|
1049
|
+
self.num_local_experts,
|
1033
1050
|
dtype=torch.float32,
|
1034
1051
|
device=hidden_states_device,
|
1035
1052
|
)
|
@@ -1042,7 +1059,7 @@ class DeepEPMoE(EPMoE):
|
|
1042
1059
|
reorder_topk_ids,
|
1043
1060
|
self.w2_input_scale,
|
1044
1061
|
0,
|
1045
|
-
self.
|
1062
|
+
self.num_local_experts - 1,
|
1046
1063
|
BLOCK_SIZE=512,
|
1047
1064
|
)
|
1048
1065
|
else:
|
@@ -1062,7 +1079,7 @@ class DeepEPMoE(EPMoE):
|
|
1062
1079
|
a=down_input,
|
1063
1080
|
b=self.w2_weight,
|
1064
1081
|
c=down_output,
|
1065
|
-
batch_size=self.
|
1082
|
+
batch_size=self.num_local_experts,
|
1066
1083
|
weight_column_major=True,
|
1067
1084
|
seg_indptr=seg_indptr,
|
1068
1085
|
weight_indices=weight_indices_cur_rank,
|
@@ -1079,17 +1096,20 @@ class DeepEPMoE(EPMoE):
|
|
1079
1096
|
|
1080
1097
|
def forward_aiter(
|
1081
1098
|
self,
|
1082
|
-
|
1083
|
-
topk_idx: torch.Tensor,
|
1084
|
-
topk_weights: torch.Tensor,
|
1099
|
+
dispatch_output: DeepEPNormalOutput,
|
1085
1100
|
):
|
1101
|
+
hidden_states, topk_idx, topk_weights = (
|
1102
|
+
dispatch_output.hidden_states,
|
1103
|
+
dispatch_output.topk_idx,
|
1104
|
+
dispatch_output.topk_weights,
|
1105
|
+
)
|
1086
1106
|
if hidden_states.shape[0] == 0:
|
1087
1107
|
return hidden_states
|
1088
1108
|
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
1089
1109
|
# aiter does not accept -1, we use a expert mask to make these idx invalid
|
1090
|
-
# (idx ==
|
1110
|
+
# (idx == num_local_experts) meaning not used in aiter fused_moe
|
1091
1111
|
topk_idx_copy = topk_idx.to(torch.int32)
|
1092
|
-
topk_idx_copy[topk_idx_copy == -1] = self.
|
1112
|
+
topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts
|
1093
1113
|
|
1094
1114
|
return fused_moe(
|
1095
1115
|
hidden_states,
|
@@ -1110,11 +1130,11 @@ class DeepEPMoE(EPMoE):
|
|
1110
1130
|
|
1111
1131
|
def forward_deepgemm_contiguous(
|
1112
1132
|
self,
|
1113
|
-
|
1114
|
-
topk_idx,
|
1115
|
-
topk_weights,
|
1116
|
-
num_recv_tokens_per_expert: List[int],
|
1133
|
+
dispatch_output: DeepEPNormalOutput,
|
1117
1134
|
):
|
1135
|
+
hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
|
1136
|
+
dispatch_output
|
1137
|
+
)
|
1118
1138
|
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
1119
1139
|
assert self.quant_method is not None
|
1120
1140
|
assert self.activation == "silu"
|
@@ -1234,10 +1254,9 @@ class DeepEPMoE(EPMoE):
|
|
1234
1254
|
|
1235
1255
|
def forward_deepgemm_masked(
|
1236
1256
|
self,
|
1237
|
-
|
1238
|
-
masked_m: torch.Tensor,
|
1239
|
-
expected_m: int,
|
1257
|
+
dispatch_output: DeepEPLLOutput,
|
1240
1258
|
):
|
1259
|
+
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
|
1241
1260
|
assert self.quant_method is not None
|
1242
1261
|
assert self.activation == "silu"
|
1243
1262
|
|
@@ -1315,12 +1334,74 @@ class DeepEPMoE(EPMoE):
|
|
1315
1334
|
return down_output
|
1316
1335
|
|
1317
1336
|
|
1337
|
+
class FlashInferEPMoE(EPMoE):
|
1338
|
+
def __init__(self, *args, **kwargs):
|
1339
|
+
renormalize = kwargs.pop("renormalize", True)
|
1340
|
+
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
1341
|
+
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
1342
|
+
num_expert_group = kwargs.pop("num_expert_group", None)
|
1343
|
+
topk_group = kwargs.pop("topk_group", None)
|
1344
|
+
correction_bias = kwargs.pop("correction_bias", None)
|
1345
|
+
super().__init__(*args, **kwargs)
|
1346
|
+
self.renormalize = renormalize
|
1347
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
1348
|
+
self.use_grouped_topk = use_grouped_topk
|
1349
|
+
if self.use_grouped_topk:
|
1350
|
+
assert num_expert_group is not None and topk_group is not None
|
1351
|
+
self.num_expert_group = num_expert_group
|
1352
|
+
self.topk_group = topk_group
|
1353
|
+
self.correction_bias = correction_bias
|
1354
|
+
self.use_flashinfer_trtllm_moe = use_flashinfer_trtllm_moe
|
1355
|
+
|
1356
|
+
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
1357
|
+
assert use_flashinfer_trtllm_moe
|
1358
|
+
assert (
|
1359
|
+
self.activation == "silu"
|
1360
|
+
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
1361
|
+
assert (
|
1362
|
+
self.renormalize
|
1363
|
+
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
1364
|
+
assert (
|
1365
|
+
self.num_fused_shared_experts == 0
|
1366
|
+
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
1367
|
+
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
|
1368
|
+
# NOTE: scales of hidden states have to be transposed!
|
1369
|
+
a_sf_t = a_sf.t().contiguous()
|
1370
|
+
assert fi_fused_moe is not None
|
1371
|
+
return fi_fused_moe.trtllm_fp8_block_scale_moe(
|
1372
|
+
routing_logits=router_logits.to(torch.float32),
|
1373
|
+
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
1374
|
+
hidden_states=a_q,
|
1375
|
+
hidden_states_scale=a_sf_t,
|
1376
|
+
gemm1_weights=self.w13_weight,
|
1377
|
+
gemm1_weights_scale=self.w13_weight_scale_inv,
|
1378
|
+
gemm2_weights=self.w2_weight,
|
1379
|
+
gemm2_weights_scale=self.w2_weight_scale_inv,
|
1380
|
+
num_experts=self.num_experts,
|
1381
|
+
top_k=self.top_k,
|
1382
|
+
n_group=self.num_expert_group,
|
1383
|
+
topk_group=self.topk_group,
|
1384
|
+
intermediate_size=self.w2_weight.shape[2],
|
1385
|
+
local_expert_offset=self.start_expert_id,
|
1386
|
+
local_num_experts=self.num_local_experts,
|
1387
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
1388
|
+
tile_tokens_dim=_get_tile_tokens_dim(
|
1389
|
+
hidden_states.shape[0], self.top_k, self.num_experts
|
1390
|
+
),
|
1391
|
+
routing_method_type=2, # DeepSeek-styled routing method
|
1392
|
+
use_shuffled_weight=False,
|
1393
|
+
)
|
1394
|
+
|
1395
|
+
|
1318
1396
|
def get_moe_impl_class():
|
1319
1397
|
if global_server_args_dict["enable_deepep_moe"]:
|
1320
1398
|
return DeepEPMoE
|
1321
|
-
if global_server_args_dict["
|
1399
|
+
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
1322
1400
|
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
1323
1401
|
return FusedMoE
|
1402
|
+
if use_flashinfer_trtllm_moe:
|
1403
|
+
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
1404
|
+
return FlashInferEPMoE
|
1324
1405
|
if global_server_args_dict["enable_ep_moe"]:
|
1325
1406
|
return EPMoE
|
1326
1407
|
return FusedMoE
|