sglang 0.3.1.post2__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +12 -11
- sglang/bench_server_latency.py +0 -6
- sglang/srt/hf_transformers_utils.py +1 -0
- sglang/srt/layers/activation.py +3 -2
- sglang/srt/layers/attention_backend.py +6 -12
- sglang/srt/layers/fused_moe/patch.py +117 -0
- sglang/srt/layers/linear.py +1133 -0
- sglang/srt/layers/quantization/__init__.py +76 -0
- sglang/srt/layers/quantization/base_config.py +122 -0
- sglang/srt/managers/schedule_batch.py +3 -5
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/mem_cache/radix_cache.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +10 -6
- sglang/srt/model_executor/forward_batch_info.py +2 -4
- sglang/srt/model_executor/model_runner.py +0 -3
- sglang/srt/models/baichuan.py +1 -1
- sglang/srt/models/chatglm.py +6 -6
- sglang/srt/models/commandr.py +7 -7
- sglang/srt/models/dbrx.py +7 -7
- sglang/srt/models/deepseek.py +7 -7
- sglang/srt/models/deepseek_v2.py +7 -7
- sglang/srt/models/exaone.py +6 -6
- sglang/srt/models/gemma.py +6 -6
- sglang/srt/models/gemma2.py +6 -6
- sglang/srt/models/gpt_bigcode.py +6 -6
- sglang/srt/models/grok.py +6 -6
- sglang/srt/models/internlm2.py +6 -6
- sglang/srt/models/llama.py +14 -6
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +6 -6
- sglang/srt/models/minicpm3.py +1 -1
- sglang/srt/models/mixtral.py +6 -6
- sglang/srt/models/mixtral_quant.py +6 -6
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen.py +6 -6
- sglang/srt/models/qwen2.py +6 -6
- sglang/srt/models/qwen2_moe.py +7 -7
- sglang/srt/models/stablelm.py +6 -6
- sglang/srt/models/xverse.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/models/yivl.py +1 -1
- sglang/srt/openai_api/adapter.py +7 -0
- sglang/srt/utils.py +21 -1
- sglang/test/runners.py +7 -9
- sglang/test/test_utils.py +39 -2
- sglang/version.py +1 -1
- {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/METADATA +8 -6
- {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/RECORD +54 -50
- {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
CHANGED
@@ -64,8 +64,13 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
64
64
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
65
65
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
66
66
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
67
|
+
from sglang.srt.server import _set_envs_and_config
|
67
68
|
from sglang.srt.server_args import ServerArgs
|
68
|
-
from sglang.srt.utils import
|
69
|
+
from sglang.srt.utils import (
|
70
|
+
configure_logger,
|
71
|
+
kill_child_process,
|
72
|
+
suppress_other_loggers,
|
73
|
+
)
|
69
74
|
|
70
75
|
|
71
76
|
@dataclasses.dataclass
|
@@ -255,7 +260,7 @@ def correctness_test(
|
|
255
260
|
|
256
261
|
# Decode
|
257
262
|
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
258
|
-
for _ in range(bench_args.output_len[0]):
|
263
|
+
for _ in range(bench_args.output_len[0] - 1):
|
259
264
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
260
265
|
for i in range(len(reqs)):
|
261
266
|
output_ids[i].append(next_token_ids[i])
|
@@ -306,7 +311,7 @@ def latency_test_run_once(
|
|
306
311
|
|
307
312
|
# Decode
|
308
313
|
decode_latencies = []
|
309
|
-
for i in range(output_len):
|
314
|
+
for i in range(output_len - 1):
|
310
315
|
torch.cuda.synchronize()
|
311
316
|
tic = time.time()
|
312
317
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
@@ -341,6 +346,8 @@ def latency_test(
|
|
341
346
|
bench_args,
|
342
347
|
tp_rank,
|
343
348
|
):
|
349
|
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
350
|
+
_set_envs_and_config(server_args)
|
344
351
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
345
352
|
|
346
353
|
# Load the model
|
@@ -484,18 +491,10 @@ def main(server_args, bench_args):
|
|
484
491
|
|
485
492
|
|
486
493
|
if __name__ == "__main__":
|
487
|
-
multiprocessing.set_start_method("spawn", force=True)
|
488
|
-
|
489
494
|
parser = argparse.ArgumentParser()
|
490
495
|
ServerArgs.add_cli_args(parser)
|
491
496
|
BenchArgs.add_cli_args(parser)
|
492
|
-
# For this script, model-path is not required
|
493
|
-
assert (
|
494
|
-
parser._actions[1].option_strings[0] == "--model-path"
|
495
|
-
), "options changed, this code need to be updated"
|
496
|
-
parser._actions[1].required = False
|
497
497
|
args = parser.parse_args()
|
498
|
-
|
499
498
|
server_args = ServerArgs.from_cli_args(args)
|
500
499
|
bench_args = BenchArgs.from_cli_args(args)
|
501
500
|
|
@@ -504,6 +503,8 @@ if __name__ == "__main__":
|
|
504
503
|
format="%(message)s",
|
505
504
|
)
|
506
505
|
|
506
|
+
multiprocessing.set_start_method("spawn", force=True)
|
507
|
+
|
507
508
|
try:
|
508
509
|
main(server_args, bench_args)
|
509
510
|
except Exception as e:
|
sglang/bench_server_latency.py
CHANGED
@@ -174,13 +174,7 @@ if __name__ == "__main__":
|
|
174
174
|
parser = argparse.ArgumentParser()
|
175
175
|
ServerArgs.add_cli_args(parser)
|
176
176
|
BenchArgs.add_cli_args(parser)
|
177
|
-
# For this script, model-path is not required
|
178
|
-
assert (
|
179
|
-
parser._actions[1].option_strings[0] == "--model-path"
|
180
|
-
), "options changed, this code need to be updated"
|
181
|
-
parser._actions[1].required = False
|
182
177
|
args = parser.parse_args()
|
183
|
-
|
184
178
|
server_args = ServerArgs.from_cli_args(args)
|
185
179
|
bench_args = BenchArgs.from_cli_args(args)
|
186
180
|
|
sglang/srt/layers/activation.py
CHANGED
@@ -31,8 +31,9 @@ from vllm.distributed import (
|
|
31
31
|
get_tensor_model_parallel_world_size,
|
32
32
|
)
|
33
33
|
from vllm.model_executor.custom_op import CustomOp
|
34
|
-
|
35
|
-
from
|
34
|
+
|
35
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
|
+
from sglang.srt.utils import set_weight_attrs
|
36
37
|
|
37
38
|
logger = logging.getLogger(__name__)
|
38
39
|
|
@@ -86,17 +86,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
86
86
|
super().__init__()
|
87
87
|
self.model_runner = model_runner
|
88
88
|
|
89
|
-
|
90
|
-
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
91
|
-
|
92
|
-
local_num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
93
|
-
model_runner.tp_size
|
94
|
-
)
|
95
|
-
if (
|
96
|
-
not _grouped_size_compiled_for_decode_kernels(
|
97
|
-
local_num_qo_heads, local_num_kv_heads
|
98
|
-
)
|
99
|
-
or local_num_qo_heads // local_num_kv_heads > 4
|
89
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
90
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size,
|
91
|
+
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
100
92
|
):
|
101
93
|
self.decode_use_tensor_cores = True
|
102
94
|
else:
|
@@ -346,7 +338,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
346
338
|
|
347
339
|
self.decode_attention_fwd = decode_attention_fwd
|
348
340
|
self.extend_attention_fwd = extend_attention_fwd
|
349
|
-
self.num_head =
|
341
|
+
self.num_head = (
|
342
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
343
|
+
)
|
350
344
|
|
351
345
|
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
352
346
|
self.reduce_dtype = torch.float32
|
@@ -0,0 +1,117 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch.nn import functional as F
|
5
|
+
|
6
|
+
|
7
|
+
def fused_topk_native(
|
8
|
+
hidden_states: torch.Tensor,
|
9
|
+
gating_output: torch.Tensor,
|
10
|
+
topk: int,
|
11
|
+
renormalize: bool,
|
12
|
+
):
|
13
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
14
|
+
M, _ = hidden_states.shape
|
15
|
+
topk_weights = torch.empty(
|
16
|
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
17
|
+
)
|
18
|
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
19
|
+
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
20
|
+
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
21
|
+
if renormalize:
|
22
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
23
|
+
return topk_weights, topk_ids
|
24
|
+
|
25
|
+
|
26
|
+
# This is used by the Deepseek-V2 model
|
27
|
+
def grouped_topk(
|
28
|
+
hidden_states: torch.Tensor,
|
29
|
+
gating_output: torch.Tensor,
|
30
|
+
topk: int,
|
31
|
+
renormalize: bool,
|
32
|
+
num_expert_group: int = 0,
|
33
|
+
topk_group: int = 0,
|
34
|
+
):
|
35
|
+
|
36
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
37
|
+
|
38
|
+
scores = torch.softmax(gating_output, dim=-1)
|
39
|
+
num_token = scores.shape[0]
|
40
|
+
group_scores = (
|
41
|
+
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
42
|
+
) # [n, n_group]
|
43
|
+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
44
|
+
1
|
45
|
+
] # [n, top_k_group]
|
46
|
+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
47
|
+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
48
|
+
score_mask = (
|
49
|
+
group_mask.unsqueeze(-1)
|
50
|
+
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
51
|
+
.reshape(num_token, -1)
|
52
|
+
) # [n, e]
|
53
|
+
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
54
|
+
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
55
|
+
|
56
|
+
if renormalize:
|
57
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
58
|
+
return topk_weights, topk_ids
|
59
|
+
|
60
|
+
|
61
|
+
def select_experts_native(
|
62
|
+
hidden_states: torch.Tensor,
|
63
|
+
router_logits: torch.Tensor,
|
64
|
+
top_k: int,
|
65
|
+
use_grouped_topk: bool,
|
66
|
+
renormalize: bool,
|
67
|
+
topk_group: Optional[int] = None,
|
68
|
+
num_expert_group: Optional[int] = None,
|
69
|
+
):
|
70
|
+
# DeekSeekv2 uses grouped_top_k
|
71
|
+
if use_grouped_topk:
|
72
|
+
assert topk_group is not None
|
73
|
+
assert num_expert_group is not None
|
74
|
+
topk_weights, topk_ids = grouped_topk(
|
75
|
+
hidden_states=hidden_states,
|
76
|
+
gating_output=router_logits,
|
77
|
+
topk=top_k,
|
78
|
+
renormalize=renormalize,
|
79
|
+
num_expert_group=num_expert_group,
|
80
|
+
topk_group=topk_group,
|
81
|
+
)
|
82
|
+
else:
|
83
|
+
topk_weights, topk_ids = fused_topk_native(
|
84
|
+
hidden_states=hidden_states,
|
85
|
+
gating_output=router_logits,
|
86
|
+
topk=top_k,
|
87
|
+
renormalize=renormalize,
|
88
|
+
)
|
89
|
+
return topk_weights, topk_ids
|
90
|
+
|
91
|
+
|
92
|
+
def fused_moe_forward_native(
|
93
|
+
layer: torch.nn.Module,
|
94
|
+
x: torch.Tensor,
|
95
|
+
use_grouped_topk: bool,
|
96
|
+
top_k: int,
|
97
|
+
router_logits: torch.Tensor,
|
98
|
+
renormalize: bool,
|
99
|
+
topk_group: Optional[int] = None,
|
100
|
+
num_expert_group: Optional[int] = None,
|
101
|
+
) -> torch.Tensor:
|
102
|
+
topk_weights, topk_ids = select_experts_native(
|
103
|
+
hidden_states=x,
|
104
|
+
router_logits=router_logits,
|
105
|
+
use_grouped_topk=use_grouped_topk,
|
106
|
+
top_k=top_k,
|
107
|
+
renormalize=renormalize,
|
108
|
+
topk_group=topk_group,
|
109
|
+
num_expert_group=num_expert_group,
|
110
|
+
)
|
111
|
+
w13_weights = layer.w13_weight[topk_ids]
|
112
|
+
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
113
|
+
w2_weights = layer.w2_weight[topk_ids]
|
114
|
+
x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
|
115
|
+
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
116
|
+
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
117
|
+
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights)
|