sglang 0.3.1.post2__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/bench_latency.py +12 -11
  2. sglang/bench_server_latency.py +0 -6
  3. sglang/srt/hf_transformers_utils.py +1 -0
  4. sglang/srt/layers/activation.py +3 -2
  5. sglang/srt/layers/attention_backend.py +6 -12
  6. sglang/srt/layers/fused_moe/patch.py +117 -0
  7. sglang/srt/layers/linear.py +1133 -0
  8. sglang/srt/layers/quantization/__init__.py +76 -0
  9. sglang/srt/layers/quantization/base_config.py +122 -0
  10. sglang/srt/managers/schedule_batch.py +3 -5
  11. sglang/srt/managers/tokenizer_manager.py +1 -0
  12. sglang/srt/managers/tp_worker.py +1 -1
  13. sglang/srt/mem_cache/radix_cache.py +5 -5
  14. sglang/srt/model_executor/cuda_graph_runner.py +10 -6
  15. sglang/srt/model_executor/forward_batch_info.py +2 -4
  16. sglang/srt/model_executor/model_runner.py +0 -3
  17. sglang/srt/models/baichuan.py +1 -1
  18. sglang/srt/models/chatglm.py +6 -6
  19. sglang/srt/models/commandr.py +7 -7
  20. sglang/srt/models/dbrx.py +7 -7
  21. sglang/srt/models/deepseek.py +7 -7
  22. sglang/srt/models/deepseek_v2.py +7 -7
  23. sglang/srt/models/exaone.py +6 -6
  24. sglang/srt/models/gemma.py +6 -6
  25. sglang/srt/models/gemma2.py +6 -6
  26. sglang/srt/models/gpt_bigcode.py +6 -6
  27. sglang/srt/models/grok.py +6 -6
  28. sglang/srt/models/internlm2.py +6 -6
  29. sglang/srt/models/llama.py +14 -6
  30. sglang/srt/models/llama_classification.py +1 -1
  31. sglang/srt/models/llava.py +1 -1
  32. sglang/srt/models/llavavid.py +1 -1
  33. sglang/srt/models/minicpm.py +6 -6
  34. sglang/srt/models/minicpm3.py +1 -1
  35. sglang/srt/models/mixtral.py +6 -6
  36. sglang/srt/models/mixtral_quant.py +6 -6
  37. sglang/srt/models/olmoe.py +1 -1
  38. sglang/srt/models/qwen.py +6 -6
  39. sglang/srt/models/qwen2.py +6 -6
  40. sglang/srt/models/qwen2_moe.py +7 -7
  41. sglang/srt/models/stablelm.py +6 -6
  42. sglang/srt/models/xverse.py +1 -1
  43. sglang/srt/models/xverse_moe.py +1 -1
  44. sglang/srt/models/yivl.py +1 -1
  45. sglang/srt/openai_api/adapter.py +7 -0
  46. sglang/srt/utils.py +21 -1
  47. sglang/test/runners.py +7 -9
  48. sglang/test/test_utils.py +39 -2
  49. sglang/version.py +1 -1
  50. {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/METADATA +8 -6
  51. {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/RECORD +54 -50
  52. {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/LICENSE +0 -0
  53. {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/WHEEL +0 -0
  54. {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py CHANGED
@@ -64,8 +64,13 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
64
64
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
65
65
  from sglang.srt.model_executor.model_runner import ModelRunner
66
66
  from sglang.srt.sampling.sampling_params import SamplingParams
67
+ from sglang.srt.server import _set_envs_and_config
67
68
  from sglang.srt.server_args import ServerArgs
68
- from sglang.srt.utils import kill_child_process, suppress_other_loggers
69
+ from sglang.srt.utils import (
70
+ configure_logger,
71
+ kill_child_process,
72
+ suppress_other_loggers,
73
+ )
69
74
 
70
75
 
71
76
  @dataclasses.dataclass
@@ -255,7 +260,7 @@ def correctness_test(
255
260
 
256
261
  # Decode
257
262
  output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
258
- for _ in range(bench_args.output_len[0]):
263
+ for _ in range(bench_args.output_len[0] - 1):
259
264
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
260
265
  for i in range(len(reqs)):
261
266
  output_ids[i].append(next_token_ids[i])
@@ -306,7 +311,7 @@ def latency_test_run_once(
306
311
 
307
312
  # Decode
308
313
  decode_latencies = []
309
- for i in range(output_len):
314
+ for i in range(output_len - 1):
310
315
  torch.cuda.synchronize()
311
316
  tic = time.time()
312
317
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
@@ -341,6 +346,8 @@ def latency_test(
341
346
  bench_args,
342
347
  tp_rank,
343
348
  ):
349
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
350
+ _set_envs_and_config(server_args)
344
351
  rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
345
352
 
346
353
  # Load the model
@@ -484,18 +491,10 @@ def main(server_args, bench_args):
484
491
 
485
492
 
486
493
  if __name__ == "__main__":
487
- multiprocessing.set_start_method("spawn", force=True)
488
-
489
494
  parser = argparse.ArgumentParser()
490
495
  ServerArgs.add_cli_args(parser)
491
496
  BenchArgs.add_cli_args(parser)
492
- # For this script, model-path is not required
493
- assert (
494
- parser._actions[1].option_strings[0] == "--model-path"
495
- ), "options changed, this code need to be updated"
496
- parser._actions[1].required = False
497
497
  args = parser.parse_args()
498
-
499
498
  server_args = ServerArgs.from_cli_args(args)
500
499
  bench_args = BenchArgs.from_cli_args(args)
501
500
 
@@ -504,6 +503,8 @@ if __name__ == "__main__":
504
503
  format="%(message)s",
505
504
  )
506
505
 
506
+ multiprocessing.set_start_method("spawn", force=True)
507
+
507
508
  try:
508
509
  main(server_args, bench_args)
509
510
  except Exception as e:
@@ -174,13 +174,7 @@ if __name__ == "__main__":
174
174
  parser = argparse.ArgumentParser()
175
175
  ServerArgs.add_cli_args(parser)
176
176
  BenchArgs.add_cli_args(parser)
177
- # For this script, model-path is not required
178
- assert (
179
- parser._actions[1].option_strings[0] == "--model-path"
180
- ), "options changed, this code need to be updated"
181
- parser._actions[1].required = False
182
177
  args = parser.parse_args()
183
-
184
178
  server_args = ServerArgs.from_cli_args(args)
185
179
  bench_args = BenchArgs.from_cli_args(args)
186
180
 
@@ -129,6 +129,7 @@ def get_tokenizer(
129
129
  *args,
130
130
  trust_remote_code=trust_remote_code,
131
131
  tokenizer_revision=tokenizer_revision,
132
+ clean_up_tokenization_spaces=False,
132
133
  **kwargs,
133
134
  )
134
135
  except TypeError as e:
@@ -31,8 +31,9 @@ from vllm.distributed import (
31
31
  get_tensor_model_parallel_world_size,
32
32
  )
33
33
  from vllm.model_executor.custom_op import CustomOp
34
- from vllm.model_executor.layers.quantization import QuantizationConfig
35
- from vllm.model_executor.utils import set_weight_attrs
34
+
35
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
+ from sglang.srt.utils import set_weight_attrs
36
37
 
37
38
  logger = logging.getLogger(__name__)
38
39
 
@@ -86,17 +86,9 @@ class FlashInferAttnBackend(AttentionBackend):
86
86
  super().__init__()
87
87
  self.model_runner = model_runner
88
88
 
89
- local_num_qo_heads = (
90
- model_runner.model_config.num_attention_heads // model_runner.tp_size
91
- )
92
- local_num_kv_heads = model_runner.model_config.get_num_kv_heads(
93
- model_runner.tp_size
94
- )
95
- if (
96
- not _grouped_size_compiled_for_decode_kernels(
97
- local_num_qo_heads, local_num_kv_heads
98
- )
99
- or local_num_qo_heads // local_num_kv_heads > 4
89
+ if not _grouped_size_compiled_for_decode_kernels(
90
+ model_runner.model_config.num_attention_heads // model_runner.tp_size,
91
+ model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
100
92
  ):
101
93
  self.decode_use_tensor_cores = True
102
94
  else:
@@ -346,7 +338,9 @@ class TritonAttnBackend(AttentionBackend):
346
338
 
347
339
  self.decode_attention_fwd = decode_attention_fwd
348
340
  self.extend_attention_fwd = extend_attention_fwd
349
- self.num_head = model_runner.model_config.num_attention_heads
341
+ self.num_head = (
342
+ model_runner.model_config.num_attention_heads // model_runner.tp_size
343
+ )
350
344
 
351
345
  if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
352
346
  self.reduce_dtype = torch.float32
@@ -0,0 +1,117 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def fused_topk_native(
8
+ hidden_states: torch.Tensor,
9
+ gating_output: torch.Tensor,
10
+ topk: int,
11
+ renormalize: bool,
12
+ ):
13
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
14
+ M, _ = hidden_states.shape
15
+ topk_weights = torch.empty(
16
+ M, topk, dtype=torch.float32, device=hidden_states.device
17
+ )
18
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
19
+ topk_weights = F.softmax(gating_output.float(), dim=-1)
20
+ topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
21
+ if renormalize:
22
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
23
+ return topk_weights, topk_ids
24
+
25
+
26
+ # This is used by the Deepseek-V2 model
27
+ def grouped_topk(
28
+ hidden_states: torch.Tensor,
29
+ gating_output: torch.Tensor,
30
+ topk: int,
31
+ renormalize: bool,
32
+ num_expert_group: int = 0,
33
+ topk_group: int = 0,
34
+ ):
35
+
36
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
37
+
38
+ scores = torch.softmax(gating_output, dim=-1)
39
+ num_token = scores.shape[0]
40
+ group_scores = (
41
+ scores.view(num_token, num_expert_group, -1).max(dim=-1).values
42
+ ) # [n, n_group]
43
+ group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
44
+ 1
45
+ ] # [n, top_k_group]
46
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
47
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
48
+ score_mask = (
49
+ group_mask.unsqueeze(-1)
50
+ .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
51
+ .reshape(num_token, -1)
52
+ ) # [n, e]
53
+ tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
54
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
55
+
56
+ if renormalize:
57
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
58
+ return topk_weights, topk_ids
59
+
60
+
61
+ def select_experts_native(
62
+ hidden_states: torch.Tensor,
63
+ router_logits: torch.Tensor,
64
+ top_k: int,
65
+ use_grouped_topk: bool,
66
+ renormalize: bool,
67
+ topk_group: Optional[int] = None,
68
+ num_expert_group: Optional[int] = None,
69
+ ):
70
+ # DeekSeekv2 uses grouped_top_k
71
+ if use_grouped_topk:
72
+ assert topk_group is not None
73
+ assert num_expert_group is not None
74
+ topk_weights, topk_ids = grouped_topk(
75
+ hidden_states=hidden_states,
76
+ gating_output=router_logits,
77
+ topk=top_k,
78
+ renormalize=renormalize,
79
+ num_expert_group=num_expert_group,
80
+ topk_group=topk_group,
81
+ )
82
+ else:
83
+ topk_weights, topk_ids = fused_topk_native(
84
+ hidden_states=hidden_states,
85
+ gating_output=router_logits,
86
+ topk=top_k,
87
+ renormalize=renormalize,
88
+ )
89
+ return topk_weights, topk_ids
90
+
91
+
92
+ def fused_moe_forward_native(
93
+ layer: torch.nn.Module,
94
+ x: torch.Tensor,
95
+ use_grouped_topk: bool,
96
+ top_k: int,
97
+ router_logits: torch.Tensor,
98
+ renormalize: bool,
99
+ topk_group: Optional[int] = None,
100
+ num_expert_group: Optional[int] = None,
101
+ ) -> torch.Tensor:
102
+ topk_weights, topk_ids = select_experts_native(
103
+ hidden_states=x,
104
+ router_logits=router_logits,
105
+ use_grouped_topk=use_grouped_topk,
106
+ top_k=top_k,
107
+ renormalize=renormalize,
108
+ topk_group=topk_group,
109
+ num_expert_group=num_expert_group,
110
+ )
111
+ w13_weights = layer.w13_weight[topk_ids]
112
+ w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
113
+ w2_weights = layer.w2_weight[topk_ids]
114
+ x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
115
+ x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
116
+ expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
117
+ return torch.einsum("tai,ta -> ti", expert_outs, topk_weights)