sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +1 -0
  3. sglang/bench_serving.py +9 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/aio_rwlock.py +100 -0
  9. sglang/srt/configs/model_config.py +8 -1
  10. sglang/srt/constrained/xgrammar_backend.py +4 -1
  11. sglang/srt/layers/attention/flashinfer_backend.py +51 -5
  12. sglang/srt/layers/attention/triton_backend.py +16 -25
  13. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  14. sglang/srt/layers/linear.py +20 -2
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
  17. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  18. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  19. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
  20. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
  21. sglang/srt/layers/moe/topk.py +191 -0
  22. sglang/srt/layers/quantization/__init__.py +5 -50
  23. sglang/srt/layers/quantization/fp8.py +221 -36
  24. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  25. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  26. sglang/srt/layers/radix_attention.py +8 -1
  27. sglang/srt/layers/sampler.py +27 -5
  28. sglang/srt/layers/torchao_utils.py +31 -0
  29. sglang/srt/managers/detokenizer_manager.py +37 -17
  30. sglang/srt/managers/io_struct.py +39 -10
  31. sglang/srt/managers/schedule_batch.py +54 -34
  32. sglang/srt/managers/schedule_policy.py +64 -5
  33. sglang/srt/managers/scheduler.py +171 -136
  34. sglang/srt/managers/tokenizer_manager.py +184 -133
  35. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  36. sglang/srt/mem_cache/chunk_cache.py +2 -2
  37. sglang/srt/mem_cache/memory_pool.py +15 -8
  38. sglang/srt/mem_cache/radix_cache.py +12 -2
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -11
  40. sglang/srt/model_executor/model_runner.py +28 -14
  41. sglang/srt/model_parallel.py +66 -5
  42. sglang/srt/models/dbrx.py +1 -1
  43. sglang/srt/models/deepseek.py +1 -1
  44. sglang/srt/models/deepseek_v2.py +67 -18
  45. sglang/srt/models/gemma2.py +34 -0
  46. sglang/srt/models/gemma2_reward.py +0 -1
  47. sglang/srt/models/granite.py +517 -0
  48. sglang/srt/models/grok.py +73 -9
  49. sglang/srt/models/llama.py +22 -0
  50. sglang/srt/models/llama_classification.py +11 -23
  51. sglang/srt/models/llama_reward.py +0 -2
  52. sglang/srt/models/llava.py +37 -14
  53. sglang/srt/models/mixtral.py +2 -2
  54. sglang/srt/models/olmoe.py +1 -1
  55. sglang/srt/models/qwen2.py +20 -0
  56. sglang/srt/models/qwen2_moe.py +1 -1
  57. sglang/srt/models/xverse_moe.py +1 -1
  58. sglang/srt/openai_api/adapter.py +8 -0
  59. sglang/srt/openai_api/protocol.py +9 -4
  60. sglang/srt/server.py +2 -1
  61. sglang/srt/server_args.py +19 -9
  62. sglang/srt/utils.py +40 -54
  63. sglang/test/test_block_fp8.py +341 -0
  64. sglang/test/test_utils.py +3 -2
  65. sglang/utils.py +10 -3
  66. sglang/version.py +1 -1
  67. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
  68. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
  69. sglang/srt/layers/fused_moe_patch.py +0 -133
  70. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  71. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  72. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  73. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,191 @@
