sglang 0.4.1.post7__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +17 -11
- sglang/bench_one_batch.py +14 -6
- sglang/bench_serving.py +47 -44
- sglang/lang/chat_template.py +31 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
- sglang/srt/entrypoints/engine.py +5 -2
- sglang/srt/entrypoints/http_server.py +24 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +5 -5
- sglang/srt/layers/dp_attention.py +3 -1
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +24 -9
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +20 -12
- sglang/srt/layers/moe/fused_moe_native.py +17 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
- sglang/srt/layers/parameter.py +16 -7
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +4 -1
- sglang/srt/layers/rotary_embedding.py +6 -1
- sglang/srt/layers/sampler.py +28 -8
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +36 -5
- sglang/srt/managers/schedule_batch.py +31 -25
- sglang/srt/managers/scheduler.py +61 -35
- sglang/srt/managers/tokenizer_manager.py +4 -0
- sglang/srt/model_executor/cuda_graph_runner.py +23 -25
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +7 -4
- sglang/srt/model_loader/loader.py +75 -0
- sglang/srt/model_loader/weight_utils.py +91 -5
- sglang/srt/models/commandr.py +14 -2
- sglang/srt/models/dbrx.py +9 -1
- sglang/srt/models/deepseek_v2.py +3 -3
- sglang/srt/models/gemma2.py +9 -1
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/minicpm3.py +3 -3
- sglang/srt/models/torch_native_llama.py +17 -4
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +5 -4
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +4 -14
- sglang/srt/server.py +2 -2
- sglang/srt/server_args.py +20 -1
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +62 -65
- sglang/test/test_programs.py +1 -0
- sglang/test/test_utils.py +81 -22
- sglang/version.py +1 -1
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/METADATA +7 -7
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/RECORD +67 -56
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -114,6 +114,8 @@ class EPMoE(torch.nn.Module):
|
|
114
114
|
tp_size: Optional[int] = None,
|
115
115
|
prefix: str = "",
|
116
116
|
correction_bias: Optional[torch.Tensor] = None,
|
117
|
+
custom_routing_function: Optional[Callable] = None,
|
118
|
+
activation: str = "silu",
|
117
119
|
):
|
118
120
|
super().__init__()
|
119
121
|
|
@@ -140,6 +142,8 @@ class EPMoE(torch.nn.Module):
|
|
140
142
|
self.num_expert_group = num_expert_group
|
141
143
|
self.topk_group = topk_group
|
142
144
|
self.correction_bias = correction_bias
|
145
|
+
self.custom_routing_function = custom_routing_function
|
146
|
+
self.activation = activation
|
143
147
|
|
144
148
|
if quant_config is None:
|
145
149
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
@@ -166,6 +170,7 @@ class EPMoE(torch.nn.Module):
|
|
166
170
|
|
167
171
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
168
172
|
assert self.quant_method is not None
|
173
|
+
assert self.activation == "silu"
|
169
174
|
|
170
175
|
if self.grouped_gemm_runner is None:
|
171
176
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
@@ -181,6 +186,7 @@ class EPMoE(torch.nn.Module):
|
|
181
186
|
topk_group=self.topk_group,
|
182
187
|
num_expert_group=self.num_expert_group,
|
183
188
|
correction_bias=self.correction_bias,
|
189
|
+
custom_routing_function=self.custom_routing_function,
|
184
190
|
)
|
185
191
|
|
186
192
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
@@ -254,16 +260,20 @@ class EPMoE(torch.nn.Module):
|
|
254
260
|
dtype=torch.float32,
|
255
261
|
device=hidden_states.device,
|
256
262
|
)
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
263
|
+
|
264
|
+
if self.activation == "silu":
|
265
|
+
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
266
|
+
gateup_output,
|
267
|
+
down_input,
|
268
|
+
gateup_output.shape[1],
|
269
|
+
reorder_topk_ids,
|
270
|
+
self.w2_input_scale,
|
271
|
+
self.start_expert_id,
|
272
|
+
self.end_expert_id,
|
273
|
+
BLOCK_SIZE=512,
|
274
|
+
)
|
275
|
+
else:
|
276
|
+
raise ValueError(f"Unsupported activation: {self.activation=}")
|
267
277
|
|
268
278
|
# GroupGemm-1
|
269
279
|
down_output = torch.empty(
|
@@ -309,7 +319,6 @@ class EPMoE(torch.nn.Module):
|
|
309
319
|
ckpt_up_proj_name: str,
|
310
320
|
num_experts: int,
|
311
321
|
) -> List[Tuple[str, str, int, str]]:
|
312
|
-
|
313
322
|
return [
|
314
323
|
# (param_name, weight_name, expert_id, shard_id)
|
315
324
|
(
|
@@ -354,7 +363,6 @@ class EPMoE(torch.nn.Module):
|
|
354
363
|
)
|
355
364
|
return
|
356
365
|
|
357
|
-
expert_data = param.data[expert_id]
|
358
366
|
if shard_id == "w2":
|
359
367
|
param.data[expert_id] = loaded_weight
|
360
368
|
elif shard_id == "w1":
|
@@ -8,7 +8,7 @@ from typing import Callable, Optional
|
|
8
8
|
import torch
|
9
9
|
from torch.nn import functional as F
|
10
10
|
|
11
|
-
from sglang.srt.layers.activation import SiluAndMul
|
11
|
+
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
12
12
|
from sglang.srt.layers.moe.topk import select_experts
|
13
13
|
|
14
14
|
|
@@ -23,6 +23,7 @@ def fused_moe_forward_native(
|
|
23
23
|
num_expert_group: Optional[int] = None,
|
24
24
|
custom_routing_function: Optional[Callable] = None,
|
25
25
|
correction_bias: Optional[torch.Tensor] = None,
|
26
|
+
activation: str = "silu",
|
26
27
|
) -> torch.Tensor:
|
27
28
|
topk_weights, topk_ids = select_experts(
|
28
29
|
hidden_states=x,
|
@@ -41,7 +42,12 @@ def fused_moe_forward_native(
|
|
41
42
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
42
43
|
w2_weights = layer.w2_weight[topk_ids]
|
43
44
|
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
44
|
-
|
45
|
+
if activation == "silu":
|
46
|
+
x1 = F.silu(x1)
|
47
|
+
elif activation == "gelu":
|
48
|
+
x1 = F.gelu(x1)
|
49
|
+
else:
|
50
|
+
raise ValueError(f"Unsupported activation: {activation=}")
|
45
51
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
46
52
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
47
53
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
@@ -58,6 +64,7 @@ def moe_forward_native(
|
|
58
64
|
num_expert_group: Optional[int] = None,
|
59
65
|
custom_routing_function: Optional[Callable] = None,
|
60
66
|
correction_bias: Optional[torch.Tensor] = None,
|
67
|
+
activation: str = "silu",
|
61
68
|
) -> torch.Tensor:
|
62
69
|
|
63
70
|
topk_weights, topk_ids = select_experts(
|
@@ -84,6 +91,13 @@ def moe_forward_native(
|
|
84
91
|
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
85
92
|
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
86
93
|
|
94
|
+
if activation == "silu":
|
95
|
+
act = SiluAndMul()
|
96
|
+
elif activation == "gelu":
|
97
|
+
act = GeluAndMul()
|
98
|
+
else:
|
99
|
+
raise ValueError(f"Unsupported activation: {activation=}")
|
100
|
+
|
87
101
|
outputs = []
|
88
102
|
start_idx = 0
|
89
103
|
for i, num_tokens in enumerate(tokens_per_expert):
|
@@ -96,7 +110,7 @@ def moe_forward_native(
|
|
96
110
|
layer_w2_weight = layer.w2_weight[i]
|
97
111
|
|
98
112
|
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
|
99
|
-
gate_up =
|
113
|
+
gate_up = act(gate_up)
|
100
114
|
expert_out = F.linear(gate_up, layer_w2_weight)
|
101
115
|
outputs.append(expert_out)
|
102
116
|
start_idx = end_idx
|
@@ -0,0 +1,164 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 32,
|
4
|
+
"BLOCK_SIZE_N": 32,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 16,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 2,
|
9
|
+
"waves_per_eu": 0
|
10
|
+
},
|
11
|
+
"2": {
|
12
|
+
"BLOCK_SIZE_M": 32,
|
13
|
+
"BLOCK_SIZE_N": 64,
|
14
|
+
"BLOCK_SIZE_K": 128,
|
15
|
+
"GROUP_SIZE_M": 1,
|
16
|
+
"num_warps": 4,
|
17
|
+
"num_stages": 2,
|
18
|
+
"waves_per_eu": 0
|
19
|
+
},
|
20
|
+
"4": {
|
21
|
+
"BLOCK_SIZE_M": 64,
|
22
|
+
"BLOCK_SIZE_N": 64,
|
23
|
+
"BLOCK_SIZE_K": 128,
|
24
|
+
"GROUP_SIZE_M": 16,
|
25
|
+
"num_warps": 4,
|
26
|
+
"num_stages": 2,
|
27
|
+
"waves_per_eu": 0
|
28
|
+
},
|
29
|
+
"8": {
|
30
|
+
"BLOCK_SIZE_M": 32,
|
31
|
+
"BLOCK_SIZE_N": 128,
|
32
|
+
"BLOCK_SIZE_K": 128,
|
33
|
+
"GROUP_SIZE_M": 32,
|
34
|
+
"num_warps": 4,
|
35
|
+
"num_stages": 2,
|
36
|
+
"waves_per_eu": 0
|
37
|
+
},
|
38
|
+
"16": {
|
39
|
+
"BLOCK_SIZE_M": 32,
|
40
|
+
"BLOCK_SIZE_N": 128,
|
41
|
+
"BLOCK_SIZE_K": 128,
|
42
|
+
"GROUP_SIZE_M": 1,
|
43
|
+
"num_warps": 4,
|
44
|
+
"num_stages": 2,
|
45
|
+
"waves_per_eu": 0
|
46
|
+
},
|
47
|
+
"24": {
|
48
|
+
"BLOCK_SIZE_M": 32,
|
49
|
+
"BLOCK_SIZE_N": 128,
|
50
|
+
"BLOCK_SIZE_K": 128,
|
51
|
+
"GROUP_SIZE_M": 4,
|
52
|
+
"num_warps": 4,
|
53
|
+
"num_stages": 2,
|
54
|
+
"waves_per_eu": 0
|
55
|
+
},
|
56
|
+
"32": {
|
57
|
+
"BLOCK_SIZE_M": 32,
|
58
|
+
"BLOCK_SIZE_N": 128,
|
59
|
+
"BLOCK_SIZE_K": 128,
|
60
|
+
"GROUP_SIZE_M": 8,
|
61
|
+
"num_warps": 4,
|
62
|
+
"num_stages": 2,
|
63
|
+
"waves_per_eu": 0
|
64
|
+
},
|
65
|
+
"48": {
|
66
|
+
"BLOCK_SIZE_M": 32,
|
67
|
+
"BLOCK_SIZE_N": 128,
|
68
|
+
"BLOCK_SIZE_K": 128,
|
69
|
+
"GROUP_SIZE_M": 4,
|
70
|
+
"num_warps": 4,
|
71
|
+
"num_stages": 2,
|
72
|
+
"waves_per_eu": 0
|
73
|
+
},
|
74
|
+
"64": {
|
75
|
+
"BLOCK_SIZE_M": 256,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 2,
|
81
|
+
"waves_per_eu": 0
|
82
|
+
},
|
83
|
+
"96": {
|
84
|
+
"BLOCK_SIZE_M": 32,
|
85
|
+
"BLOCK_SIZE_N": 128,
|
86
|
+
"BLOCK_SIZE_K": 128,
|
87
|
+
"GROUP_SIZE_M": 8,
|
88
|
+
"num_warps": 4,
|
89
|
+
"num_stages": 2,
|
90
|
+
"waves_per_eu": 0
|
91
|
+
},
|
92
|
+
"128": {
|
93
|
+
"BLOCK_SIZE_M": 32,
|
94
|
+
"BLOCK_SIZE_N": 16,
|
95
|
+
"BLOCK_SIZE_K": 128,
|
96
|
+
"GROUP_SIZE_M": 4,
|
97
|
+
"num_warps": 4,
|
98
|
+
"num_stages": 2,
|
99
|
+
"waves_per_eu": 0
|
100
|
+
},
|
101
|
+
"256": {
|
102
|
+
"BLOCK_SIZE_M": 64,
|
103
|
+
"BLOCK_SIZE_N": 16,
|
104
|
+
"BLOCK_SIZE_K": 128,
|
105
|
+
"GROUP_SIZE_M": 1,
|
106
|
+
"num_warps": 4,
|
107
|
+
"num_stages": 2,
|
108
|
+
"waves_per_eu": 0
|
109
|
+
},
|
110
|
+
"512": {
|
111
|
+
"BLOCK_SIZE_M": 64,
|
112
|
+
"BLOCK_SIZE_N": 64,
|
113
|
+
"BLOCK_SIZE_K": 128,
|
114
|
+
"GROUP_SIZE_M": 32,
|
115
|
+
"num_warps": 4,
|
116
|
+
"num_stages": 2,
|
117
|
+
"waves_per_eu": 0
|
118
|
+
},
|
119
|
+
"1024": {
|
120
|
+
"BLOCK_SIZE_M": 64,
|
121
|
+
"BLOCK_SIZE_N": 64,
|
122
|
+
"BLOCK_SIZE_K": 128,
|
123
|
+
"GROUP_SIZE_M": 4,
|
124
|
+
"num_warps": 8,
|
125
|
+
"num_stages": 2,
|
126
|
+
"waves_per_eu": 0
|
127
|
+
},
|
128
|
+
"1536": {
|
129
|
+
"BLOCK_SIZE_M": 64,
|
130
|
+
"BLOCK_SIZE_N": 64,
|
131
|
+
"BLOCK_SIZE_K": 128,
|
132
|
+
"GROUP_SIZE_M": 8,
|
133
|
+
"num_warps": 4,
|
134
|
+
"num_stages": 2,
|
135
|
+
"waves_per_eu": 0
|
136
|
+
},
|
137
|
+
"2048": {
|
138
|
+
"BLOCK_SIZE_M": 32,
|
139
|
+
"BLOCK_SIZE_N": 64,
|
140
|
+
"BLOCK_SIZE_K": 128,
|
141
|
+
"GROUP_SIZE_M": 1,
|
142
|
+
"num_warps": 4,
|
143
|
+
"num_stages": 2,
|
144
|
+
"waves_per_eu": 0
|
145
|
+
},
|
146
|
+
"3072": {
|
147
|
+
"BLOCK_SIZE_M": 32,
|
148
|
+
"BLOCK_SIZE_N": 128,
|
149
|
+
"BLOCK_SIZE_K": 128,
|
150
|
+
"GROUP_SIZE_M": 1,
|
151
|
+
"num_warps": 4,
|
152
|
+
"num_stages": 2,
|
153
|
+
"waves_per_eu": 0
|
154
|
+
},
|
155
|
+
"4096": {
|
156
|
+
"BLOCK_SIZE_M": 64,
|
157
|
+
"BLOCK_SIZE_N": 128,
|
158
|
+
"BLOCK_SIZE_K": 64,
|
159
|
+
"GROUP_SIZE_M": 4,
|
160
|
+
"num_warps": 4,
|
161
|
+
"num_stages": 2,
|
162
|
+
"waves_per_eu": 0
|
163
|
+
}
|
164
|
+
}
|
@@ -711,6 +711,7 @@ def inplace_fused_experts(
|
|
711
711
|
w2: torch.Tensor,
|
712
712
|
topk_weights: torch.Tensor,
|
713
713
|
topk_ids: torch.Tensor,
|
714
|
+
activation: str = "silu",
|
714
715
|
use_fp8_w8a8: bool = False,
|
715
716
|
use_int8_w8a16: bool = False,
|
716
717
|
w1_scale: Optional[torch.Tensor] = None,
|
@@ -726,6 +727,7 @@ def inplace_fused_experts(
|
|
726
727
|
topk_weights,
|
727
728
|
topk_ids,
|
728
729
|
True,
|
730
|
+
activation,
|
729
731
|
use_fp8_w8a8,
|
730
732
|
use_int8_w8a16,
|
731
733
|
w1_scale,
|
@@ -742,6 +744,7 @@ def inplace_fused_experts_fake(
|
|
742
744
|
w2: torch.Tensor,
|
743
745
|
topk_weights: torch.Tensor,
|
744
746
|
topk_ids: torch.Tensor,
|
747
|
+
activation: str = "silu",
|
745
748
|
use_fp8_w8a8: bool = False,
|
746
749
|
use_int8_w8a16: bool = False,
|
747
750
|
w1_scale: Optional[torch.Tensor] = None,
|
@@ -767,6 +770,7 @@ def outplace_fused_experts(
|
|
767
770
|
w2: torch.Tensor,
|
768
771
|
topk_weights: torch.Tensor,
|
769
772
|
topk_ids: torch.Tensor,
|
773
|
+
activation: str = "silu",
|
770
774
|
use_fp8_w8a8: bool = False,
|
771
775
|
use_int8_w8a16: bool = False,
|
772
776
|
w1_scale: Optional[torch.Tensor] = None,
|
@@ -782,6 +786,7 @@ def outplace_fused_experts(
|
|
782
786
|
topk_weights,
|
783
787
|
topk_ids,
|
784
788
|
False,
|
789
|
+
activation,
|
785
790
|
use_fp8_w8a8,
|
786
791
|
use_int8_w8a16,
|
787
792
|
w1_scale,
|
@@ -798,6 +803,7 @@ def outplace_fused_experts_fake(
|
|
798
803
|
w2: torch.Tensor,
|
799
804
|
topk_weights: torch.Tensor,
|
800
805
|
topk_ids: torch.Tensor,
|
806
|
+
activation: str = "silu",
|
801
807
|
use_fp8_w8a8: bool = False,
|
802
808
|
use_int8_w8a16: bool = False,
|
803
809
|
w1_scale: Optional[torch.Tensor] = None,
|
@@ -824,6 +830,7 @@ def fused_experts(
|
|
824
830
|
topk_weights: torch.Tensor,
|
825
831
|
topk_ids: torch.Tensor,
|
826
832
|
inplace: bool = False,
|
833
|
+
activation: str = "silu",
|
827
834
|
use_fp8_w8a8: bool = False,
|
828
835
|
use_int8_w8a16: bool = False,
|
829
836
|
w1_scale: Optional[torch.Tensor] = None,
|
@@ -839,6 +846,7 @@ def fused_experts(
|
|
839
846
|
w2,
|
840
847
|
topk_weights,
|
841
848
|
topk_ids,
|
849
|
+
activation,
|
842
850
|
use_fp8_w8a8,
|
843
851
|
use_int8_w8a16,
|
844
852
|
w1_scale,
|
@@ -855,6 +863,7 @@ def fused_experts(
|
|
855
863
|
w2,
|
856
864
|
topk_weights,
|
857
865
|
topk_ids,
|
866
|
+
activation,
|
858
867
|
use_fp8_w8a8,
|
859
868
|
use_int8_w8a16,
|
860
869
|
w1_scale,
|
@@ -872,6 +881,7 @@ def fused_experts_impl(
|
|
872
881
|
topk_weights: torch.Tensor,
|
873
882
|
topk_ids: torch.Tensor,
|
874
883
|
inplace: bool = False,
|
884
|
+
activation: str = "silu",
|
875
885
|
use_fp8_w8a8: bool = False,
|
876
886
|
use_int8_w8a16: bool = False,
|
877
887
|
w1_scale: Optional[torch.Tensor] = None,
|
@@ -986,7 +996,12 @@ def fused_experts_impl(
|
|
986
996
|
block_shape=block_shape,
|
987
997
|
)
|
988
998
|
|
989
|
-
|
999
|
+
if activation == "silu":
|
1000
|
+
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
1001
|
+
elif activation == "gelu":
|
1002
|
+
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
1003
|
+
else:
|
1004
|
+
raise ValueError(f"Unsupported activation: {activation=}")
|
990
1005
|
|
991
1006
|
invoke_fused_moe_kernel(
|
992
1007
|
intermediate_cache2,
|
@@ -1042,6 +1057,7 @@ def fused_moe(
|
|
1042
1057
|
topk: int,
|
1043
1058
|
renormalize: bool,
|
1044
1059
|
inplace: bool = False,
|
1060
|
+
activation: str = "silu",
|
1045
1061
|
use_grouped_topk: bool = False,
|
1046
1062
|
num_expert_group: Optional[int] = None,
|
1047
1063
|
topk_group: Optional[int] = None,
|
@@ -1111,6 +1127,7 @@ def fused_moe(
|
|
1111
1127
|
topk_weights,
|
1112
1128
|
topk_ids,
|
1113
1129
|
inplace=inplace,
|
1130
|
+
activation=activation,
|
1114
1131
|
use_fp8_w8a8=use_fp8_w8a8,
|
1115
1132
|
use_int8_w8a16=use_int8_w8a16,
|
1116
1133
|
w1_scale=w1_scale,
|
@@ -126,6 +126,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
126
126
|
num_expert_group: Optional[int] = None,
|
127
127
|
custom_routing_function: Optional[Callable] = None,
|
128
128
|
correction_bias: Optional[torch.Tensor] = None,
|
129
|
+
activation: str = "silu",
|
129
130
|
) -> torch.Tensor:
|
130
131
|
return self.forward(
|
131
132
|
x=x,
|
@@ -138,6 +139,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
138
139
|
num_expert_group=num_expert_group,
|
139
140
|
custom_routing_function=custom_routing_function,
|
140
141
|
correction_bias=correction_bias,
|
142
|
+
activation=activation,
|
141
143
|
)
|
142
144
|
|
143
145
|
def forward_cuda(
|
@@ -152,6 +154,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
152
154
|
num_expert_group: Optional[int] = None,
|
153
155
|
custom_routing_function: Optional[Callable] = None,
|
154
156
|
correction_bias: Optional[torch.Tensor] = None,
|
157
|
+
activation: str = "silu",
|
155
158
|
) -> torch.Tensor:
|
156
159
|
topk_weights, topk_ids = select_experts(
|
157
160
|
hidden_states=x,
|
@@ -169,6 +172,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
169
172
|
import ater
|
170
173
|
from ater.fused_moe import fused_experts_ck
|
171
174
|
|
175
|
+
assert activation == "silu", f"{activation=} is not supported."
|
176
|
+
|
172
177
|
return fused_experts_ck(
|
173
178
|
hidden_states=x,
|
174
179
|
w1=layer.w13_weight,
|
@@ -184,6 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
184
189
|
topk_weights=topk_weights,
|
185
190
|
topk_ids=topk_ids,
|
186
191
|
inplace=True,
|
192
|
+
activation=activation,
|
187
193
|
)
|
188
194
|
|
189
195
|
def forward_cpu(
|
@@ -256,6 +262,7 @@ class FusedMoE(torch.nn.Module):
|
|
256
262
|
prefix: str = "",
|
257
263
|
custom_routing_function: Optional[Callable] = None,
|
258
264
|
correction_bias: Optional[torch.Tensor] = None,
|
265
|
+
activation: str = "silu",
|
259
266
|
use_presharded_weights: bool = False,
|
260
267
|
):
|
261
268
|
super().__init__()
|
@@ -279,6 +286,7 @@ class FusedMoE(torch.nn.Module):
|
|
279
286
|
self.topk_group = topk_group
|
280
287
|
self.custom_routing_function = custom_routing_function
|
281
288
|
self.correction_bias = correction_bias
|
289
|
+
self.activation = activation
|
282
290
|
|
283
291
|
if quant_config is None:
|
284
292
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
@@ -589,6 +597,7 @@ class FusedMoE(torch.nn.Module):
|
|
589
597
|
num_expert_group=self.num_expert_group,
|
590
598
|
custom_routing_function=self.custom_routing_function,
|
591
599
|
correction_bias=self.correction_bias,
|
600
|
+
activation=self.activation,
|
592
601
|
)
|
593
602
|
|
594
603
|
if self.reduce_results and self.tp_size > 1:
|
sglang/srt/layers/parameter.py
CHANGED
@@ -124,7 +124,13 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
124
124
|
assert param_data.shape == loaded_weight.shape
|
125
125
|
param_data.copy_(loaded_weight)
|
126
126
|
|
127
|
-
def load_qkv_weight(
|
127
|
+
def load_qkv_weight(
|
128
|
+
self,
|
129
|
+
loaded_weight: torch.Tensor,
|
130
|
+
tp_rank: int,
|
131
|
+
use_presharded_weights: bool = False,
|
132
|
+
**kwargs,
|
133
|
+
):
|
128
134
|
|
129
135
|
shard_offset = kwargs.get("shard_offset")
|
130
136
|
shard_size = kwargs.get("shard_size")
|
@@ -142,11 +148,14 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
142
148
|
param_data = self.data
|
143
149
|
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
144
150
|
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
145
|
-
|
146
|
-
|
147
|
-
|
151
|
+
if not use_presharded_weights:
|
152
|
+
loaded_weight = loaded_weight.narrow(
|
153
|
+
self.output_dim, shard_id * shard_size, shard_size
|
154
|
+
)
|
148
155
|
|
149
|
-
assert
|
156
|
+
assert (
|
157
|
+
param_data.shape == loaded_weight.shape
|
158
|
+
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
150
159
|
param_data.copy_(loaded_weight)
|
151
160
|
|
152
161
|
|
@@ -292,7 +301,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
|
|
292
301
|
packed_factor: Union[int, Fraction],
|
293
302
|
packed_dim: int,
|
294
303
|
marlin_tile_size: Optional[int] = None,
|
295
|
-
**kwargs
|
304
|
+
**kwargs,
|
296
305
|
):
|
297
306
|
self._packed_factor = packed_factor
|
298
307
|
self._packed_dim = packed_dim
|
@@ -336,7 +345,7 @@ class PackedvLLMParameter(ModelWeightParameter):
|
|
336
345
|
packed_factor: Union[int, Fraction],
|
337
346
|
packed_dim: int,
|
338
347
|
marlin_tile_size: Optional[int] = None,
|
339
|
-
**kwargs
|
348
|
+
**kwargs,
|
340
349
|
):
|
341
350
|
self._packed_factor = packed_factor
|
342
351
|
self._packed_dim = packed_dim
|
@@ -0,0 +1,164 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 64,
|
4
|
+
"BLOCK_SIZE_N": 16,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 4,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 2,
|
9
|
+
"waves_per_eu": 0
|
10
|
+
},
|
11
|
+
"2": {
|
12
|
+
"BLOCK_SIZE_M": 64,
|
13
|
+
"BLOCK_SIZE_N": 16,
|
14
|
+
"BLOCK_SIZE_K": 128,
|
15
|
+
"GROUP_SIZE_M": 8,
|
16
|
+
"num_warps": 4,
|
17
|
+
"num_stages": 2,
|
18
|
+
"waves_per_eu": 0
|
19
|
+
},
|
20
|
+
"4": {
|
21
|
+
"BLOCK_SIZE_M": 64,
|
22
|
+
"BLOCK_SIZE_N": 16,
|
23
|
+
"BLOCK_SIZE_K": 128,
|
24
|
+
"GROUP_SIZE_M": 1,
|
25
|
+
"num_warps": 4,
|
26
|
+
"num_stages": 2,
|
27
|
+
"waves_per_eu": 0
|
28
|
+
},
|
29
|
+
"8": {
|
30
|
+
"BLOCK_SIZE_M": 64,
|
31
|
+
"BLOCK_SIZE_N": 16,
|
32
|
+
"BLOCK_SIZE_K": 128,
|
33
|
+
"GROUP_SIZE_M": 1,
|
34
|
+
"num_warps": 4,
|
35
|
+
"num_stages": 2,
|
36
|
+
"waves_per_eu": 0
|
37
|
+
},
|
38
|
+
"16": {
|
39
|
+
"BLOCK_SIZE_M": 64,
|
40
|
+
"BLOCK_SIZE_N": 16,
|
41
|
+
"BLOCK_SIZE_K": 128,
|
42
|
+
"GROUP_SIZE_M": 1,
|
43
|
+
"num_warps": 4,
|
44
|
+
"num_stages": 2,
|
45
|
+
"waves_per_eu": 0
|
46
|
+
},
|
47
|
+
"24": {
|
48
|
+
"BLOCK_SIZE_M": 64,
|
49
|
+
"BLOCK_SIZE_N": 16,
|
50
|
+
"BLOCK_SIZE_K": 128,
|
51
|
+
"GROUP_SIZE_M": 32,
|
52
|
+
"num_warps": 4,
|
53
|
+
"num_stages": 2,
|
54
|
+
"waves_per_eu": 0
|
55
|
+
},
|
56
|
+
"32": {
|
57
|
+
"BLOCK_SIZE_M": 64,
|
58
|
+
"BLOCK_SIZE_N": 16,
|
59
|
+
"BLOCK_SIZE_K": 128,
|
60
|
+
"GROUP_SIZE_M": 1,
|
61
|
+
"num_warps": 4,
|
62
|
+
"num_stages": 2,
|
63
|
+
"waves_per_eu": 0
|
64
|
+
},
|
65
|
+
"48": {
|
66
|
+
"BLOCK_SIZE_M": 64,
|
67
|
+
"BLOCK_SIZE_N": 16,
|
68
|
+
"BLOCK_SIZE_K": 128,
|
69
|
+
"GROUP_SIZE_M": 1,
|
70
|
+
"num_warps": 4,
|
71
|
+
"num_stages": 2,
|
72
|
+
"waves_per_eu": 0
|
73
|
+
},
|
74
|
+
"64": {
|
75
|
+
"BLOCK_SIZE_M": 64,
|
76
|
+
"BLOCK_SIZE_N": 16,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 32,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 2,
|
81
|
+
"waves_per_eu": 0
|
82
|
+
},
|
83
|
+
"96": {
|
84
|
+
"BLOCK_SIZE_M": 64,
|
85
|
+
"BLOCK_SIZE_N": 16,
|
86
|
+
"BLOCK_SIZE_K": 128,
|
87
|
+
"GROUP_SIZE_M": 1,
|
88
|
+
"num_warps": 4,
|
89
|
+
"num_stages": 2,
|
90
|
+
"waves_per_eu": 0
|
91
|
+
},
|
92
|
+
"128": {
|
93
|
+
"BLOCK_SIZE_M": 64,
|
94
|
+
"BLOCK_SIZE_N": 16,
|
95
|
+
"BLOCK_SIZE_K": 128,
|
96
|
+
"GROUP_SIZE_M": 1,
|
97
|
+
"num_warps": 4,
|
98
|
+
"num_stages": 2,
|
99
|
+
"waves_per_eu": 0
|
100
|
+
},
|
101
|
+
"256": {
|
102
|
+
"BLOCK_SIZE_M": 64,
|
103
|
+
"BLOCK_SIZE_N": 16,
|
104
|
+
"BLOCK_SIZE_K": 128,
|
105
|
+
"GROUP_SIZE_M": 4,
|
106
|
+
"num_warps": 4,
|
107
|
+
"num_stages": 2,
|
108
|
+
"waves_per_eu": 0
|
109
|
+
},
|
110
|
+
"512": {
|
111
|
+
"BLOCK_SIZE_M": 64,
|
112
|
+
"BLOCK_SIZE_N": 16,
|
113
|
+
"BLOCK_SIZE_K": 128,
|
114
|
+
"GROUP_SIZE_M": 4,
|
115
|
+
"num_warps": 4,
|
116
|
+
"num_stages": 2,
|
117
|
+
"waves_per_eu": 0
|
118
|
+
},
|
119
|
+
"1024": {
|
120
|
+
"BLOCK_SIZE_M": 64,
|
121
|
+
"BLOCK_SIZE_N": 16,
|
122
|
+
"BLOCK_SIZE_K": 128,
|
123
|
+
"GROUP_SIZE_M": 4,
|
124
|
+
"num_warps": 4,
|
125
|
+
"num_stages": 2,
|
126
|
+
"waves_per_eu": 0
|
127
|
+
},
|
128
|
+
"1536": {
|
129
|
+
"BLOCK_SIZE_M": 64,
|
130
|
+
"BLOCK_SIZE_N": 64,
|
131
|
+
"BLOCK_SIZE_K": 128,
|
132
|
+
"GROUP_SIZE_M": 8,
|
133
|
+
"num_warps": 4,
|
134
|
+
"num_stages": 2,
|
135
|
+
"waves_per_eu": 0
|
136
|
+
},
|
137
|
+
"2048": {
|
138
|
+
"BLOCK_SIZE_M": 128,
|
139
|
+
"BLOCK_SIZE_N": 32,
|
140
|
+
"BLOCK_SIZE_K": 128,
|
141
|
+
"GROUP_SIZE_M": 8,
|
142
|
+
"num_warps": 4,
|
143
|
+
"num_stages": 2,
|
144
|
+
"waves_per_eu": 0
|
145
|
+
},
|
146
|
+
"3072": {
|
147
|
+
"BLOCK_SIZE_M": 64,
|
148
|
+
"BLOCK_SIZE_N": 128,
|
149
|
+
"BLOCK_SIZE_K": 128,
|
150
|
+
"GROUP_SIZE_M": 16,
|
151
|
+
"num_warps": 4,
|
152
|
+
"num_stages": 2,
|
153
|
+
"waves_per_eu": 0
|
154
|
+
},
|
155
|
+
"4096": {
|
156
|
+
"BLOCK_SIZE_M": 64,
|
157
|
+
"BLOCK_SIZE_N": 64,
|
158
|
+
"BLOCK_SIZE_K": 128,
|
159
|
+
"GROUP_SIZE_M": 16,
|
160
|
+
"num_warps": 4,
|
161
|
+
"num_stages": 2,
|
162
|
+
"waves_per_eu": 0
|
163
|
+
}
|
164
|
+
}
|