sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +1 -0
- sglang/bench_serving.py +9 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +51 -5
- sglang/srt/layers/attention/triton_backend.py +16 -25
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
- sglang/srt/layers/moe/topk.py +191 -0
- sglang/srt/layers/quantization/__init__.py +5 -50
- sglang/srt/layers/quantization/fp8.py +221 -36
- sglang/srt/layers/quantization/fp8_kernel.py +278 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/radix_attention.py +8 -1
- sglang/srt/layers/sampler.py +27 -5
- sglang/srt/layers/torchao_utils.py +31 -0
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +54 -34
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +171 -136
- sglang/srt/managers/tokenizer_manager.py +184 -133
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +25 -11
- sglang/srt/model_executor/model_runner.py +28 -14
- sglang/srt/model_parallel.py +66 -5
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +73 -9
- sglang/srt/models/llama.py +22 -0
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +8 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/server.py +2 -1
- sglang/srt/server_args.py +19 -9
- sglang/srt/utils.py +40 -54
- sglang/test/test_block_fp8.py +341 -0
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,191 @@
|
|
1
|
+
from typing import Callable, Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn.functional as F
|
5
|
+
|
6
|
+
|
7
|
+
def fused_topk_native(
|
8
|
+
hidden_states: torch.Tensor,
|
9
|
+
gating_output: torch.Tensor,
|
10
|
+
topk: int,
|
11
|
+
renormalize: bool,
|
12
|
+
):
|
13
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
14
|
+
M, _ = hidden_states.shape
|
15
|
+
topk_weights = torch.empty(
|
16
|
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
17
|
+
)
|
18
|
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
19
|
+
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
20
|
+
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
21
|
+
if renormalize:
|
22
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
23
|
+
return topk_weights, topk_ids
|
24
|
+
|
25
|
+
|
26
|
+
def fused_topk(
|
27
|
+
hidden_states: torch.Tensor,
|
28
|
+
gating_output: torch.Tensor,
|
29
|
+
topk: int,
|
30
|
+
renormalize: bool,
|
31
|
+
):
|
32
|
+
from vllm import _custom_ops as ops
|
33
|
+
|
34
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
35
|
+
|
36
|
+
M, _ = hidden_states.shape
|
37
|
+
|
38
|
+
topk_weights = torch.empty(
|
39
|
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
40
|
+
)
|
41
|
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
42
|
+
token_expert_indicies = torch.empty(
|
43
|
+
M, topk, dtype=torch.int32, device=hidden_states.device
|
44
|
+
)
|
45
|
+
|
46
|
+
ops.topk_softmax(
|
47
|
+
topk_weights,
|
48
|
+
topk_ids,
|
49
|
+
token_expert_indicies,
|
50
|
+
gating_output.float(),
|
51
|
+
)
|
52
|
+
del token_expert_indicies
|
53
|
+
|
54
|
+
if renormalize:
|
55
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
56
|
+
|
57
|
+
return topk_weights, topk_ids
|
58
|
+
|
59
|
+
|
60
|
+
# This is used by the Deepseek-V2 model
|
61
|
+
def grouped_topk(
|
62
|
+
hidden_states: torch.Tensor,
|
63
|
+
gating_output: torch.Tensor,
|
64
|
+
topk: int,
|
65
|
+
renormalize: bool,
|
66
|
+
num_expert_group: int = 0,
|
67
|
+
topk_group: int = 0,
|
68
|
+
):
|
69
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
70
|
+
|
71
|
+
scores = torch.softmax(gating_output, dim=-1)
|
72
|
+
num_token = scores.shape[0]
|
73
|
+
group_scores = (
|
74
|
+
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
75
|
+
) # [n, n_group]
|
76
|
+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
77
|
+
1
|
78
|
+
] # [n, top_k_group]
|
79
|
+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
80
|
+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
81
|
+
score_mask = (
|
82
|
+
group_mask.unsqueeze(-1)
|
83
|
+
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
84
|
+
.reshape(num_token, -1)
|
85
|
+
) # [n, e]
|
86
|
+
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
87
|
+
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
88
|
+
|
89
|
+
if renormalize:
|
90
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
91
|
+
|
92
|
+
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
93
|
+
|
94
|
+
|
95
|
+
def biased_grouped_topk(
|
96
|
+
hidden_states: torch.Tensor,
|
97
|
+
gating_output: torch.Tensor,
|
98
|
+
correction_bias: torch.Tensor,
|
99
|
+
topk: int,
|
100
|
+
renormalize: bool,
|
101
|
+
num_expert_group: int = 0,
|
102
|
+
topk_group: int = 0,
|
103
|
+
):
|
104
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
105
|
+
|
106
|
+
scores = gating_output.sigmoid()
|
107
|
+
num_token = scores.shape[0]
|
108
|
+
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
|
109
|
+
group_scores = (
|
110
|
+
scores_for_choice.view(num_token, num_expert_group, -1)
|
111
|
+
.topk(2, dim=-1)[0]
|
112
|
+
.sum(dim=-1)
|
113
|
+
) # [n, n_group]
|
114
|
+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
115
|
+
1
|
116
|
+
] # [n, top_k_group]
|
117
|
+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
118
|
+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
119
|
+
score_mask = (
|
120
|
+
group_mask.unsqueeze(-1)
|
121
|
+
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
122
|
+
.reshape(num_token, -1)
|
123
|
+
) # [n, e]
|
124
|
+
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
125
|
+
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
126
|
+
topk_weights = scores.gather(1, topk_ids)
|
127
|
+
|
128
|
+
if renormalize:
|
129
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
130
|
+
|
131
|
+
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
132
|
+
|
133
|
+
|
134
|
+
def select_experts(
|
135
|
+
hidden_states: torch.Tensor,
|
136
|
+
router_logits: torch.Tensor,
|
137
|
+
top_k: int,
|
138
|
+
use_grouped_topk: bool,
|
139
|
+
renormalize: bool,
|
140
|
+
topk_group: Optional[int] = None,
|
141
|
+
num_expert_group: Optional[int] = None,
|
142
|
+
custom_routing_function: Optional[Callable] = None,
|
143
|
+
correction_bias: Optional[torch.Tensor] = None,
|
144
|
+
torch_native: bool = False,
|
145
|
+
):
|
146
|
+
# DeekSeekv2 uses grouped_top_k
|
147
|
+
if use_grouped_topk:
|
148
|
+
assert topk_group is not None
|
149
|
+
assert num_expert_group is not None
|
150
|
+
if correction_bias is None:
|
151
|
+
topk_weights, topk_ids = grouped_topk(
|
152
|
+
hidden_states=hidden_states,
|
153
|
+
gating_output=router_logits,
|
154
|
+
topk=top_k,
|
155
|
+
renormalize=renormalize,
|
156
|
+
num_expert_group=num_expert_group,
|
157
|
+
topk_group=topk_group,
|
158
|
+
)
|
159
|
+
else:
|
160
|
+
topk_weights, topk_ids = biased_grouped_topk(
|
161
|
+
hidden_states=hidden_states,
|
162
|
+
gating_output=router_logits,
|
163
|
+
correction_bias=correction_bias,
|
164
|
+
topk=top_k,
|
165
|
+
renormalize=renormalize,
|
166
|
+
num_expert_group=num_expert_group,
|
167
|
+
topk_group=topk_group,
|
168
|
+
)
|
169
|
+
elif torch_native:
|
170
|
+
topk_weights, topk_ids = fused_topk_native(
|
171
|
+
hidden_states=hidden_states,
|
172
|
+
gating_output=router_logits,
|
173
|
+
topk=top_k,
|
174
|
+
renormalize=renormalize,
|
175
|
+
)
|
176
|
+
elif custom_routing_function is None:
|
177
|
+
topk_weights, topk_ids = fused_topk(
|
178
|
+
hidden_states=hidden_states,
|
179
|
+
gating_output=router_logits,
|
180
|
+
topk=top_k,
|
181
|
+
renormalize=renormalize,
|
182
|
+
)
|
183
|
+
else:
|
184
|
+
topk_weights, topk_ids = custom_routing_function(
|
185
|
+
hidden_states=hidden_states,
|
186
|
+
gating_output=router_logits,
|
187
|
+
topk=top_k,
|
188
|
+
renormalize=renormalize,
|
189
|
+
)
|
190
|
+
|
191
|
+
return topk_weights, topk_ids
|
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
|
22
22
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
23
23
|
|
24
24
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
25
|
-
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
25
|
+
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
26
26
|
|
27
27
|
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
28
28
|
"aqlm": AQLMConfig,
|
@@ -53,50 +53,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
53
53
|
return QUANTIZATION_METHODS[quantization]
|
54
54
|
|
55
55
|
|
56
|
-
def fp8_moe_apply(
|
57
|
-
self,
|
58
|
-
layer: torch.nn.Module,
|
59
|
-
x: torch.Tensor,
|
60
|
-
router_logits: torch.Tensor,
|
61
|
-
top_k: int,
|
62
|
-
renormalize: bool,
|
63
|
-
use_grouped_topk: bool,
|
64
|
-
topk_group: Optional[int] = None,
|
65
|
-
num_expert_group: Optional[int] = None,
|
66
|
-
custom_routing_function: Optional[Callable] = None,
|
67
|
-
) -> torch.Tensor:
|
68
|
-
"""Enhanced apply method for FP8 MoE."""
|
69
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
70
|
-
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
71
|
-
|
72
|
-
# Expert selection
|
73
|
-
topk_weights, topk_ids = FusedMoE.select_experts(
|
74
|
-
hidden_states=x,
|
75
|
-
router_logits=router_logits,
|
76
|
-
use_grouped_topk=use_grouped_topk,
|
77
|
-
top_k=top_k,
|
78
|
-
renormalize=renormalize,
|
79
|
-
topk_group=topk_group,
|
80
|
-
num_expert_group=num_expert_group,
|
81
|
-
custom_routing_function=custom_routing_function,
|
82
|
-
)
|
83
|
-
|
84
|
-
# Expert fusion with FP8 quantization
|
85
|
-
return fused_experts(
|
86
|
-
x,
|
87
|
-
layer.w13_weight,
|
88
|
-
layer.w2_weight,
|
89
|
-
topk_weights=topk_weights,
|
90
|
-
topk_ids=topk_ids,
|
91
|
-
inplace=True,
|
92
|
-
use_fp8_w8a8=True,
|
93
|
-
w1_scale=layer.w13_weight_scale,
|
94
|
-
w2_scale=layer.w2_weight_scale,
|
95
|
-
a1_scale=layer.w13_input_scale,
|
96
|
-
a2_scale=layer.w2_input_scale,
|
97
|
-
)
|
98
|
-
|
99
|
-
|
100
56
|
def fp8_get_quant_method(self, layer, prefix):
|
101
57
|
"""Enhanced get_quant_method for FP8 config."""
|
102
58
|
from vllm.model_executor.layers.linear import LinearBase
|
@@ -104,9 +60,9 @@ def fp8_get_quant_method(self, layer, prefix):
|
|
104
60
|
is_layer_skipped,
|
105
61
|
)
|
106
62
|
|
107
|
-
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
108
63
|
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
109
|
-
from sglang.srt.layers.
|
64
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
65
|
+
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
|
110
66
|
|
111
67
|
if isinstance(layer, LinearBase):
|
112
68
|
if is_layer_skipped(prefix, self.ignored_layers):
|
@@ -124,7 +80,7 @@ def gptq_get_quant_method(self, layer, prefix):
|
|
124
80
|
GPTQMarlinMoEMethod,
|
125
81
|
)
|
126
82
|
|
127
|
-
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
83
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
128
84
|
|
129
85
|
if isinstance(layer, LinearBase):
|
130
86
|
return GPTQMarlinLinearMethod(self)
|
@@ -140,7 +96,7 @@ def awq_get_quant_method(self, layer, prefix):
|
|
140
96
|
AWQMoEMethod,
|
141
97
|
)
|
142
98
|
|
143
|
-
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
99
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
144
100
|
|
145
101
|
if isinstance(layer, LinearBase):
|
146
102
|
return AWQMarlinLinearMethod(self)
|
@@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix):
|
|
151
107
|
|
152
108
|
def apply_monkey_patches():
|
153
109
|
"""Apply all monkey patches in one place."""
|
154
|
-
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
|
155
110
|
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
156
111
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
157
112
|
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|