1
+ from typing import Callable, Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def fused_topk_native(
8
+ hidden_states: torch.Tensor,
9
+ gating_output: torch.Tensor,
10
+ topk: int,
11
+ renormalize: bool,
12
+ ):
13
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
14
+ M, _ = hidden_states.shape
15
+ topk_weights = torch.empty(
16
+ M, topk, dtype=torch.float32, device=hidden_states.device
17
+ )
18
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
19
+ topk_weights = F.softmax(gating_output.float(), dim=-1)
20
+ topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
21
+ if renormalize:
22
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
23
+ return topk_weights, topk_ids
24
+
25
+
26
+ def fused_topk(
27
+ hidden_states: torch.Tensor,
28
+ gating_output: torch.Tensor,
29
+ topk: int,
30
+ renormalize: bool,
31
+ ):
32
+ from vllm import _custom_ops as ops
33
+
34
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
35
+
36
+ M, _ = hidden_states.shape
37
+
38
+ topk_weights = torch.empty(
39
+ M, topk, dtype=torch.float32, device=hidden_states.device
40
+ )
41
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
42
+ token_expert_indicies = torch.empty(
43
+ M, topk, dtype=torch.int32, device=hidden_states.device
44
+ )
45
+
46
+ ops.topk_softmax(
47
+ topk_weights,
48
+ topk_ids,
49
+ token_expert_indicies,
50
+ gating_output.float(),
51
+ )
52
+ del token_expert_indicies
53
+
54
+ if renormalize:
55
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
56
+
57
+ return topk_weights, topk_ids
58
+
59
+
60
+ # This is used by the Deepseek-V2 model
61
+ def grouped_topk(
62
+ hidden_states: torch.Tensor,
63
+ gating_output: torch.Tensor,
64
+ topk: int,
65
+ renormalize: bool,
66
+ num_expert_group: int = 0,
67
+ topk_group: int = 0,
68
+ ):
69
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
70
+
71
+ scores = torch.softmax(gating_output, dim=-1)
72
+ num_token = scores.shape[0]
73
+ group_scores = (
74
+ scores.view(num_token, num_expert_group, -1).max(dim=-1).values
75
+ ) # [n, n_group]
76
+ group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
77
+ 1
78
+ ] # [n, top_k_group]
79
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
80
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
81
+ score_mask = (
82
+ group_mask.unsqueeze(-1)
83
+ .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
84
+ .reshape(num_token, -1)
85
+ ) # [n, e]
86
+ tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
87
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
88
+
89
+ if renormalize:
90
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
91
+
92
+ return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
93
+
94
+
95
+ def biased_grouped_topk(
96
+ hidden_states: torch.Tensor,
97
+ gating_output: torch.Tensor,
98
+ correction_bias: torch.Tensor,
99
+ topk: int,
100
+ renormalize: bool,
101
+ num_expert_group: int = 0,
102
+ topk_group: int = 0,
103
+ ):
104
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
105
+
106
+ scores = gating_output.sigmoid()
107
+ num_token = scores.shape[0]
108
+ scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
109
+ group_scores = (
110
+ scores_for_choice.view(num_token, num_expert_group, -1)
111
+ .topk(2, dim=-1)[0]
112
+ .sum(dim=-1)
113
+ ) # [n, n_group]
114
+ group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
115
+ 1
116
+ ] # [n, top_k_group]
117
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
118
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
119
+ score_mask = (
120
+ group_mask.unsqueeze(-1)
121
+ .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
122
+ .reshape(num_token, -1)
123
+ ) # [n, e]
124
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
125
+ _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
126
+ topk_weights = scores.gather(1, topk_ids)
127
+
128
+ if renormalize:
129
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
130
+
131
+ return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
132
+
133
+
134
+ def select_experts(
135
+ hidden_states: torch.Tensor,
136
+ router_logits: torch.Tensor,
137
+ top_k: int,
138
+ use_grouped_topk: bool,
139
+ renormalize: bool,
140
+ topk_group: Optional[int] = None,
141
+ num_expert_group: Optional[int] = None,
142
+ custom_routing_function: Optional[Callable] = None,
143
+ correction_bias: Optional[torch.Tensor] = None,
144
+ torch_native: bool = False,
145
+ ):
146
+ # DeekSeekv2 uses grouped_top_k
147
+ if use_grouped_topk:
148
+ assert topk_group is not None
149
+ assert num_expert_group is not None
150
+ if correction_bias is None:
151
+ topk_weights, topk_ids = grouped_topk(
152
+ hidden_states=hidden_states,
153
+ gating_output=router_logits,
154
+ topk=top_k,
155
+ renormalize=renormalize,
156
+ num_expert_group=num_expert_group,
157
+ topk_group=topk_group,
158
+ )
159
+ else:
160
+ topk_weights, topk_ids = biased_grouped_topk(
161
+ hidden_states=hidden_states,
162
+ gating_output=router_logits,
163
+ correction_bias=correction_bias,
164
+ topk=top_k,
165
+ renormalize=renormalize,
166
+ num_expert_group=num_expert_group,
167
+ topk_group=topk_group,
168
+ )
169
+ elif torch_native:
170
+ topk_weights, topk_ids = fused_topk_native(
171
+ hidden_states=hidden_states,
172
+ gating_output=router_logits,
173
+ topk=top_k,
174
+ renormalize=renormalize,
175
+ )
176
+ elif custom_routing_function is None:
177
+ topk_weights, topk_ids = fused_topk(
178
+ hidden_states=hidden_states,
179
+ gating_output=router_logits,
180
+ topk=top_k,
181
+ renormalize=renormalize,
182
+ )
183
+ else:
184
+ topk_weights, topk_ids = custom_routing_function(
185
+ hidden_states=hidden_states,
186
+ gating_output=router_logits,
187
+ topk=top_k,
188
+ renormalize=renormalize,
189
+ )
190
+
191
+ return topk_weights, topk_ids
@@ -22,7 +22,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig
22
22
  from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
