sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -30,13 +30,13 @@ from sglang.srt.layers.quantization.base_config import (
|
|
30
30
|
QuantizationConfig,
|
31
31
|
QuantizeMethodBase,
|
32
32
|
)
|
33
|
-
from sglang.srt.layers.quantization.fp8 import
|
33
|
+
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
34
34
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
35
35
|
is_fp8_fnuz,
|
36
36
|
sglang_per_token_group_quant_fp8,
|
37
37
|
sglang_per_token_quant_fp8,
|
38
38
|
)
|
39
|
-
from sglang.srt.layers.quantization.unquant import
|
39
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
40
40
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
41
41
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
42
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -47,23 +47,33 @@ from sglang.srt.utils import (
|
|
47
47
|
get_bool_env_var,
|
48
48
|
is_hip,
|
49
49
|
is_npu,
|
50
|
+
next_power_of_2,
|
50
51
|
)
|
51
52
|
|
52
53
|
_is_hip = is_hip()
|
53
54
|
_is_npu = is_npu()
|
54
55
|
_is_fp8_fnuz = is_fp8_fnuz()
|
55
56
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
57
|
+
use_flashinfer_trtllm_moe = (
|
58
|
+
global_server_args_dict["enable_flashinfer_trtllm_moe"]
|
59
|
+
and global_server_args_dict["enable_ep_moe"]
|
60
|
+
)
|
56
61
|
|
57
62
|
if not (_is_npu or _is_hip):
|
58
63
|
from sgl_kernel import silu_and_mul
|
59
64
|
|
60
|
-
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
61
|
-
|
62
65
|
if _use_aiter:
|
63
66
|
from aiter import ActivationType, QuantType
|
64
67
|
from aiter.fused_moe import fused_moe
|
65
68
|
from aiter.ops.shuffle import shuffle_weight
|
66
69
|
|
70
|
+
if use_flashinfer_trtllm_moe:
|
71
|
+
try:
|
72
|
+
import flashinfer.fused_moe as fi_fused_moe
|
73
|
+
except ImportError:
|
74
|
+
fi_fused_moe = None
|
75
|
+
use_flashinfer_trtllm_moe = False
|
76
|
+
|
67
77
|
logger = logging.getLogger(__name__)
|
68
78
|
|
69
79
|
|
@@ -140,7 +150,17 @@ class GroupedGemmRunner(torch.nn.Module):
|
|
140
150
|
return c
|
141
151
|
|
142
152
|
|
143
|
-
|
153
|
+
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
154
|
+
# Guess tokens per expert assuming perfect expert distribution first.
|
155
|
+
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
156
|
+
# And pad the number to the next power of 2.
|
157
|
+
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
158
|
+
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
159
|
+
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
160
|
+
return tile_tokens_dim
|
161
|
+
|
162
|
+
|
163
|
+
class EPMoE(FusedMoE):
|
144
164
|
"""
|
145
165
|
MoE Expert Parallel Impl
|
146
166
|
|
@@ -162,51 +182,60 @@ class EPMoE(torch.nn.Module):
|
|
162
182
|
routed_scaling_factor: Optional[float] = None,
|
163
183
|
use_per_token_if_dynamic: bool = True,
|
164
184
|
):
|
165
|
-
super().__init__(
|
185
|
+
super().__init__(
|
186
|
+
num_experts=num_experts,
|
187
|
+
hidden_size=hidden_size,
|
188
|
+
intermediate_size=intermediate_size,
|
189
|
+
top_k=top_k,
|
190
|
+
layer_id=layer_id,
|
191
|
+
params_dtype=params_dtype,
|
192
|
+
quant_config=quant_config,
|
193
|
+
tp_size=tp_size,
|
194
|
+
prefix=prefix,
|
195
|
+
activation=activation,
|
196
|
+
routed_scaling_factor=routed_scaling_factor,
|
197
|
+
enable_ep_moe=True,
|
198
|
+
skip_quant=True,
|
199
|
+
)
|
166
200
|
|
167
201
|
if params_dtype is None:
|
168
202
|
params_dtype = torch.get_default_dtype()
|
169
203
|
|
170
|
-
self.tp_size = (
|
171
|
-
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
172
|
-
)
|
173
|
-
self.tp_rank = get_tensor_model_parallel_rank()
|
174
|
-
|
175
204
|
self.layer_id = layer_id
|
176
|
-
self.
|
177
|
-
|
178
|
-
self.
|
179
|
-
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
180
|
-
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
205
|
+
self.num_local_experts, self.expert_map = self.determine_expert_map()
|
206
|
+
self.start_expert_id = self.ep_rank * self.num_local_experts
|
207
|
+
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
181
208
|
|
182
|
-
self.top_k = top_k
|
183
209
|
self.intermediate_size = intermediate_size
|
184
|
-
self.activation = activation
|
185
|
-
self.routed_scaling_factor = routed_scaling_factor
|
186
210
|
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
187
211
|
|
212
|
+
# TODO(ch-wan): move quant preparation to FusedMoE
|
188
213
|
if quant_config is None:
|
189
|
-
self.quant_method: Optional[QuantizeMethodBase] =
|
214
|
+
self.quant_method: Optional[QuantizeMethodBase] = (
|
215
|
+
UnquantizedFusedMoEMethod()
|
216
|
+
)
|
190
217
|
self.use_fp8_w8a8 = False
|
191
218
|
self.use_block_quant = False
|
192
219
|
self.block_shape = None
|
193
220
|
self.activation_scheme = None
|
194
|
-
self.
|
221
|
+
self.w13_input_scale = None
|
222
|
+
self.w2_input_scale = None
|
223
|
+
self.w13_weight_scale = None
|
224
|
+
self.w2_weight_scale = None
|
195
225
|
elif isinstance(quant_config, W4AFp8Config):
|
196
226
|
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
|
197
227
|
quant_config
|
198
228
|
)
|
199
|
-
self.use_w4afp8 = True
|
200
229
|
self.use_fp8_w8a8 = False
|
201
230
|
self.use_block_quant = False
|
202
231
|
self.fp8_dtype = torch.float8_e4m3fn
|
232
|
+
self.w13_input_scale = None
|
233
|
+
self.w2_input_scale = None
|
203
234
|
self.w13_weight_scale = None
|
204
235
|
self.w2_weight_scale = None
|
205
236
|
self.activation_scheme = quant_config.moe_activation_scheme
|
206
|
-
|
207
|
-
self.quant_method: Optional[QuantizeMethodBase] =
|
208
|
-
quant_config
|
209
|
-
)
|
237
|
+
elif isinstance(quant_config, Fp8Config):
|
238
|
+
self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
|
210
239
|
self.use_fp8_w8a8 = True
|
211
240
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
212
241
|
self.block_shape = (
|
@@ -216,11 +245,13 @@ class EPMoE(torch.nn.Module):
|
|
216
245
|
)
|
217
246
|
self.fp8_dtype = torch.float8_e4m3fn
|
218
247
|
self.activation_scheme = quant_config.activation_scheme
|
219
|
-
|
248
|
+
else:
|
249
|
+
raise ValueError(f"Unsupported quant_config: {quant_config}")
|
220
250
|
|
251
|
+
self.quant_config = quant_config
|
221
252
|
self.quant_method.create_weights(
|
222
253
|
layer=self,
|
223
|
-
|
254
|
+
num_experts=self.num_local_experts,
|
224
255
|
hidden_size=hidden_size,
|
225
256
|
intermediate_size=self.intermediate_size,
|
226
257
|
params_dtype=params_dtype,
|
@@ -229,19 +260,6 @@ class EPMoE(torch.nn.Module):
|
|
229
260
|
|
230
261
|
self.grouped_gemm_runner = None
|
231
262
|
|
232
|
-
self.w13_weight_fp8 = (
|
233
|
-
self.w13_weight,
|
234
|
-
(
|
235
|
-
self.w13_weight_scale_inv
|
236
|
-
if self.use_block_quant
|
237
|
-
else self.w13_weight_scale
|
238
|
-
),
|
239
|
-
)
|
240
|
-
self.w2_weight_fp8 = (
|
241
|
-
self.w2_weight,
|
242
|
-
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
243
|
-
)
|
244
|
-
|
245
263
|
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
|
246
264
|
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
|
247
265
|
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
|
@@ -260,8 +278,8 @@ class EPMoE(torch.nn.Module):
|
|
260
278
|
Contains global_num_experts for experts not assigned to the current rank.
|
261
279
|
Returns None if ep_size is 1.
|
262
280
|
"""
|
263
|
-
ep_size = self.
|
264
|
-
ep_rank = self.
|
281
|
+
ep_size = self.ep_size
|
282
|
+
ep_rank = self.ep_rank
|
265
283
|
global_num_experts = self.num_experts
|
266
284
|
|
267
285
|
assert ep_size > 0
|
@@ -271,7 +289,7 @@ class EPMoE(torch.nn.Module):
|
|
271
289
|
local_num_experts = global_num_experts // ep_size
|
272
290
|
|
273
291
|
expert_map = torch.full(
|
274
|
-
(global_num_experts,),
|
292
|
+
(global_num_experts,), global_num_experts, dtype=torch.int32
|
275
293
|
)
|
276
294
|
if ep_rank < (ep_size - 1):
|
277
295
|
expert_map[
|
@@ -296,6 +314,20 @@ class EPMoE(torch.nn.Module):
|
|
296
314
|
hidden_states: torch.Tensor,
|
297
315
|
topk_output: TopKOutput,
|
298
316
|
):
|
317
|
+
|
318
|
+
self.w13_weight_fp8 = (
|
319
|
+
self.w13_weight,
|
320
|
+
(
|
321
|
+
self.w13_weight_scale_inv
|
322
|
+
if self.use_block_quant
|
323
|
+
else self.w13_weight_scale
|
324
|
+
),
|
325
|
+
)
|
326
|
+
self.w2_weight_fp8 = (
|
327
|
+
self.w2_weight,
|
328
|
+
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
329
|
+
)
|
330
|
+
|
299
331
|
assert self.quant_method is not None
|
300
332
|
assert self.activation == "silu"
|
301
333
|
hidden_states_shape = hidden_states.shape
|
@@ -435,7 +467,10 @@ class EPMoE(torch.nn.Module):
|
|
435
467
|
return output
|
436
468
|
|
437
469
|
def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
438
|
-
|
470
|
+
return self.quant_method.apply(self, hidden_states, topk_output)
|
471
|
+
|
472
|
+
def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
473
|
+
|
439
474
|
topk_weights, topk_ids, _ = topk_output
|
440
475
|
|
441
476
|
hidden_states_shape = hidden_states.shape
|
@@ -448,53 +483,11 @@ class EPMoE(torch.nn.Module):
|
|
448
483
|
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
449
484
|
)
|
450
485
|
|
451
|
-
|
452
|
-
local_topk_ids = topk_ids
|
453
|
-
if self.expert_map is not None:
|
454
|
-
"Translate info from expert_map to topk_ids"
|
455
|
-
local_topk_ids = torch.where(
|
456
|
-
self.expert_map[topk_ids] != self.num_experts,
|
457
|
-
self.expert_map[topk_ids],
|
458
|
-
self.num_experts,
|
459
|
-
)
|
460
|
-
|
461
|
-
output = cutlass_w4a8_moe(
|
462
|
-
self.start_expert_id,
|
463
|
-
self.end_expert_id,
|
464
|
-
self.num_experts,
|
465
|
-
hidden_states,
|
466
|
-
self.w13_weight,
|
467
|
-
self.w2_weight,
|
468
|
-
self.w13_weight_scale_inv,
|
469
|
-
self.w2_weight_scale_inv,
|
470
|
-
topk_weights,
|
471
|
-
topk_ids,
|
472
|
-
local_topk_ids,
|
473
|
-
self.quant_method.a_strides1,
|
474
|
-
self.quant_method.b_strides1,
|
475
|
-
self.quant_method.c_strides1,
|
476
|
-
self.quant_method.a_strides2,
|
477
|
-
self.quant_method.b_strides2,
|
478
|
-
self.quant_method.c_strides2,
|
479
|
-
self.quant_method.s_strides13,
|
480
|
-
self.quant_method.s_strides2,
|
481
|
-
self.quant_method.expert_offsets,
|
482
|
-
self.quant_method.problem_sizes1,
|
483
|
-
self.quant_method.problem_sizes2,
|
484
|
-
self.w13_input_scale,
|
485
|
-
self.w2_input_scale,
|
486
|
-
)
|
487
|
-
return output
|
488
|
-
|
489
|
-
if self.grouped_gemm_runner is None:
|
490
|
-
self.grouped_gemm_runner = GroupedGemmRunner(
|
491
|
-
hidden_states.device,
|
492
|
-
use_flashinfer=False, # TODO: use flashinfer
|
493
|
-
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
494
|
-
)
|
486
|
+
num_experts = self.num_experts
|
495
487
|
|
496
488
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
497
|
-
topk_ids,
|
489
|
+
topk_ids,
|
490
|
+
num_experts,
|
498
491
|
)
|
499
492
|
|
500
493
|
gateup_input = torch.empty(
|
@@ -502,7 +495,7 @@ class EPMoE(torch.nn.Module):
|
|
502
495
|
device=hidden_states.device,
|
503
496
|
dtype=(
|
504
497
|
self.fp8_dtype
|
505
|
-
if
|
498
|
+
if self.use_fp8_w8a8 and not self.use_block_quant
|
506
499
|
else hidden_states.dtype
|
507
500
|
),
|
508
501
|
)
|
@@ -513,7 +506,7 @@ class EPMoE(torch.nn.Module):
|
|
513
506
|
else:
|
514
507
|
max_value = (
|
515
508
|
torch.max(hidden_states)
|
516
|
-
.repeat(self.
|
509
|
+
.repeat(self.num_local_experts)
|
517
510
|
.to(torch.float32)
|
518
511
|
)
|
519
512
|
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
@@ -554,7 +547,7 @@ class EPMoE(torch.nn.Module):
|
|
554
547
|
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
555
548
|
weight_indices_cur_rank = torch.arange(
|
556
549
|
0,
|
557
|
-
self.
|
550
|
+
self.num_local_experts,
|
558
551
|
device=hidden_states_device,
|
559
552
|
dtype=torch.int64,
|
560
553
|
)
|
@@ -564,17 +557,13 @@ class EPMoE(torch.nn.Module):
|
|
564
557
|
b=self.w13_weight,
|
565
558
|
c=None,
|
566
559
|
c_dtype=hidden_states_dtype,
|
567
|
-
batch_size=self.
|
560
|
+
batch_size=self.num_local_experts,
|
568
561
|
weight_column_major=True,
|
569
562
|
seg_indptr=seg_indptr_cur_rank,
|
570
563
|
weight_indices=weight_indices_cur_rank,
|
571
564
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
572
565
|
scale_a=self.w13_input_scale,
|
573
|
-
scale_b=
|
574
|
-
self.w13_weight_scale_inv
|
575
|
-
if self.use_block_quant
|
576
|
-
else self.w13_weight_scale
|
577
|
-
),
|
566
|
+
scale_b=self.w13_weight_scale,
|
578
567
|
block_shape=self.block_shape,
|
579
568
|
)
|
580
569
|
del gateup_input
|
@@ -631,7 +620,7 @@ class EPMoE(torch.nn.Module):
|
|
631
620
|
down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
|
632
621
|
else:
|
633
622
|
self.w2_input_scale = torch.ones(
|
634
|
-
self.
|
623
|
+
self.num_local_experts,
|
635
624
|
dtype=torch.float32,
|
636
625
|
device=hidden_states_device,
|
637
626
|
)
|
@@ -647,17 +636,13 @@ class EPMoE(torch.nn.Module):
|
|
647
636
|
a=down_input,
|
648
637
|
b=self.w2_weight,
|
649
638
|
c=down_output,
|
650
|
-
batch_size=self.
|
639
|
+
batch_size=self.num_local_experts,
|
651
640
|
weight_column_major=True,
|
652
641
|
seg_indptr=seg_indptr_cur_rank,
|
653
642
|
weight_indices=weight_indices_cur_rank,
|
654
643
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
655
644
|
scale_a=self.w2_input_scale,
|
656
|
-
scale_b=
|
657
|
-
self.w2_weight_scale_inv
|
658
|
-
if self.use_block_quant
|
659
|
-
else self.w2_weight_scale
|
660
|
-
),
|
645
|
+
scale_b=self.w2_weight_scale,
|
661
646
|
block_shape=self.block_shape,
|
662
647
|
)
|
663
648
|
del down_input
|
@@ -760,95 +745,14 @@ class EPMoE(torch.nn.Module):
|
|
760
745
|
return
|
761
746
|
expert_id = expert_id - self.start_expert_id
|
762
747
|
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
param.data,
|
772
|
-
loaded_weight,
|
773
|
-
weight_name,
|
774
|
-
shard_id,
|
775
|
-
expert_id,
|
776
|
-
)
|
777
|
-
return
|
778
|
-
|
779
|
-
if shard_id == "w2":
|
780
|
-
param.data[expert_id] = loaded_weight
|
781
|
-
elif shard_id == "w1":
|
782
|
-
param.data[expert_id][: self.intermediate_size, :] = loaded_weight
|
783
|
-
elif shard_id == "w3":
|
784
|
-
param.data[expert_id][self.intermediate_size :, :] = loaded_weight
|
785
|
-
else:
|
786
|
-
raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}")
|
787
|
-
|
788
|
-
def _load_fp8_scale(
|
789
|
-
self,
|
790
|
-
param: torch.nn.Parameter,
|
791
|
-
loaded_weight: torch.Tensor,
|
792
|
-
weight_name: str,
|
793
|
-
shard_id: str,
|
794
|
-
expert_id: int,
|
795
|
-
) -> None:
|
796
|
-
param_data = param.data
|
797
|
-
|
798
|
-
# Input scales can be loaded directly and should be equal.
|
799
|
-
if "input_scale" in weight_name:
|
800
|
-
if self.use_w4afp8:
|
801
|
-
if shard_id == "w1":
|
802
|
-
param_data[expert_id][0] = loaded_weight
|
803
|
-
elif shard_id == "w3":
|
804
|
-
param_data[expert_id][1] = loaded_weight
|
805
|
-
else:
|
806
|
-
param_data[expert_id] = loaded_weight
|
807
|
-
return
|
808
|
-
|
809
|
-
if (
|
810
|
-
(shard_id == "w1" or shard_id == "w3")
|
811
|
-
and param_data[expert_id] != 1
|
812
|
-
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
813
|
-
):
|
814
|
-
raise ValueError(
|
815
|
-
"input_scales of w1 and w3 of a layer "
|
816
|
-
f"must be equal. But got {param_data[expert_id]} "
|
817
|
-
f"vs. {loaded_weight}"
|
818
|
-
)
|
819
|
-
param_data[expert_id] = loaded_weight
|
820
|
-
# Weight scales
|
821
|
-
elif "weight_scale" in weight_name:
|
822
|
-
if self.use_block_quant:
|
823
|
-
block_n, block_k = self.block_shape[0], self.block_shape[1]
|
824
|
-
if shard_id == "w1":
|
825
|
-
param_data[expert_id][
|
826
|
-
: (self.intermediate_size + block_n - 1) // block_n, :
|
827
|
-
] = loaded_weight
|
828
|
-
elif shard_id == "w3":
|
829
|
-
param_data[expert_id][
|
830
|
-
(self.intermediate_size + block_n - 1) // block_n :, :
|
831
|
-
] = loaded_weight
|
832
|
-
else: # w2
|
833
|
-
param_data[expert_id] = loaded_weight
|
834
|
-
elif self.use_w4afp8:
|
835
|
-
if shard_id == "w1":
|
836
|
-
param_data[expert_id][: self.intermediate_size, :] = loaded_weight
|
837
|
-
elif shard_id == "w3":
|
838
|
-
param_data[expert_id][self.intermediate_size :, :] = loaded_weight
|
839
|
-
else:
|
840
|
-
param_data[expert_id] = loaded_weight
|
841
|
-
# If we are in merged column case (gate_up_proj)
|
842
|
-
else:
|
843
|
-
if shard_id in ("w1", "w3"):
|
844
|
-
# We have to keep the weight scales of w1 and w3 because
|
845
|
-
# we need to re-quantize w1/w3 weights after weight loading.
|
846
|
-
idx = 0 if shard_id == "w1" else 1
|
847
|
-
param_data[expert_id][idx] = loaded_weight
|
848
|
-
|
849
|
-
# If we are in the row parallel case (down_proj)
|
850
|
-
else:
|
851
|
-
param_data[expert_id] = loaded_weight
|
748
|
+
self._weight_loader_impl(
|
749
|
+
param=param,
|
750
|
+
loaded_weight=loaded_weight,
|
751
|
+
weight_name=weight_name,
|
752
|
+
shard_id=shard_id,
|
753
|
+
expert_id=expert_id,
|
754
|
+
)
|
755
|
+
return
|
852
756
|
|
853
757
|
|
854
758
|
class DeepEPMoE(EPMoE):
|
@@ -898,13 +802,13 @@ class DeepEPMoE(EPMoE):
|
|
898
802
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
899
803
|
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
900
804
|
if _use_aiter:
|
901
|
-
# expert_mask is of size (self.
|
805
|
+
# expert_mask is of size (self.num_local_experts + 1),
|
902
806
|
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
903
807
|
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
|
904
808
|
# self.expert_mask = [1, 1, 1, 1, 0]
|
905
809
|
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
|
906
810
|
self.expert_mask = torch.zeros(
|
907
|
-
(self.
|
811
|
+
(self.num_local_experts + 1),
|
908
812
|
device=torch.cuda.current_device(),
|
909
813
|
dtype=torch.int,
|
910
814
|
)
|
@@ -977,13 +881,13 @@ class DeepEPMoE(EPMoE):
|
|
977
881
|
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
978
882
|
max_value = (
|
979
883
|
torch.max(hidden_states)
|
980
|
-
.repeat(self.
|
884
|
+
.repeat(self.num_local_experts)
|
981
885
|
.to(torch.float32)
|
982
886
|
)
|
983
887
|
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
984
888
|
weight_indices_cur_rank = torch.arange(
|
985
889
|
0,
|
986
|
-
self.
|
890
|
+
self.num_local_experts,
|
987
891
|
device=hidden_states.device,
|
988
892
|
dtype=torch.int64,
|
989
893
|
)
|
@@ -995,7 +899,7 @@ class DeepEPMoE(EPMoE):
|
|
995
899
|
b=self.w13_weight,
|
996
900
|
c=None,
|
997
901
|
c_dtype=hidden_states.dtype,
|
998
|
-
batch_size=self.
|
902
|
+
batch_size=self.num_local_experts,
|
999
903
|
weight_column_major=True,
|
1000
904
|
seg_indptr=seg_indptr,
|
1001
905
|
weight_indices=weight_indices_cur_rank,
|
@@ -1029,7 +933,7 @@ class DeepEPMoE(EPMoE):
|
|
1029
933
|
)
|
1030
934
|
if self.w2_input_scale is None and not self.use_block_quant:
|
1031
935
|
self.w2_input_scale = torch.ones(
|
1032
|
-
self.
|
936
|
+
self.num_local_experts,
|
1033
937
|
dtype=torch.float32,
|
1034
938
|
device=hidden_states_device,
|
1035
939
|
)
|
@@ -1042,7 +946,7 @@ class DeepEPMoE(EPMoE):
|
|
1042
946
|
reorder_topk_ids,
|
1043
947
|
self.w2_input_scale,
|
1044
948
|
0,
|
1045
|
-
self.
|
949
|
+
self.num_local_experts - 1,
|
1046
950
|
BLOCK_SIZE=512,
|
1047
951
|
)
|
1048
952
|
else:
|
@@ -1062,7 +966,7 @@ class DeepEPMoE(EPMoE):
|
|
1062
966
|
a=down_input,
|
1063
967
|
b=self.w2_weight,
|
1064
968
|
c=down_output,
|
1065
|
-
batch_size=self.
|
969
|
+
batch_size=self.num_local_experts,
|
1066
970
|
weight_column_major=True,
|
1067
971
|
seg_indptr=seg_indptr,
|
1068
972
|
weight_indices=weight_indices_cur_rank,
|
@@ -1087,9 +991,9 @@ class DeepEPMoE(EPMoE):
|
|
1087
991
|
return hidden_states
|
1088
992
|
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
1089
993
|
# aiter does not accept -1, we use a expert mask to make these idx invalid
|
1090
|
-
# (idx ==
|
994
|
+
# (idx == num_local_experts) meaning not used in aiter fused_moe
|
1091
995
|
topk_idx_copy = topk_idx.to(torch.int32)
|
1092
|
-
topk_idx_copy[topk_idx_copy == -1] = self.
|
996
|
+
topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts
|
1093
997
|
|
1094
998
|
return fused_moe(
|
1095
999
|
hidden_states,
|
@@ -1315,12 +1219,74 @@ class DeepEPMoE(EPMoE):
|
|
1315
1219
|
return down_output
|
1316
1220
|
|
1317
1221
|
|
1222
|
+
class FlashInferEPMoE(EPMoE):
|
1223
|
+
def __init__(self, *args, **kwargs):
|
1224
|
+
renormalize = kwargs.pop("renormalize", True)
|
1225
|
+
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
1226
|
+
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
1227
|
+
num_expert_group = kwargs.pop("num_expert_group", None)
|
1228
|
+
topk_group = kwargs.pop("topk_group", None)
|
1229
|
+
correction_bias = kwargs.pop("correction_bias", None)
|
1230
|
+
super().__init__(*args, **kwargs)
|
1231
|
+
self.renormalize = renormalize
|
1232
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
1233
|
+
self.use_grouped_topk = use_grouped_topk
|
1234
|
+
if self.use_grouped_topk:
|
1235
|
+
assert num_expert_group is not None and topk_group is not None
|
1236
|
+
self.num_expert_group = num_expert_group
|
1237
|
+
self.topk_group = topk_group
|
1238
|
+
self.correction_bias = correction_bias
|
1239
|
+
self.use_flashinfer_trtllm_moe = use_flashinfer_trtllm_moe
|
1240
|
+
|
1241
|
+
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
1242
|
+
assert use_flashinfer_trtllm_moe
|
1243
|
+
assert (
|
1244
|
+
self.activation == "silu"
|
1245
|
+
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
1246
|
+
assert (
|
1247
|
+
self.renormalize
|
1248
|
+
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
1249
|
+
assert (
|
1250
|
+
self.num_fused_shared_experts == 0
|
1251
|
+
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
1252
|
+
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
|
1253
|
+
# NOTE: scales of hidden states have to be transposed!
|
1254
|
+
a_sf_t = a_sf.t().contiguous()
|
1255
|
+
assert fi_fused_moe is not None
|
1256
|
+
return fi_fused_moe.trtllm_fp8_block_scale_moe(
|
1257
|
+
routing_logits=router_logits.to(torch.float32),
|
1258
|
+
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
1259
|
+
hidden_states=a_q,
|
1260
|
+
hidden_states_scale=a_sf_t,
|
1261
|
+
gemm1_weights=self.w13_weight,
|
1262
|
+
gemm1_weights_scale=self.w13_weight_scale_inv,
|
1263
|
+
gemm2_weights=self.w2_weight,
|
1264
|
+
gemm2_weights_scale=self.w2_weight_scale_inv,
|
1265
|
+
num_experts=self.num_experts,
|
1266
|
+
top_k=self.top_k,
|
1267
|
+
n_group=self.num_expert_group,
|
1268
|
+
topk_group=self.topk_group,
|
1269
|
+
intermediate_size=self.w2_weight.shape[2],
|
1270
|
+
local_expert_offset=self.start_expert_id,
|
1271
|
+
local_num_experts=self.num_experts_per_partition,
|
1272
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
1273
|
+
tile_tokens_dim=_get_tile_tokens_dim(
|
1274
|
+
hidden_states.shape[0], self.top_k, self.num_experts
|
1275
|
+
),
|
1276
|
+
routing_method_type=2, # DeepSeek-styled routing method
|
1277
|
+
use_shuffled_weight=False,
|
1278
|
+
)
|
1279
|
+
|
1280
|
+
|
1318
1281
|
def get_moe_impl_class():
|
1319
1282
|
if global_server_args_dict["enable_deepep_moe"]:
|
1320
1283
|
return DeepEPMoE
|
1321
|
-
if global_server_args_dict["
|
1284
|
+
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
1322
1285
|
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
1323
1286
|
return FusedMoE
|
1287
|
+
if use_flashinfer_trtllm_moe:
|
1288
|
+
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
1289
|
+
return FlashInferEPMoE
|
1324
1290
|
if global_server_args_dict["enable_ep_moe"]:
|
1325
1291
|
return EPMoE
|
1326
1292
|
return FusedMoE
|