sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
16
16
|
ep_scatter,
|
17
17
|
gelu_and_mul_triton_kernel,
|
18
18
|
grouped_gemm_triton,
|
19
|
+
moe_ep_deepgemm_preprocess,
|
19
20
|
post_reorder_triton_kernel,
|
20
21
|
pre_reorder_triton_kernel,
|
21
22
|
run_moe_ep_preproess,
|
@@ -33,10 +34,12 @@ from sglang.srt.layers.quantization.base_config import (
|
|
33
34
|
)
|
34
35
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
35
36
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
37
|
+
is_fp8_fnuz,
|
36
38
|
scaled_fp8_quant,
|
37
39
|
sglang_per_token_group_quant_fp8,
|
38
40
|
sglang_per_token_quant_fp8,
|
39
41
|
)
|
42
|
+
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
40
43
|
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
41
44
|
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
42
45
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -50,6 +53,7 @@ from sglang.srt.utils import (
|
|
50
53
|
)
|
51
54
|
|
52
55
|
_is_hip = is_hip()
|
56
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
53
57
|
|
54
58
|
if _is_hip:
|
55
59
|
from vllm._custom_ops import scaled_fp8_quant
|
@@ -175,6 +179,7 @@ class EPMoE(torch.nn.Module):
|
|
175
179
|
assert (
|
176
180
|
num_fused_shared_experts == 0
|
177
181
|
), "num_fused_shared_experts is not supported in EP"
|
182
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
178
183
|
self.num_experts_per_partition = self.num_experts // self.tp_size
|
179
184
|
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
180
185
|
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
@@ -224,13 +229,182 @@ class EPMoE(torch.nn.Module):
|
|
224
229
|
|
225
230
|
self.grouped_gemm_runner = None
|
226
231
|
|
232
|
+
self.w13_weight_fp8 = (
|
233
|
+
self.w13_weight,
|
234
|
+
(
|
235
|
+
self.w13_weight_scale_inv
|
236
|
+
if self.use_block_quant
|
237
|
+
else self.w13_weight_scale
|
238
|
+
),
|
239
|
+
)
|
240
|
+
self.w2_weight_fp8 = (
|
241
|
+
self.w2_weight,
|
242
|
+
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
243
|
+
)
|
244
|
+
|
227
245
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
246
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
247
|
+
return self.forward_deepgemm(hidden_states, router_logits)
|
248
|
+
else:
|
249
|
+
return self.forward_normal(hidden_states, router_logits)
|
250
|
+
|
251
|
+
def forward_deepgemm(
|
252
|
+
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
|
253
|
+
):
|
254
|
+
assert self.quant_method is not None
|
255
|
+
assert self.activation == "silu"
|
228
256
|
hidden_states_shape = hidden_states.shape
|
229
257
|
hidden_states_dtype = hidden_states.dtype
|
230
258
|
hidden_states_device = hidden_states.device
|
259
|
+
topk_weights, topk_ids = select_experts(
|
260
|
+
hidden_states=hidden_states,
|
261
|
+
router_logits=router_logits,
|
262
|
+
top_k=self.top_k,
|
263
|
+
use_grouped_topk=self.use_grouped_topk,
|
264
|
+
renormalize=self.renormalize,
|
265
|
+
topk_group=self.topk_group,
|
266
|
+
num_expert_group=self.num_expert_group,
|
267
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
268
|
+
correction_bias=self.correction_bias,
|
269
|
+
custom_routing_function=self.custom_routing_function,
|
270
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
271
|
+
)
|
231
272
|
|
232
|
-
|
273
|
+
if not self.use_block_quant:
|
274
|
+
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
275
|
+
scale_block_size = 128
|
276
|
+
w13_weight_scale_n = 2 * (
|
277
|
+
(self.intermediate_size + scale_block_size - 1) // scale_block_size
|
278
|
+
)
|
279
|
+
w13_weight_scale_k = (
|
280
|
+
hidden_states_shape[-1] + scale_block_size - 1
|
281
|
+
) // scale_block_size
|
282
|
+
w13_weight_scale = (
|
283
|
+
self.w13_weight_scale.unsqueeze(1)
|
284
|
+
.repeat_interleave(w13_weight_scale_n, dim=1)
|
285
|
+
.unsqueeze(2)
|
286
|
+
.repeat_interleave(w13_weight_scale_k, dim=2)
|
287
|
+
)
|
288
|
+
self.w13_weight_fp8 = (
|
289
|
+
self.w13_weight,
|
290
|
+
w13_weight_scale,
|
291
|
+
)
|
292
|
+
w2_weight_scale_n = (
|
293
|
+
hidden_states_shape[-1] + scale_block_size - 1
|
294
|
+
) // scale_block_size
|
295
|
+
w2_weight_scale_k = (
|
296
|
+
self.intermediate_size + scale_block_size - 1
|
297
|
+
) // scale_block_size
|
298
|
+
w2_weight_scale = (
|
299
|
+
self.w2_weight_scale.unsqueeze(1)
|
300
|
+
.repeat_interleave(w2_weight_scale_n, dim=1)
|
301
|
+
.unsqueeze(2)
|
302
|
+
.repeat_interleave(w2_weight_scale_k, dim=2)
|
303
|
+
)
|
304
|
+
self.w2_weight_fp8 = (
|
305
|
+
self.w2_weight,
|
306
|
+
w2_weight_scale,
|
307
|
+
)
|
308
|
+
|
309
|
+
# PreReorder
|
310
|
+
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
|
311
|
+
moe_ep_deepgemm_preprocess(
|
312
|
+
topk_ids,
|
313
|
+
self.num_experts,
|
314
|
+
hidden_states,
|
315
|
+
self.top_k,
|
316
|
+
self.start_expert_id,
|
317
|
+
self.end_expert_id,
|
318
|
+
self.block_shape,
|
319
|
+
)
|
320
|
+
)
|
321
|
+
|
322
|
+
dispose_tensor(hidden_states)
|
323
|
+
|
324
|
+
# GroupGemm-0
|
325
|
+
gateup_input_fp8 = (
|
326
|
+
gateup_input,
|
327
|
+
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
|
328
|
+
)
|
329
|
+
num_groups, m, k = gateup_input_fp8[0].size()
|
330
|
+
n = self.w13_weight.size(1)
|
331
|
+
gateup_output = torch.empty(
|
332
|
+
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
333
|
+
)
|
334
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
335
|
+
gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
336
|
+
)
|
337
|
+
del gateup_input
|
338
|
+
del gateup_input_fp8
|
339
|
+
|
340
|
+
# Act
|
341
|
+
down_input = torch.empty(
|
342
|
+
(
|
343
|
+
gateup_output.shape[0],
|
344
|
+
gateup_output.shape[1],
|
345
|
+
gateup_output.shape[2] // 2,
|
346
|
+
),
|
347
|
+
device=hidden_states_device,
|
348
|
+
dtype=self.fp8_dtype,
|
349
|
+
)
|
350
|
+
scale_block_size = 128
|
351
|
+
down_input_scale = torch.empty(
|
352
|
+
(
|
353
|
+
gateup_output.shape[0],
|
354
|
+
gateup_output.shape[1],
|
355
|
+
gateup_output.shape[2] // 2 // scale_block_size,
|
356
|
+
),
|
357
|
+
device=hidden_states_device,
|
358
|
+
dtype=torch.float32,
|
359
|
+
)
|
360
|
+
silu_and_mul_masked_post_quant_fwd(
|
361
|
+
gateup_output,
|
362
|
+
down_input,
|
363
|
+
down_input_scale,
|
364
|
+
scale_block_size,
|
365
|
+
masked_m,
|
366
|
+
)
|
367
|
+
del gateup_output
|
233
368
|
|
369
|
+
# GroupGemm-1
|
370
|
+
n = self.w2_weight.size(1)
|
371
|
+
down_input_fp8 = (
|
372
|
+
down_input,
|
373
|
+
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
|
374
|
+
)
|
375
|
+
down_output = torch.empty(
|
376
|
+
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
377
|
+
)
|
378
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
379
|
+
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
|
380
|
+
)
|
381
|
+
del down_input
|
382
|
+
del down_input_fp8
|
383
|
+
|
384
|
+
# PostReorder
|
385
|
+
output = torch.empty(
|
386
|
+
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
387
|
+
)
|
388
|
+
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
389
|
+
down_output,
|
390
|
+
output,
|
391
|
+
src2dst,
|
392
|
+
topk_ids,
|
393
|
+
topk_weights,
|
394
|
+
self.start_expert_id,
|
395
|
+
self.end_expert_id,
|
396
|
+
self.top_k,
|
397
|
+
hidden_states_shape[1],
|
398
|
+
m_max * self.start_expert_id,
|
399
|
+
BLOCK_SIZE=512,
|
400
|
+
)
|
401
|
+
return output
|
402
|
+
|
403
|
+
def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
404
|
+
assert self.quant_method is not None
|
405
|
+
hidden_states_shape = hidden_states.shape
|
406
|
+
hidden_states_dtype = hidden_states.dtype
|
407
|
+
hidden_states_device = hidden_states.device
|
234
408
|
if self.grouped_gemm_runner is None:
|
235
409
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
236
410
|
hidden_states.device,
|
@@ -246,6 +420,7 @@ class EPMoE(torch.nn.Module):
|
|
246
420
|
renormalize=self.renormalize,
|
247
421
|
topk_group=self.topk_group,
|
248
422
|
num_expert_group=self.num_expert_group,
|
423
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
249
424
|
correction_bias=self.correction_bias,
|
250
425
|
custom_routing_function=self.custom_routing_function,
|
251
426
|
routed_scaling_factor=self.routed_scaling_factor,
|
@@ -437,6 +612,7 @@ class EPMoE(torch.nn.Module):
|
|
437
612
|
self.end_expert_id,
|
438
613
|
self.top_k,
|
439
614
|
hidden_states_shape[1],
|
615
|
+
0,
|
440
616
|
BLOCK_SIZE=512,
|
441
617
|
)
|
442
618
|
return output
|
@@ -843,6 +1019,33 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
843
1019
|
torch.max(layer.w13_weight_scale, dim=1).values,
|
844
1020
|
requires_grad=False,
|
845
1021
|
)
|
1022
|
+
if self.block_quant:
|
1023
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
1024
|
+
if _is_fp8_fnuz:
|
1025
|
+
# activation_scheme: dynamic
|
1026
|
+
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1027
|
+
weight=layer.w13_weight,
|
1028
|
+
weight_scale=layer.w13_weight_scale_inv,
|
1029
|
+
input_scale=None,
|
1030
|
+
)
|
1031
|
+
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1032
|
+
weight=layer.w2_weight,
|
1033
|
+
weight_scale=layer.w2_weight_scale_inv,
|
1034
|
+
input_scale=None,
|
1035
|
+
)
|
1036
|
+
# Reset the parameter
|
1037
|
+
layer.w13_weight = torch.nn.Parameter(
|
1038
|
+
w13_weight, requires_grad=False
|
1039
|
+
)
|
1040
|
+
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
1041
|
+
w13_weight_scale, requires_grad=False
|
1042
|
+
)
|
1043
|
+
layer.w13_input_scale = None
|
1044
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
1045
|
+
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
1046
|
+
w2_weight_scale, requires_grad=False
|
1047
|
+
)
|
1048
|
+
layer.w2_input_scale = None
|
846
1049
|
return
|
847
1050
|
|
848
1051
|
def apply(
|
@@ -1265,6 +1468,9 @@ class DeepEPMoE(EPMoE):
|
|
1265
1468
|
def get_moe_impl_class():
|
1266
1469
|
if global_server_args_dict["enable_deepep_moe"]:
|
1267
1470
|
return DeepEPMoE
|
1471
|
+
if global_server_args_dict["enable_flashinfer_moe"]:
|
1472
|
+
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
1473
|
+
return FusedMoE
|
1268
1474
|
if global_server_args_dict["enable_ep_moe"]:
|
1269
1475
|
return EPMoE
|
1270
1476
|
return FusedMoE
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 5
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 32,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 32,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 32,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 5
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 64,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 5
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 5
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 5
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 32,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 4
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 4
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -25,9 +25,11 @@ from sglang.srt.layers.quantization.int8_kernel import (
|
|
25
25
|
sglang_per_token_group_quant_int8,
|
26
26
|
)
|
27
27
|
from sglang.srt.utils import (
|
28
|
+
cpu_has_amx_support,
|
28
29
|
direct_register_custom_op,
|
29
30
|
get_bool_env_var,
|
30
31
|
get_device_name,
|
32
|
+
is_cpu,
|
31
33
|
is_cuda,
|
32
34
|
is_hip,
|
33
35
|
log_info_on_rank0,
|
@@ -36,9 +38,13 @@ from sglang.srt.utils import (
|
|
36
38
|
|
37
39
|
_is_hip = is_hip()
|
38
40
|
_is_cuda = is_cuda()
|
41
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
42
|
+
_is_cpu = is_cpu()
|
39
43
|
|
40
44
|
if _is_cuda:
|
41
45
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
46
|
+
elif _is_cpu and _is_cpu_amx_available:
|
47
|
+
pass
|
42
48
|
else:
|
43
49
|
from vllm import _custom_ops as vllm_ops
|
44
50
|
from vllm._custom_ops import scaled_fp8_quant
|
@@ -32,6 +32,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
32
32
|
|
33
33
|
if _use_aiter:
|
34
34
|
from aiter import ActivationType
|
35
|
+
from aiter.fused_moe import fused_moe
|
35
36
|
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
36
37
|
from aiter.ops.shuffle import shuffle_weight
|
37
38
|
|
@@ -204,7 +205,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
204
205
|
topk_weights, dtype=torch.float32
|
205
206
|
) # topk_weights must be FP32 (float32)
|
206
207
|
|
207
|
-
return
|
208
|
+
return fused_moe(
|
208
209
|
x,
|
209
210
|
layer.w13_weight,
|
210
211
|
layer.w2_weight,
|
@@ -241,7 +242,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
241
242
|
num_fused_shared_experts: int = 0,
|
242
243
|
custom_routing_function: Optional[Callable] = None,
|
243
244
|
correction_bias: Optional[torch.Tensor] = None,
|
245
|
+
activation: str = "silu",
|
246
|
+
apply_router_weight_on_input: bool = False,
|
244
247
|
inplace: bool = True,
|
248
|
+
no_combine: bool = False,
|
249
|
+
routed_scaling_factor: Optional[float] = None,
|
245
250
|
) -> torch.Tensor:
|
246
251
|
return moe_forward_native(
|
247
252
|
layer,
|
@@ -260,7 +265,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
260
265
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
261
266
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
262
267
|
|
263
|
-
forward_native =
|
268
|
+
forward_native = forward_cpu
|
264
269
|
|
265
270
|
|
266
271
|
class FusedMoE(torch.nn.Module):
|
@@ -310,6 +315,8 @@ class FusedMoE(torch.nn.Module):
|
|
310
315
|
inplace: bool = True,
|
311
316
|
no_combine: bool = False,
|
312
317
|
routed_scaling_factor: Optional[float] = None,
|
318
|
+
enable_flashinfer_moe: Optional[bool] = False,
|
319
|
+
enable_ep_moe: Optional[bool] = False,
|
313
320
|
):
|
314
321
|
super().__init__()
|
315
322
|
|
@@ -320,9 +327,40 @@ class FusedMoE(torch.nn.Module):
|
|
320
327
|
self.tp_size = (
|
321
328
|
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
322
329
|
)
|
330
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
331
|
+
self.num_experts = num_experts
|
332
|
+
self.expert_map = None
|
333
|
+
|
334
|
+
if enable_flashinfer_moe and quant_config is None:
|
335
|
+
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
336
|
+
enable_flashinfer_moe = False
|
337
|
+
enable_ep_moe = False
|
338
|
+
|
339
|
+
self.enable_flashinfer_moe = enable_flashinfer_moe
|
340
|
+
if enable_ep_moe:
|
341
|
+
assert (
|
342
|
+
self.enable_flashinfer_moe
|
343
|
+
), "FusedMoE only supports EP with --enable-flashinfer-moe"
|
344
|
+
self.ep_size = self.tp_size
|
345
|
+
self.ep_rank = self.tp_rank
|
346
|
+
self.tp_size = 1
|
347
|
+
self.tp_rank = 0
|
348
|
+
# Create a tensor of size num_experts filled with -1
|
349
|
+
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
350
|
+
# Create a expert map for the local experts
|
351
|
+
assert num_experts % self.ep_size == 0
|
352
|
+
self.local_num_experts = num_experts // self.ep_size
|
353
|
+
self.expert_map[
|
354
|
+
self.ep_rank
|
355
|
+
* self.local_num_experts : (self.ep_rank + 1)
|
356
|
+
* self.local_num_experts
|
357
|
+
] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
|
358
|
+
else:
|
359
|
+
self.ep_size = 1
|
360
|
+
self.ep_rank = 0
|
361
|
+
self.local_num_experts = num_experts
|
323
362
|
self.routed_scaling_factor = routed_scaling_factor
|
324
363
|
self.top_k = top_k
|
325
|
-
self.num_experts = num_experts
|
326
364
|
assert intermediate_size % self.tp_size == 0
|
327
365
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
328
366
|
self.reduce_results = reduce_results
|
@@ -340,7 +378,6 @@ class FusedMoE(torch.nn.Module):
|
|
340
378
|
self.use_presharded_weights = use_presharded_weights
|
341
379
|
self.inplace = inplace
|
342
380
|
self.no_combine = no_combine
|
343
|
-
self.local_num_experts = num_experts
|
344
381
|
|
345
382
|
if quant_config is None:
|
346
383
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
@@ -348,11 +385,13 @@ class FusedMoE(torch.nn.Module):
|
|
348
385
|
)
|
349
386
|
else:
|
350
387
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
388
|
+
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
|
389
|
+
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
|
351
390
|
assert self.quant_method is not None
|
352
391
|
|
353
392
|
self.quant_method.create_weights(
|
354
393
|
layer=self,
|
355
|
-
num_experts=
|
394
|
+
num_experts=self.local_num_experts,
|
356
395
|
hidden_size=hidden_size,
|
357
396
|
# FIXME: figure out which intermediate_size to use
|
358
397
|
intermediate_size=self.intermediate_size_per_partition,
|
@@ -446,12 +485,15 @@ class FusedMoE(torch.nn.Module):
|
|
446
485
|
|
447
486
|
# Narrow parameter and load.
|
448
487
|
# w1, gate_proj: Load into first logical weight of w13.
|
449
|
-
if shard_id == "w1":
|
450
|
-
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
451
488
|
# w3, up_proj: Load into second logical weight of w13.
|
489
|
+
# trtllm cutlass kernel assumes differently
|
490
|
+
assert shard_id in ("w1", "w3")
|
491
|
+
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
|
492
|
+
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
|
493
|
+
start = shard_size
|
452
494
|
else:
|
453
|
-
|
454
|
-
|
495
|
+
start = 0
|
496
|
+
expert_data = expert_data.narrow(shard_dim, start, shard_size)
|
455
497
|
expert_data.copy_(loaded_weight)
|
456
498
|
|
457
499
|
def _load_w2(
|
@@ -505,6 +547,11 @@ class FusedMoE(torch.nn.Module):
|
|
505
547
|
assert shard_id in ("w1", "w3")
|
506
548
|
expert_data.copy_(loaded_weight)
|
507
549
|
|
550
|
+
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
551
|
+
if self.expert_map is None:
|
552
|
+
return expert_id
|
553
|
+
return self.expert_map[expert_id].item()
|
554
|
+
|
508
555
|
def weight_loader(
|
509
556
|
self,
|
510
557
|
param: torch.nn.Parameter,
|
@@ -513,6 +560,13 @@ class FusedMoE(torch.nn.Module):
|
|
513
560
|
shard_id: str,
|
514
561
|
expert_id: int,
|
515
562
|
) -> None:
|
563
|
+
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
564
|
+
if expert_id == -1:
|
565
|
+
return
|
566
|
+
|
567
|
+
# TP rank is set to 0 if EP is enabled
|
568
|
+
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
|
569
|
+
|
516
570
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
517
571
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
518
572
|
# against known CompressionFormat enum values that have this quality
|
@@ -537,7 +591,6 @@ class FusedMoE(torch.nn.Module):
|
|
537
591
|
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
538
592
|
|
539
593
|
expert_data = param.data[expert_id]
|
540
|
-
tp_rank = get_tensor_model_parallel_rank()
|
541
594
|
|
542
595
|
# is_transposed: if the dim to shard the weight
|
543
596
|
# should be flipped. Required by GPTQ, compressed-tensors
|
@@ -545,7 +598,7 @@ class FusedMoE(torch.nn.Module):
|
|
545
598
|
is_transposed = getattr(param, "is_transposed", False)
|
546
599
|
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
547
600
|
if is_transposed:
|
548
|
-
shard_dim =
|
601
|
+
shard_dim = int(not shard_dim)
|
549
602
|
|
550
603
|
# Case input scale: input_scale loading is only supported for fp8
|
551
604
|
if "input_scale" in weight_name:
|
@@ -686,9 +739,19 @@ class FusedMoE(torch.nn.Module):
|
|
686
739
|
activation=self.activation,
|
687
740
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
688
741
|
routed_scaling_factor=self.routed_scaling_factor,
|
742
|
+
**(
|
743
|
+
dict(
|
744
|
+
tp_rank=self.tp_rank,
|
745
|
+
tp_size=self.tp_size,
|
746
|
+
ep_rank=self.ep_rank,
|
747
|
+
ep_size=self.ep_size,
|
748
|
+
)
|
749
|
+
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
|
750
|
+
else {}
|
751
|
+
),
|
689
752
|
)
|
690
753
|
|
691
|
-
if self.reduce_results and self.tp_size > 1:
|
754
|
+
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
692
755
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
693
756
|
|
694
757
|
return final_hidden_states
|