23
23
 
24
24
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
25
- from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
25
+ from sglang.srt.layers.quantization.fp8 import Fp8Config
26
26
 
27
27
  QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
28
28
  "aqlm": AQLMConfig,
@@ -53,50 +53,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
53
53
  return QUANTIZATION_METHODS[quantization]
54
54
 
55
55
 
56
- def fp8_moe_apply(
57
- self,
58
- layer: torch.nn.Module,
59
- x: torch.Tensor,
60
- router_logits: torch.Tensor,
61
- top_k: int,
62
- renormalize: bool,
63
- use_grouped_topk: bool,
64
- topk_group: Optional[int] = None,
65
- num_expert_group: Optional[int] = None,
66
- custom_routing_function: Optional[Callable] = None,
67
- ) -> torch.Tensor:
68
- """Enhanced apply method for FP8 MoE."""
69
- from sglang.srt.layers.fused_moe_triton import FusedMoE
70
- from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
71
-
72
- # Expert selection
73
- topk_weights, topk_ids = FusedMoE.select_experts(
74
- hidden_states=x,
75
- router_logits=router_logits,
76
- use_grouped_topk=use_grouped_topk,
77
- top_k=top_k,
78
- renormalize=renormalize,
79
- topk_group=topk_group,
80
- num_expert_group=num_expert_group,
81
- custom_routing_function=custom_routing_function,
82
- )
83
-
84
- # Expert fusion with FP8 quantization
85
- return fused_experts(
86
- x,
87
- layer.w13_weight,
88
- layer.w2_weight,
89
- topk_weights=topk_weights,
90
- topk_ids=topk_ids,
91
- inplace=True,
92
- use_fp8_w8a8=True,
93
- w1_scale=layer.w13_weight_scale,
94
- w2_scale=layer.w2_weight_scale,
95
- a1_scale=layer.w13_input_scale,
96
- a2_scale=layer.w2_input_scale,
97
- )
98
-
99
-
100
56
  def fp8_get_quant_method(self, layer, prefix):
101
57
  """Enhanced get_quant_method for FP8 config."""
102
58
  from vllm.model_executor.layers.linear import LinearBase
@@ -104,9 +60,9 @@ def fp8_get_quant_method(self, layer, prefix):
104
60
  is_layer_skipped,
105
61
  )
106
62
 
107
- from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
108
63
  from sglang.srt.layers.linear import UnquantizedLinearMethod
109
- from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
64
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
65
+ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
110
66
 
111
67
  if isinstance(layer, LinearBase):
112
68
  if is_layer_skipped(prefix, self.ignored_layers):
@@ -124,7 +80,7 @@ def gptq_get_quant_method(self, layer, prefix):
124
80
  GPTQMarlinMoEMethod,
125
81
  )
126
82
 
127
- from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
83
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
128
84
 
129
85
  if isinstance(layer, LinearBase):
130
86
  return GPTQMarlinLinearMethod(self)
@@ -140,7 +96,7 @@ def awq_get_quant_method(self, layer, prefix):
140
96
  AWQMoEMethod,
141
97
  )
142
98
 
143
- from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
99
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
144
100
 
145
101
  if isinstance(layer, LinearBase):
146
102
  return AWQMarlinLinearMethod(self)
@@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix):
151
107
 
152
108
  def apply_monkey_patches():
153
109
  """Apply all monkey patches in one place."""
154
- setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
155
110
  setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
156
111
  setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
157
112
  setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)