sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +41 -27
- sglang/bench_one_batch.py +60 -4
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +83 -71
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +46 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +452 -0
- sglang/srt/entrypoints/http_server.py +603 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +8 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +71 -0
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +65 -14
- sglang/srt/layers/logits_processor.py +49 -64
- sglang/srt/layers/moe/ep_moe/layer.py +24 -16
- sglang/srt/layers/moe/fused_moe_native.py +84 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
- sglang/srt/layers/parameter.py +18 -8
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +10 -4
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1184 -31
- sglang/srt/layers/sampler.py +64 -6
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +24 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +57 -3
- sglang/srt/managers/schedule_batch.py +78 -45
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +326 -201
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +210 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +26 -30
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +44 -19
- sglang/srt/model_loader/loader.py +83 -6
- sglang/srt/model_loader/weight_utils.py +145 -6
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +17 -5
- sglang/srt/models/dbrx.py +13 -5
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +11 -11
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +15 -25
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +4 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +9 -9
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +20 -7
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +7 -4
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +143 -18
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +77 -15
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +164 -129
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +2 -1
- sglang/test/test_utils.py +83 -22
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple
|
|
4
4
|
import torch
|
5
5
|
from torch.nn import Module
|
6
6
|
from vllm import _custom_ops as ops
|
7
|
-
from vllm.
|
7
|
+
from vllm.model_executor.custom_op import CustomOp
|
8
|
+
|
9
|
+
from sglang.srt.distributed import (
|
8
10
|
get_tensor_model_parallel_rank,
|
9
11
|
get_tensor_model_parallel_world_size,
|
10
12
|
)
|
11
|
-
from vllm.model_executor.custom_op import CustomOp
|
12
|
-
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
13
|
-
|
14
13
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
15
14
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
16
15
|
grouped_gemm_triton,
|
@@ -25,6 +24,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
25
24
|
QuantizationConfig,
|
26
25
|
QuantizeMethodBase,
|
27
26
|
)
|
27
|
+
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
28
28
|
from sglang.srt.utils import is_hip, set_weight_attrs
|
29
29
|
|
30
30
|
logger = logging.getLogger(__name__)
|
@@ -114,6 +114,8 @@ class EPMoE(torch.nn.Module):
|
|
114
114
|
tp_size: Optional[int] = None,
|
115
115
|
prefix: str = "",
|
116
116
|
correction_bias: Optional[torch.Tensor] = None,
|
117
|
+
custom_routing_function: Optional[Callable] = None,
|
118
|
+
activation: str = "silu",
|
117
119
|
):
|
118
120
|
super().__init__()
|
119
121
|
|
@@ -140,6 +142,8 @@ class EPMoE(torch.nn.Module):
|
|
140
142
|
self.num_expert_group = num_expert_group
|
141
143
|
self.topk_group = topk_group
|
142
144
|
self.correction_bias = correction_bias
|
145
|
+
self.custom_routing_function = custom_routing_function
|
146
|
+
self.activation = activation
|
143
147
|
|
144
148
|
if quant_config is None:
|
145
149
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
@@ -166,6 +170,7 @@ class EPMoE(torch.nn.Module):
|
|
166
170
|
|
167
171
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
168
172
|
assert self.quant_method is not None
|
173
|
+
assert self.activation == "silu"
|
169
174
|
|
170
175
|
if self.grouped_gemm_runner is None:
|
171
176
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
@@ -181,6 +186,7 @@ class EPMoE(torch.nn.Module):
|
|
181
186
|
topk_group=self.topk_group,
|
182
187
|
num_expert_group=self.num_expert_group,
|
183
188
|
correction_bias=self.correction_bias,
|
189
|
+
custom_routing_function=self.custom_routing_function,
|
184
190
|
)
|
185
191
|
|
186
192
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
@@ -254,16 +260,20 @@ class EPMoE(torch.nn.Module):
|
|
254
260
|
dtype=torch.float32,
|
255
261
|
device=hidden_states.device,
|
256
262
|
)
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
263
|
+
|
264
|
+
if self.activation == "silu":
|
265
|
+
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
266
|
+
gateup_output,
|
267
|
+
down_input,
|
268
|
+
gateup_output.shape[1],
|
269
|
+
reorder_topk_ids,
|
270
|
+
self.w2_input_scale,
|
271
|
+
self.start_expert_id,
|
272
|
+
self.end_expert_id,
|
273
|
+
BLOCK_SIZE=512,
|
274
|
+
)
|
275
|
+
else:
|
276
|
+
raise ValueError(f"Unsupported activation: {self.activation=}")
|
267
277
|
|
268
278
|
# GroupGemm-1
|
269
279
|
down_output = torch.empty(
|
@@ -309,7 +319,6 @@ class EPMoE(torch.nn.Module):
|
|
309
319
|
ckpt_up_proj_name: str,
|
310
320
|
num_experts: int,
|
311
321
|
) -> List[Tuple[str, str, int, str]]:
|
312
|
-
|
313
322
|
return [
|
314
323
|
# (param_name, weight_name, expert_id, shard_id)
|
315
324
|
(
|
@@ -354,7 +363,6 @@ class EPMoE(torch.nn.Module):
|
|
354
363
|
)
|
355
364
|
return
|
356
365
|
|
357
|
-
expert_data = param.data[expert_id]
|
358
366
|
if shard_id == "w2":
|
359
367
|
param.data[expert_id] = loaded_weight
|
360
368
|
elif shard_id == "w1":
|
@@ -8,6 +8,7 @@ from typing import Callable, Optional
|
|
8
8
|
import torch
|
9
9
|
from torch.nn import functional as F
|
10
10
|
|
11
|
+
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
11
12
|
from sglang.srt.layers.moe.topk import select_experts
|
12
13
|
|
13
14
|
|
@@ -22,6 +23,7 @@ def fused_moe_forward_native(
|
|
22
23
|
num_expert_group: Optional[int] = None,
|
23
24
|
custom_routing_function: Optional[Callable] = None,
|
24
25
|
correction_bias: Optional[torch.Tensor] = None,
|
26
|
+
activation: str = "silu",
|
25
27
|
) -> torch.Tensor:
|
26
28
|
topk_weights, topk_ids = select_experts(
|
27
29
|
hidden_states=x,
|
@@ -40,7 +42,88 @@ def fused_moe_forward_native(
|
|
40
42
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
41
43
|
w2_weights = layer.w2_weight[topk_ids]
|
42
44
|
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
43
|
-
|
45
|
+
if activation == "silu":
|
46
|
+
x1 = F.silu(x1)
|
47
|
+
elif activation == "gelu":
|
48
|
+
x1 = F.gelu(x1)
|
49
|
+
else:
|
50
|
+
raise ValueError(f"Unsupported activation: {activation=}")
|
44
51
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
45
52
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
46
53
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
54
|
+
|
55
|
+
|
56
|
+
def moe_forward_native(
|
57
|
+
layer: torch.nn.Module,
|
58
|
+
x: torch.Tensor,
|
59
|
+
use_grouped_topk: bool,
|
60
|
+
top_k: int,
|
61
|
+
router_logits: torch.Tensor,
|
62
|
+
renormalize: bool,
|
63
|
+
topk_group: Optional[int] = None,
|
64
|
+
num_expert_group: Optional[int] = None,
|
65
|
+
custom_routing_function: Optional[Callable] = None,
|
66
|
+
correction_bias: Optional[torch.Tensor] = None,
|
67
|
+
activation: str = "silu",
|
68
|
+
) -> torch.Tensor:
|
69
|
+
|
70
|
+
topk_weights, topk_ids = select_experts(
|
71
|
+
hidden_states=x,
|
72
|
+
router_logits=router_logits,
|
73
|
+
use_grouped_topk=use_grouped_topk,
|
74
|
+
top_k=top_k,
|
75
|
+
renormalize=renormalize,
|
76
|
+
topk_group=topk_group,
|
77
|
+
num_expert_group=num_expert_group,
|
78
|
+
custom_routing_function=custom_routing_function,
|
79
|
+
correction_bias=correction_bias,
|
80
|
+
torch_native=True,
|
81
|
+
)
|
82
|
+
|
83
|
+
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
|
84
|
+
len_experts = layer.num_experts
|
85
|
+
|
86
|
+
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
|
87
|
+
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
|
88
|
+
tokens_per_expert = cnts.sum(dim=0)
|
89
|
+
idxs = topk_ids.view(-1).argsort()
|
90
|
+
|
91
|
+
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
92
|
+
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
93
|
+
|
94
|
+
if activation == "silu":
|
95
|
+
act = SiluAndMul()
|
96
|
+
elif activation == "gelu":
|
97
|
+
act = GeluAndMul()
|
98
|
+
else:
|
99
|
+
raise ValueError(f"Unsupported activation: {activation=}")
|
100
|
+
|
101
|
+
outputs = []
|
102
|
+
start_idx = 0
|
103
|
+
for i, num_tokens in enumerate(tokens_per_expert):
|
104
|
+
end_idx = start_idx + num_tokens
|
105
|
+
if num_tokens == 0:
|
106
|
+
continue
|
107
|
+
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
108
|
+
|
109
|
+
layer_w13_weight = layer.w13_weight[i]
|
110
|
+
layer_w2_weight = layer.w2_weight[i]
|
111
|
+
|
112
|
+
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
|
113
|
+
gate_up = act(gate_up)
|
114
|
+
expert_out = F.linear(gate_up, layer_w2_weight)
|
115
|
+
outputs.append(expert_out)
|
116
|
+
start_idx = end_idx
|
117
|
+
|
118
|
+
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
119
|
+
new_x = torch.empty_like(outs)
|
120
|
+
|
121
|
+
new_x[idxs] = outs
|
122
|
+
final_out = (
|
123
|
+
new_x.view(*topk_ids.shape, -1)
|
124
|
+
.type(topk_weights.dtype)
|
125
|
+
.mul_(topk_weights.unsqueeze(dim=-1))
|
126
|
+
.sum(dim=1)
|
127
|
+
.type(new_x.dtype)
|
128
|
+
)
|
129
|
+
return final_out
|
@@ -0,0 +1,164 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 32,
|
4
|
+
"BLOCK_SIZE_N": 32,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 16,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 2,
|
9
|
+
"waves_per_eu": 0
|
10
|
+
},
|
11
|
+
"2": {
|
12
|
+
"BLOCK_SIZE_M": 32,
|
13
|
+
"BLOCK_SIZE_N": 64,
|
14
|
+
"BLOCK_SIZE_K": 128,
|
15
|
+
"GROUP_SIZE_M": 1,
|
16
|
+
"num_warps": 4,
|
17
|
+
"num_stages": 2,
|
18
|
+
"waves_per_eu": 0
|
19
|
+
},
|
20
|
+
"4": {
|
21
|
+
"BLOCK_SIZE_M": 64,
|
22
|
+
"BLOCK_SIZE_N": 64,
|
23
|
+
"BLOCK_SIZE_K": 128,
|
24
|
+
"GROUP_SIZE_M": 16,
|
25
|
+
"num_warps": 4,
|
26
|
+
"num_stages": 2,
|
27
|
+
"waves_per_eu": 0
|
28
|
+
},
|
29
|
+
"8": {
|
30
|
+
"BLOCK_SIZE_M": 32,
|
31
|
+
"BLOCK_SIZE_N": 128,
|
32
|
+
"BLOCK_SIZE_K": 128,
|
33
|
+
"GROUP_SIZE_M": 32,
|
34
|
+
"num_warps": 4,
|
35
|
+
"num_stages": 2,
|
36
|
+
"waves_per_eu": 0
|
37
|
+
},
|
38
|
+
"16": {
|
39
|
+
"BLOCK_SIZE_M": 32,
|
40
|
+
"BLOCK_SIZE_N": 128,
|
41
|
+
"BLOCK_SIZE_K": 128,
|
42
|
+
"GROUP_SIZE_M": 1,
|
43
|
+
"num_warps": 4,
|
44
|
+
"num_stages": 2,
|
45
|
+
"waves_per_eu": 0
|
46
|
+
},
|
47
|
+
"24": {
|
48
|
+
"BLOCK_SIZE_M": 32,
|
49
|
+
"BLOCK_SIZE_N": 128,
|
50
|
+
"BLOCK_SIZE_K": 128,
|
51
|
+
"GROUP_SIZE_M": 4,
|
52
|
+
"num_warps": 4,
|
53
|
+
"num_stages": 2,
|
54
|
+
"waves_per_eu": 0
|
55
|
+
},
|
56
|
+
"32": {
|
57
|
+
"BLOCK_SIZE_M": 32,
|
58
|
+
"BLOCK_SIZE_N": 128,
|
59
|
+
"BLOCK_SIZE_K": 128,
|
60
|
+
"GROUP_SIZE_M": 8,
|
61
|
+
"num_warps": 4,
|
62
|
+
"num_stages": 2,
|
63
|
+
"waves_per_eu": 0
|
64
|
+
},
|
65
|
+
"48": {
|
66
|
+
"BLOCK_SIZE_M": 32,
|
67
|
+
"BLOCK_SIZE_N": 128,
|
68
|
+
"BLOCK_SIZE_K": 128,
|
69
|
+
"GROUP_SIZE_M": 4,
|
70
|
+
"num_warps": 4,
|
71
|
+
"num_stages": 2,
|
72
|
+
"waves_per_eu": 0
|
73
|
+
},
|
74
|
+
"64": {
|
75
|
+
"BLOCK_SIZE_M": 256,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 2,
|
81
|
+
"waves_per_eu": 0
|
82
|
+
},
|
83
|
+
"96": {
|
84
|
+
"BLOCK_SIZE_M": 32,
|
85
|
+
"BLOCK_SIZE_N": 128,
|
86
|
+
"BLOCK_SIZE_K": 128,
|
87
|
+
"GROUP_SIZE_M": 8,
|
88
|
+
"num_warps": 4,
|
89
|
+
"num_stages": 2,
|
90
|
+
"waves_per_eu": 0
|
91
|
+
},
|
92
|
+
"128": {
|
93
|
+
"BLOCK_SIZE_M": 32,
|
94
|
+
"BLOCK_SIZE_N": 16,
|
95
|
+
"BLOCK_SIZE_K": 128,
|
96
|
+
"GROUP_SIZE_M": 4,
|
97
|
+
"num_warps": 4,
|
98
|
+
"num_stages": 2,
|
99
|
+
"waves_per_eu": 0
|
100
|
+
},
|
101
|
+
"256": {
|
102
|
+
"BLOCK_SIZE_M": 64,
|
103
|
+
"BLOCK_SIZE_N": 16,
|
104
|
+
"BLOCK_SIZE_K": 128,
|
105
|
+
"GROUP_SIZE_M": 1,
|
106
|
+
"num_warps": 4,
|
107
|
+
"num_stages": 2,
|
108
|
+
"waves_per_eu": 0
|
109
|
+
},
|
110
|
+
"512": {
|
111
|
+
"BLOCK_SIZE_M": 64,
|
112
|
+
"BLOCK_SIZE_N": 64,
|
113
|
+
"BLOCK_SIZE_K": 128,
|
114
|
+
"GROUP_SIZE_M": 32,
|
115
|
+
"num_warps": 4,
|
116
|
+
"num_stages": 2,
|
117
|
+
"waves_per_eu": 0
|
118
|
+
},
|
119
|
+
"1024": {
|
120
|
+
"BLOCK_SIZE_M": 64,
|
121
|
+
"BLOCK_SIZE_N": 64,
|
122
|
+
"BLOCK_SIZE_K": 128,
|
123
|
+
"GROUP_SIZE_M": 4,
|
124
|
+
"num_warps": 8,
|
125
|
+
"num_stages": 2,
|
126
|
+
"waves_per_eu": 0
|
127
|
+
},
|
128
|
+
"1536": {
|
129
|
+
"BLOCK_SIZE_M": 64,
|
130
|
+
"BLOCK_SIZE_N": 64,
|
131
|
+
"BLOCK_SIZE_K": 128,
|
132
|
+
"GROUP_SIZE_M": 8,
|
133
|
+
"num_warps": 4,
|
134
|
+
"num_stages": 2,
|
135
|
+
"waves_per_eu": 0
|
136
|
+
},
|
137
|
+
"2048": {
|
138
|
+
"BLOCK_SIZE_M": 32,
|
139
|
+
"BLOCK_SIZE_N": 64,
|
140
|
+
"BLOCK_SIZE_K": 128,
|
141
|
+
"GROUP_SIZE_M": 1,
|
142
|
+
"num_warps": 4,
|
143
|
+
"num_stages": 2,
|
144
|
+
"waves_per_eu": 0
|
145
|
+
},
|
146
|
+
"3072": {
|
147
|
+
"BLOCK_SIZE_M": 32,
|
148
|
+
"BLOCK_SIZE_N": 128,
|
149
|
+
"BLOCK_SIZE_K": 128,
|
150
|
+
"GROUP_SIZE_M": 1,
|
151
|
+
"num_warps": 4,
|
152
|
+
"num_stages": 2,
|
153
|
+
"waves_per_eu": 0
|
154
|
+
},
|
155
|
+
"4096": {
|
156
|
+
"BLOCK_SIZE_M": 64,
|
157
|
+
"BLOCK_SIZE_N": 128,
|
158
|
+
"BLOCK_SIZE_K": 64,
|
159
|
+
"GROUP_SIZE_M": 4,
|
160
|
+
"num_warps": 4,
|
161
|
+
"num_stages": 2,
|
162
|
+
"waves_per_eu": 0
|
163
|
+
}
|
164
|
+
}
|
@@ -15,15 +15,18 @@ from vllm import _custom_ops as ops
|
|
15
15
|
|
16
16
|
from sglang.srt.layers.moe.topk import select_experts
|
17
17
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
18
|
-
from sglang.srt.utils import
|
18
|
+
from sglang.srt.utils import (
|
19
|
+
direct_register_custom_op,
|
20
|
+
get_device_name,
|
21
|
+
is_cuda_available,
|
22
|
+
is_hip,
|
23
|
+
)
|
19
24
|
|
20
|
-
|
21
|
-
|
25
|
+
is_cuda = is_cuda_available()
|
26
|
+
is_hip_flag = is_hip()
|
27
|
+
if is_cuda:
|
22
28
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
23
29
|
|
24
|
-
is_hip_flag = False
|
25
|
-
else:
|
26
|
-
is_hip_flag = True
|
27
30
|
|
28
31
|
logger = logging.getLogger(__name__)
|
29
32
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
@@ -708,6 +711,7 @@ def inplace_fused_experts(
|
|
708
711
|
w2: torch.Tensor,
|
709
712
|
topk_weights: torch.Tensor,
|
710
713
|
topk_ids: torch.Tensor,
|
714
|
+
activation: str = "silu",
|
711
715
|
use_fp8_w8a8: bool = False,
|
712
716
|
use_int8_w8a16: bool = False,
|
713
717
|
w1_scale: Optional[torch.Tensor] = None,
|
@@ -723,6 +727,7 @@ def inplace_fused_experts(
|
|
723
727
|
topk_weights,
|
724
728
|
topk_ids,
|
725
729
|
True,
|
730
|
+
activation,
|
726
731
|
use_fp8_w8a8,
|
727
732
|
use_int8_w8a16,
|
728
733
|
w1_scale,
|
@@ -739,6 +744,7 @@ def inplace_fused_experts_fake(
|
|
739
744
|
w2: torch.Tensor,
|
740
745
|
topk_weights: torch.Tensor,
|
741
746
|
topk_ids: torch.Tensor,
|
747
|
+
activation: str = "silu",
|
742
748
|
use_fp8_w8a8: bool = False,
|
743
749
|
use_int8_w8a16: bool = False,
|
744
750
|
w1_scale: Optional[torch.Tensor] = None,
|
@@ -764,6 +770,7 @@ def outplace_fused_experts(
|
|
764
770
|
w2: torch.Tensor,
|
765
771
|
topk_weights: torch.Tensor,
|
766
772
|
topk_ids: torch.Tensor,
|
773
|
+
activation: str = "silu",
|
767
774
|
use_fp8_w8a8: bool = False,
|
768
775
|
use_int8_w8a16: bool = False,
|
769
776
|
w1_scale: Optional[torch.Tensor] = None,
|
@@ -779,6 +786,7 @@ def outplace_fused_experts(
|
|
779
786
|
topk_weights,
|
780
787
|
topk_ids,
|
781
788
|
False,
|
789
|
+
activation,
|
782
790
|
use_fp8_w8a8,
|
783
791
|
use_int8_w8a16,
|
784
792
|
w1_scale,
|
@@ -795,6 +803,7 @@ def outplace_fused_experts_fake(
|
|
795
803
|
w2: torch.Tensor,
|
796
804
|
topk_weights: torch.Tensor,
|
797
805
|
topk_ids: torch.Tensor,
|
806
|
+
activation: str = "silu",
|
798
807
|
use_fp8_w8a8: bool = False,
|
799
808
|
use_int8_w8a16: bool = False,
|
800
809
|
w1_scale: Optional[torch.Tensor] = None,
|
@@ -821,6 +830,7 @@ def fused_experts(
|
|
821
830
|
topk_weights: torch.Tensor,
|
822
831
|
topk_ids: torch.Tensor,
|
823
832
|
inplace: bool = False,
|
833
|
+
activation: str = "silu",
|
824
834
|
use_fp8_w8a8: bool = False,
|
825
835
|
use_int8_w8a16: bool = False,
|
826
836
|
w1_scale: Optional[torch.Tensor] = None,
|
@@ -836,6 +846,7 @@ def fused_experts(
|
|
836
846
|
w2,
|
837
847
|
topk_weights,
|
838
848
|
topk_ids,
|
849
|
+
activation,
|
839
850
|
use_fp8_w8a8,
|
840
851
|
use_int8_w8a16,
|
841
852
|
w1_scale,
|
@@ -852,6 +863,7 @@ def fused_experts(
|
|
852
863
|
w2,
|
853
864
|
topk_weights,
|
854
865
|
topk_ids,
|
866
|
+
activation,
|
855
867
|
use_fp8_w8a8,
|
856
868
|
use_int8_w8a16,
|
857
869
|
w1_scale,
|
@@ -869,6 +881,7 @@ def fused_experts_impl(
|
|
869
881
|
topk_weights: torch.Tensor,
|
870
882
|
topk_ids: torch.Tensor,
|
871
883
|
inplace: bool = False,
|
884
|
+
activation: str = "silu",
|
872
885
|
use_fp8_w8a8: bool = False,
|
873
886
|
use_int8_w8a16: bool = False,
|
874
887
|
w1_scale: Optional[torch.Tensor] = None,
|
@@ -983,7 +996,12 @@ def fused_experts_impl(
|
|
983
996
|
block_shape=block_shape,
|
984
997
|
)
|
985
998
|
|
986
|
-
|
999
|
+
if activation == "silu":
|
1000
|
+
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
1001
|
+
elif activation == "gelu":
|
1002
|
+
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
1003
|
+
else:
|
1004
|
+
raise ValueError(f"Unsupported activation: {activation=}")
|
987
1005
|
|
988
1006
|
invoke_fused_moe_kernel(
|
989
1007
|
intermediate_cache2,
|
@@ -1039,6 +1057,7 @@ def fused_moe(
|
|
1039
1057
|
topk: int,
|
1040
1058
|
renormalize: bool,
|
1041
1059
|
inplace: bool = False,
|
1060
|
+
activation: str = "silu",
|
1042
1061
|
use_grouped_topk: bool = False,
|
1043
1062
|
num_expert_group: Optional[int] = None,
|
1044
1063
|
topk_group: Optional[int] = None,
|
@@ -1108,6 +1127,7 @@ def fused_moe(
|
|
1108
1127
|
topk_weights,
|
1109
1128
|
topk_ids,
|
1110
1129
|
inplace=inplace,
|
1130
|
+
activation=activation,
|
1111
1131
|
use_fp8_w8a8=use_fp8_w8a8,
|
1112
1132
|
use_int8_w8a16=use_int8_w8a16,
|
1113
1133
|
w1_scale=w1_scale,
|
@@ -5,14 +5,15 @@ from enum import Enum
|
|
5
5
|
from typing import Callable, List, Optional, Tuple
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from vllm.
|
8
|
+
from vllm.model_executor.custom_op import CustomOp
|
9
|
+
|
10
|
+
from sglang.srt.distributed import (
|
9
11
|
get_tensor_model_parallel_rank,
|
10
12
|
get_tensor_model_parallel_world_size,
|
11
13
|
tensor_model_parallel_all_reduce,
|
12
14
|
)
|
13
|
-
from vllm.model_executor.custom_op import CustomOp
|
14
|
-
|
15
15
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
16
|
+
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
16
17
|
from sglang.srt.layers.moe.topk import select_experts
|
17
18
|
from sglang.srt.layers.quantization.base_config import (
|
18
19
|
QuantizationConfig,
|
@@ -125,6 +126,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
125
126
|
num_expert_group: Optional[int] = None,
|
126
127
|
custom_routing_function: Optional[Callable] = None,
|
127
128
|
correction_bias: Optional[torch.Tensor] = None,
|
129
|
+
activation: str = "silu",
|
128
130
|
) -> torch.Tensor:
|
129
131
|
return self.forward(
|
130
132
|
x=x,
|
@@ -137,6 +139,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
137
139
|
num_expert_group=num_expert_group,
|
138
140
|
custom_routing_function=custom_routing_function,
|
139
141
|
correction_bias=correction_bias,
|
142
|
+
activation=activation,
|
140
143
|
)
|
141
144
|
|
142
145
|
def forward_cuda(
|
@@ -151,6 +154,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
151
154
|
num_expert_group: Optional[int] = None,
|
152
155
|
custom_routing_function: Optional[Callable] = None,
|
153
156
|
correction_bias: Optional[torch.Tensor] = None,
|
157
|
+
activation: str = "silu",
|
154
158
|
) -> torch.Tensor:
|
155
159
|
topk_weights, topk_ids = select_experts(
|
156
160
|
hidden_states=x,
|
@@ -168,6 +172,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
168
172
|
import ater
|
169
173
|
from ater.fused_moe import fused_experts_ck
|
170
174
|
|
175
|
+
assert activation == "silu", f"{activation=} is not supported."
|
176
|
+
|
171
177
|
return fused_experts_ck(
|
172
178
|
hidden_states=x,
|
173
179
|
w1=layer.w13_weight,
|
@@ -183,10 +189,34 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
183
189
|
topk_weights=topk_weights,
|
184
190
|
topk_ids=topk_ids,
|
185
191
|
inplace=True,
|
192
|
+
activation=activation,
|
186
193
|
)
|
187
194
|
|
188
|
-
def forward_cpu(
|
189
|
-
|
195
|
+
def forward_cpu(
|
196
|
+
self,
|
197
|
+
layer: torch.nn.Module,
|
198
|
+
x: torch.Tensor,
|
199
|
+
use_grouped_topk: bool,
|
200
|
+
top_k: int,
|
201
|
+
router_logits: torch.Tensor,
|
202
|
+
renormalize: bool,
|
203
|
+
topk_group: Optional[int] = None,
|
204
|
+
num_expert_group: Optional[int] = None,
|
205
|
+
custom_routing_function: Optional[Callable] = None,
|
206
|
+
correction_bias: Optional[torch.Tensor] = None,
|
207
|
+
) -> torch.Tensor:
|
208
|
+
return moe_forward_native(
|
209
|
+
layer,
|
210
|
+
x,
|
211
|
+
use_grouped_topk,
|
212
|
+
top_k,
|
213
|
+
router_logits,
|
214
|
+
renormalize,
|
215
|
+
topk_group,
|
216
|
+
num_expert_group,
|
217
|
+
custom_routing_function,
|
218
|
+
correction_bias,
|
219
|
+
)
|
190
220
|
|
191
221
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
192
222
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
@@ -232,6 +262,7 @@ class FusedMoE(torch.nn.Module):
|
|
232
262
|
prefix: str = "",
|
233
263
|
custom_routing_function: Optional[Callable] = None,
|
234
264
|
correction_bias: Optional[torch.Tensor] = None,
|
265
|
+
activation: str = "silu",
|
235
266
|
use_presharded_weights: bool = False,
|
236
267
|
):
|
237
268
|
super().__init__()
|
@@ -255,6 +286,7 @@ class FusedMoE(torch.nn.Module):
|
|
255
286
|
self.topk_group = topk_group
|
256
287
|
self.custom_routing_function = custom_routing_function
|
257
288
|
self.correction_bias = correction_bias
|
289
|
+
self.activation = activation
|
258
290
|
|
259
291
|
if quant_config is None:
|
260
292
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
@@ -565,6 +597,7 @@ class FusedMoE(torch.nn.Module):
|
|
565
597
|
num_expert_group=self.num_expert_group,
|
566
598
|
custom_routing_function=self.custom_routing_function,
|
567
599
|
correction_bias=self.correction_bias,
|
600
|
+
activation=self.activation,
|
568
601
|
)
|
569
602
|
|
570
603
|
if self.reduce_results and self.tp_size > 1:
|
sglang/srt/layers/parameter.py
CHANGED
@@ -6,7 +6,8 @@ from typing import Callable, Optional, Union
|
|
6
6
|
|
7
7
|
import torch
|
8
8
|
from torch.nn import Parameter
|
9
|
-
|
9
|
+
|
10
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
10
11
|
|
11
12
|
__all__ = [
|
12
13
|
"BasevLLMParameter",
|
@@ -123,7 +124,13 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
123
124
|
assert param_data.shape == loaded_weight.shape
|
124
125
|
param_data.copy_(loaded_weight)
|
125
126
|
|
126
|
-
def load_qkv_weight(
|
127
|
+
def load_qkv_weight(
|
128
|
+
self,
|
129
|
+
loaded_weight: torch.Tensor,
|
130
|
+
tp_rank: int,
|
131
|
+
use_presharded_weights: bool = False,
|
132
|
+
**kwargs,
|
133
|
+
):
|
127
134
|
|
128
135
|
shard_offset = kwargs.get("shard_offset")
|
129
136
|
shard_size = kwargs.get("shard_size")
|
@@ -141,11 +148,14 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
141
148
|
param_data = self.data
|
142
149
|
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
143
150
|
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
144
|
-
|
145
|
-
|
146
|
-
|
151
|
+
if not use_presharded_weights:
|
152
|
+
loaded_weight = loaded_weight.narrow(
|
153
|
+
self.output_dim, shard_id * shard_size, shard_size
|
154
|
+
)
|
147
155
|
|
148
|
-
assert
|
156
|
+
assert (
|
157
|
+
param_data.shape == loaded_weight.shape
|
158
|
+
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
149
159
|
param_data.copy_(loaded_weight)
|
150
160
|
|
151
161
|
|
@@ -291,7 +301,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
|
|
291
301
|
packed_factor: Union[int, Fraction],
|
292
302
|
packed_dim: int,
|
293
303
|
marlin_tile_size: Optional[int] = None,
|
294
|
-
**kwargs
|
304
|
+
**kwargs,
|
295
305
|
):
|
296
306
|
self._packed_factor = packed_factor
|
297
307
|
self._packed_dim = packed_dim
|
@@ -335,7 +345,7 @@ class PackedvLLMParameter(ModelWeightParameter):
|
|
335
345
|
packed_factor: Union[int, Fraction],
|
336
346
|
packed_dim: int,
|
337
347
|
marlin_tile_size: Optional[int] = None,
|
338
|
-
**kwargs
|
348
|
+
**kwargs,
|
339
349
|
):
|
340
350
|
self._packed_factor = packed_factor
|
341
351
|
self._packed_dim = packed_dim
|