sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +7 -0
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +16 -1
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mooncake/conn.py +16 -0
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +13 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +132 -79
- sglang/srt/function_call/ebnf_composer.py +10 -3
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/qwen3_coder_detector.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +14 -3
- sglang/srt/layers/moe/ep_moe/layer.py +323 -242
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +90 -24
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +27 -10
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/lora/lora_registry.py +93 -29
- sglang/srt/managers/cache_controller.py +9 -7
- sglang/srt/managers/data_parallel_controller.py +4 -0
- sglang/srt/managers/io_struct.py +12 -0
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +14 -8
- sglang/srt/managers/scheduler.py +64 -1
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/tokenizer_manager.py +80 -15
- sglang/srt/managers/tp_worker.py +8 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +83 -27
- sglang/srt/models/deepseek_v2.py +75 -84
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/qwen3_moe.py +17 -71
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +65 -6
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/srt/utils.py +96 -1
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +118 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 256,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 256,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 256,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 2
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 256,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 256,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 256,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 2
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 16,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 256,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 8,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 8,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -75,8 +75,9 @@ class FusedMoE(torch.nn.Module):
|
|
75
75
|
inplace: bool = True,
|
76
76
|
no_combine: bool = False,
|
77
77
|
routed_scaling_factor: Optional[float] = None,
|
78
|
-
|
78
|
+
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
79
79
|
enable_ep_moe: Optional[bool] = False,
|
80
|
+
skip_quant: Optional[bool] = False,
|
80
81
|
):
|
81
82
|
super().__init__()
|
82
83
|
|
@@ -92,16 +93,13 @@ class FusedMoE(torch.nn.Module):
|
|
92
93
|
self.num_experts = num_experts
|
93
94
|
self.expert_map = None
|
94
95
|
|
95
|
-
if
|
96
|
+
if enable_flashinfer_cutlass_moe and quant_config is None:
|
96
97
|
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
97
|
-
|
98
|
+
enable_flashinfer_cutlass_moe = False
|
98
99
|
enable_ep_moe = False
|
99
100
|
|
100
|
-
self.
|
101
|
+
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
|
101
102
|
if enable_ep_moe:
|
102
|
-
assert (
|
103
|
-
self.enable_flashinfer_moe
|
104
|
-
), "FusedMoE only supports EP with --enable-flashinfer-moe"
|
105
103
|
self.ep_size = self.tp_size
|
106
104
|
self.ep_rank = self.tp_rank
|
107
105
|
self.tp_size = 1
|
@@ -110,16 +108,16 @@ class FusedMoE(torch.nn.Module):
|
|
110
108
|
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
111
109
|
# Create a expert map for the local experts
|
112
110
|
assert num_experts % self.ep_size == 0
|
113
|
-
self.
|
111
|
+
self.num_local_experts = num_experts // self.ep_size
|
114
112
|
self.expert_map[
|
115
113
|
self.ep_rank
|
116
|
-
* self.
|
117
|
-
* self.
|
118
|
-
] = torch.arange(0, self.
|
114
|
+
* self.num_local_experts : (self.ep_rank + 1)
|
115
|
+
* self.num_local_experts
|
116
|
+
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
119
117
|
else:
|
120
118
|
self.ep_size = 1
|
121
119
|
self.ep_rank = 0
|
122
|
-
self.
|
120
|
+
self.num_local_experts = num_experts
|
123
121
|
self.routed_scaling_factor = routed_scaling_factor
|
124
122
|
assert intermediate_size % self.tp_size == 0
|
125
123
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
@@ -134,6 +132,9 @@ class FusedMoE(torch.nn.Module):
|
|
134
132
|
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
|
135
133
|
)
|
136
134
|
|
135
|
+
if skip_quant:
|
136
|
+
return
|
137
|
+
|
137
138
|
if quant_config is None:
|
138
139
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
139
140
|
self.use_triton_kernels
|
@@ -141,13 +142,15 @@ class FusedMoE(torch.nn.Module):
|
|
141
142
|
else:
|
142
143
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
143
144
|
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
|
144
|
-
self.quant_method.
|
145
|
+
self.quant_method.enable_flashinfer_cutlass_moe = (
|
146
|
+
self.enable_flashinfer_cutlass_moe
|
147
|
+
)
|
145
148
|
assert self.quant_method is not None
|
146
149
|
|
147
150
|
self.quant_config = quant_config
|
148
151
|
self.quant_method.create_weights(
|
149
152
|
layer=self,
|
150
|
-
num_experts=self.
|
153
|
+
num_experts=self.num_local_experts,
|
151
154
|
hidden_size=hidden_size,
|
152
155
|
# FIXME: figure out which intermediate_size to use
|
153
156
|
intermediate_size=self.intermediate_size_per_partition,
|
@@ -376,6 +379,23 @@ class FusedMoE(torch.nn.Module):
|
|
376
379
|
if expert_id == -1:
|
377
380
|
return
|
378
381
|
|
382
|
+
self._weight_loader_impl(
|
383
|
+
param=param,
|
384
|
+
loaded_weight=loaded_weight,
|
385
|
+
weight_name=weight_name,
|
386
|
+
shard_id=shard_id,
|
387
|
+
expert_id=expert_id,
|
388
|
+
)
|
389
|
+
|
390
|
+
def _weight_loader_impl(
|
391
|
+
self,
|
392
|
+
param: torch.nn.Parameter,
|
393
|
+
loaded_weight: torch.Tensor,
|
394
|
+
weight_name: str,
|
395
|
+
shard_id: str,
|
396
|
+
expert_id: int,
|
397
|
+
) -> None:
|
398
|
+
|
379
399
|
# TP rank is set to 0 if EP is enabled
|
380
400
|
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
|
381
401
|
|
@@ -396,6 +416,10 @@ class FusedMoE(torch.nn.Module):
|
|
396
416
|
f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
|
397
417
|
)
|
398
418
|
|
419
|
+
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
420
|
+
if getattr(self, "use_flashinfer_trtllm_moe", False):
|
421
|
+
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
422
|
+
|
399
423
|
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
400
424
|
# Fetch the dim to shard the parameter/loaded weight
|
401
425
|
# based on the shard id. This will be whatever
|
@@ -603,37 +627,3 @@ class FusedMoE(torch.nn.Module):
|
|
603
627
|
("w3", ckpt_up_proj_name),
|
604
628
|
]
|
605
629
|
]
|
606
|
-
|
607
|
-
def _load_fp8_scale(
|
608
|
-
self,
|
609
|
-
param: torch.nn.Parameter,
|
610
|
-
loaded_weight: torch.Tensor,
|
611
|
-
weight_name: str,
|
612
|
-
shard_id: str,
|
613
|
-
expert_id: int,
|
614
|
-
) -> None:
|
615
|
-
param_data = param.data
|
616
|
-
|
617
|
-
# Input scales can be loaded directly and should be equal.
|
618
|
-
if "input_scale" in weight_name:
|
619
|
-
if (
|
620
|
-
param_data[expert_id] != 1
|
621
|
-
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
622
|
-
):
|
623
|
-
raise ValueError(
|
624
|
-
"input_scales of w1 and w3 of a layer "
|
625
|
-
f"must be equal. But got {param_data[expert_id]} "
|
626
|
-
f"vs. {loaded_weight}"
|
627
|
-
)
|
628
|
-
param_data[expert_id] = loaded_weight
|
629
|
-
# Weight scales
|
630
|
-
elif "weight_scale" in weight_name:
|
631
|
-
# If we are in merged column case (gate_up_proj)
|
632
|
-
if shard_id in ("w1", "w3"):
|
633
|
-
# We have to keep the weight scales of w1 and w3 because
|
634
|
-
# we need to re-quantize w1/w3 weights after weight loading.
|
635
|
-
idx = 0 if shard_id == "w1" else 1
|
636
|
-
param_data[expert_id][idx] = loaded_weight
|
637
|
-
# If we are in the row parallel case (down_proj)
|
638
|
-
else:
|
639
|
-
param_data[expert_id] = loaded_weight
|
@@ -1,21 +1,25 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
|
2
|
-
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING, Optional
|
3
6
|
|
4
7
|
import torch
|
5
8
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
6
9
|
from triton_kernels.matmul_ogs import matmul_ogs
|
7
|
-
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
10
|
+
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
8
11
|
|
9
12
|
from sglang.srt.utils import direct_register_custom_op
|
10
13
|
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
16
|
+
|
11
17
|
|
12
18
|
def triton_kernel_moe_forward(
|
13
19
|
hidden_states: torch.Tensor,
|
14
20
|
w1: torch.Tensor,
|
15
21
|
w2: torch.Tensor,
|
16
|
-
|
17
|
-
topk: int,
|
18
|
-
renormalize: bool,
|
22
|
+
topk_output: TopKOutput,
|
19
23
|
inplace: bool = False,
|
20
24
|
activation: str = "silu",
|
21
25
|
apply_router_weight_on_input: bool = False,
|
@@ -30,9 +34,8 @@ def triton_kernel_moe_forward(
|
|
30
34
|
block_shape: Optional[list[int]] = None,
|
31
35
|
) -> torch.Tensor:
|
32
36
|
|
33
|
-
|
34
|
-
|
35
|
-
routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
|
37
|
+
assert topk_output.format.is_triton_kernel()
|
38
|
+
routing_data, gather_idx, scatter_idx = topk_output
|
36
39
|
|
37
40
|
return triton_kernel_fused_experts(
|
38
41
|
hidden_states,
|
File without changes
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from enum import Enum, auto
|
5
|
+
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
|
10
|
+
class DispatchOutputFormat(Enum):
|
11
|
+
standard = auto()
|
12
|
+
deepep_normal = auto()
|
13
|
+
deepep_ll = auto()
|
14
|
+
|
15
|
+
def is_standard(self) -> bool:
|
16
|
+
return self == DispatchOutputFormat.standard
|
17
|
+
|
18
|
+
def is_deepep_normal(self) -> bool:
|
19
|
+
return self == DispatchOutputFormat.deepep_normal
|
20
|
+
|
21
|
+
def is_deepep_ll(self) -> bool:
|
22
|
+
return self == DispatchOutputFormat.deepep_ll
|
23
|
+
|
24
|
+
|
25
|
+
@runtime_checkable
|
26
|
+
class DispatchOutput(Protocol):
|
27
|
+
"""Protocol for dispatch outputs in different formats."""
|
28
|
+
|
29
|
+
@property
|
30
|
+
def format(self) -> DispatchOutputFormat: ...
|
31
|
+
|
32
|
+
|
33
|
+
class BaseDispatcherConfig(ABC):
|
34
|
+
"""Base class for dispatcher configs."""
|
35
|
+
|
36
|
+
pass
|
37
|
+
|
38
|
+
|
39
|
+
class BaseDispatcher(ABC):
|
40
|
+
"""Base class for dispatchers."""
|
41
|
+
|
42
|
+
@abstractmethod
|
43
|
+
def dispatch(self, *args, **kwargs) -> DispatchOutput:
|
44
|
+
pass
|
45
|
+
|
46
|
+
@abstractmethod
|
47
|
+
def combine(self, *args, **kwargs) -> torch.Tensor:
|
48
|
+
pass
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import NamedTuple
|
4
|
+
|
5
|
+
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
6
|
+
DispatchOutput,
|
7
|
+
DispatchOutputFormat,
|
8
|
+
)
|
9
|
+
|
10
|
+
|
11
|
+
class StandardDispatchOutput(NamedTuple):
|
12
|
+
"""Standard dispatch output."""
|
13
|
+
|
14
|
+
@property
|
15
|
+
def format(self) -> DispatchOutputFormat:
|
16
|
+
return DispatchOutputFormat.standard
|
17
|
+
|
18
|
+
|
19
|
+
assert isinstance(StandardDispatchOutput, DispatchOutput)
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -15,7 +15,8 @@
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
17
|
import math
|
18
|
-
from
|
18
|
+
from enum import Enum, auto
|
19
|
+
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
|
19
20
|
|
20
21
|
import torch
|
21
22
|
import torch.nn.functional as F
|
@@ -27,6 +28,7 @@ from sglang.srt.eplb.expert_location_dispatch import (
|
|
27
28
|
ExpertLocationDispatchInfo,
|
28
29
|
topk_ids_logical_to_physical,
|
29
30
|
)
|
31
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
30
32
|
from sglang.srt.utils import (
|
31
33
|
cpu_has_amx_support,
|
32
34
|
get_bool_env_var,
|
@@ -37,6 +39,12 @@ from sglang.srt.utils import (
|
|
37
39
|
is_npu,
|
38
40
|
)
|
39
41
|
|
42
|
+
try:
|
43
|
+
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
44
|
+
except ImportError:
|
45
|
+
pass
|
46
|
+
|
47
|
+
|
40
48
|
_is_cuda = is_cuda()
|
41
49
|
_is_hip = is_hip()
|
42
50
|
_is_cpu = is_cpu()
|
@@ -58,15 +66,58 @@ if _is_npu:
|
|
58
66
|
import torch_npu
|
59
67
|
|
60
68
|
|
61
|
-
|
69
|
+
# -------------------------------- TopKOutput ---------------------------------------
|
70
|
+
|
71
|
+
|
72
|
+
class TopKOutputFormat(Enum):
|
73
|
+
STANDARD = auto()
|
74
|
+
TRITON_KERNEL = auto()
|
75
|
+
|
76
|
+
def is_standard(self) -> bool:
|
77
|
+
return self == TopKOutputFormat.STANDARD
|
78
|
+
|
79
|
+
def is_triton_kernel(self) -> bool:
|
80
|
+
return self == TopKOutputFormat.TRITON_KERNEL
|
81
|
+
|
82
|
+
|
83
|
+
@runtime_checkable
|
84
|
+
class TopKOutput(Protocol):
|
85
|
+
"""Protocol for top-k outputs in different formats."""
|
86
|
+
|
87
|
+
@property
|
88
|
+
def format(self) -> TopKOutputFormat:
|
89
|
+
"""The format of the output."""
|
90
|
+
...
|
91
|
+
|
92
|
+
|
93
|
+
class StandardTopKOutput(NamedTuple):
|
94
|
+
"""Standard top-k output format."""
|
95
|
+
|
62
96
|
topk_weights: torch.Tensor
|
63
97
|
topk_ids: torch.Tensor
|
64
98
|
router_logits: torch.Tensor
|
65
99
|
|
100
|
+
@property
|
101
|
+
def format(self) -> TopKOutputFormat:
|
102
|
+
return TopKOutputFormat.STANDARD
|
66
103
|
|
67
|
-
class TopK(CustomOp):
|
68
104
|
|
69
|
-
|
105
|
+
class TritonKernelTopKOutput(NamedTuple):
|
106
|
+
"""Triton kernel top-k output format."""
|
107
|
+
|
108
|
+
routing_data: RoutingData
|
109
|
+
gather_indx: GatherIndx
|
110
|
+
scatter_indx: ScatterIndx
|
111
|
+
|
112
|
+
@property
|
113
|
+
def format(self) -> TopKOutputFormat:
|
114
|
+
return TopKOutputFormat.TRITON_KERNEL
|
115
|
+
|
116
|
+
|
117
|
+
# -------------------------------- TopK ---------------------------------------
|
118
|
+
|
119
|
+
|
120
|
+
class TopK(CustomOp):
|
70
121
|
|
71
122
|
def __init__(
|
72
123
|
self,
|
@@ -97,6 +148,8 @@ class TopK(CustomOp):
|
|
97
148
|
self.correction_bias = correction_bias
|
98
149
|
self.routed_scaling_factor = routed_scaling_factor
|
99
150
|
|
151
|
+
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
152
|
+
|
100
153
|
def forward_native(
|
101
154
|
self,
|
102
155
|
hidden_states: torch.Tensor,
|
@@ -131,23 +184,29 @@ class TopK(CustomOp):
|
|
131
184
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
132
185
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
133
186
|
) -> TopKOutput:
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
187
|
+
if self.use_triton_kernels:
|
188
|
+
routing_data, gather_idx, scatter_idx = routing(
|
189
|
+
router_logits, self.top_k, self.renormalize
|
190
|
+
)
|
191
|
+
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
192
|
+
else:
|
193
|
+
torch_native = False
|
194
|
+
return select_experts(
|
195
|
+
hidden_states=hidden_states,
|
196
|
+
router_logits=router_logits,
|
197
|
+
top_k=self.top_k,
|
198
|
+
use_grouped_topk=self.use_grouped_topk,
|
199
|
+
renormalize=self.renormalize,
|
200
|
+
topk_group=self.topk_group,
|
201
|
+
num_expert_group=self.num_expert_group,
|
202
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
203
|
+
custom_routing_function=self.custom_routing_function,
|
204
|
+
correction_bias=self.correction_bias,
|
205
|
+
torch_native=torch_native,
|
206
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
207
|
+
num_token_non_padded=num_token_non_padded,
|
208
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
209
|
+
)
|
151
210
|
|
152
211
|
def forward_cpu(
|
153
212
|
self,
|
@@ -217,6 +276,9 @@ class TopK(CustomOp):
|
|
217
276
|
)
|
218
277
|
|
219
278
|
|
279
|
+
# ------------------------------- TopK implementation -------------------------------------
|
280
|
+
|
281
|
+
|
220
282
|
def fused_topk_torch_native(
|
221
283
|
hidden_states: torch.Tensor,
|
222
284
|
gating_output: torch.Tensor,
|
@@ -335,7 +397,9 @@ def grouped_topk_gpu(
|
|
335
397
|
.reshape(num_token, -1)
|
336
398
|
) # [n, e]
|
337
399
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
338
|
-
topk_weights, topk_ids = torch.topk(
|
400
|
+
topk_weights, topk_ids = torch.topk(
|
401
|
+
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
|
402
|
+
)
|
339
403
|
if num_fused_shared_experts:
|
340
404
|
topk_ids[:, -1] = torch.randint(
|
341
405
|
low=num_experts,
|
@@ -424,7 +488,9 @@ def biased_grouped_topk_impl(
|
|
424
488
|
tmp_scores = scores_for_choice.masked_fill(
|
425
489
|
~score_mask.bool(), float("-inf")
|
426
490
|
) # [n, e]
|
427
|
-
_, topk_ids = torch.topk(
|
491
|
+
_, topk_ids = torch.topk(
|
492
|
+
tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
|
493
|
+
)
|
428
494
|
topk_weights = scores.gather(1, topk_ids)
|
429
495
|
|
430
496
|
if num_fused_shared_experts:
|
@@ -680,4 +746,4 @@ def select_experts(
|
|
680
746
|
|
681
747
|
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
|
682
748
|
|
683
|
-
return
|
749
|
+
return StandardTopKOutput(topk_weights, topk_ids, router_logits)
|
sglang/srt/layers/multimodal.py
CHANGED
@@ -55,14 +55,17 @@ def gpu_tensor_hash(tensor: torch.Tensor) -> int:
|
|
55
55
|
|
56
56
|
intermediate_hashes = torch.empty(n, dtype=torch.int64, device=tensor.device)
|
57
57
|
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
58
|
+
# Set cuda device to prevent ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
|
59
|
+
# Solution from Tri: https://github.com/Dao-AILab/flash-attention/issues/523#issuecomment-1707611579
|
60
|
+
with torch.cuda.device(tensor.device):
|
61
|
+
hash_kernel[grid](
|
62
|
+
tensor,
|
63
|
+
intermediate_hashes,
|
64
|
+
n,
|
65
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
66
|
+
PRIME=PRIME_1,
|
67
|
+
XCONST=PRIME_2,
|
68
|
+
)
|
66
69
|
|
67
70
|
# TODO: threads can't be synced on triton kernel
|
68
71
|
final_hash = intermediate_hashes.sum().item()
|