sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +1 -0
  3. sglang/bench_serving.py +9 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/aio_rwlock.py +100 -0
  9. sglang/srt/configs/model_config.py +8 -1
  10. sglang/srt/constrained/xgrammar_backend.py +4 -1
  11. sglang/srt/layers/attention/flashinfer_backend.py +51 -5
  12. sglang/srt/layers/attention/triton_backend.py +16 -25
  13. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  14. sglang/srt/layers/linear.py +20 -2
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
  17. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  18. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  19. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
  20. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
  21. sglang/srt/layers/moe/topk.py +191 -0
  22. sglang/srt/layers/quantization/__init__.py +5 -50
  23. sglang/srt/layers/quantization/fp8.py +221 -36
  24. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  25. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  26. sglang/srt/layers/radix_attention.py +8 -1
  27. sglang/srt/layers/sampler.py +27 -5
  28. sglang/srt/layers/torchao_utils.py +31 -0
  29. sglang/srt/managers/detokenizer_manager.py +37 -17
  30. sglang/srt/managers/io_struct.py +39 -10
  31. sglang/srt/managers/schedule_batch.py +54 -34
  32. sglang/srt/managers/schedule_policy.py +64 -5
  33. sglang/srt/managers/scheduler.py +171 -136
  34. sglang/srt/managers/tokenizer_manager.py +184 -133
  35. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  36. sglang/srt/mem_cache/chunk_cache.py +2 -2
  37. sglang/srt/mem_cache/memory_pool.py +15 -8
  38. sglang/srt/mem_cache/radix_cache.py +12 -2
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -11
  40. sglang/srt/model_executor/model_runner.py +28 -14
  41. sglang/srt/model_parallel.py +66 -5
  42. sglang/srt/models/dbrx.py +1 -1
  43. sglang/srt/models/deepseek.py +1 -1
  44. sglang/srt/models/deepseek_v2.py +67 -18
  45. sglang/srt/models/gemma2.py +34 -0
  46. sglang/srt/models/gemma2_reward.py +0 -1
  47. sglang/srt/models/granite.py +517 -0
  48. sglang/srt/models/grok.py +73 -9
  49. sglang/srt/models/llama.py +22 -0
  50. sglang/srt/models/llama_classification.py +11 -23
  51. sglang/srt/models/llama_reward.py +0 -2
  52. sglang/srt/models/llava.py +37 -14
  53. sglang/srt/models/mixtral.py +2 -2
  54. sglang/srt/models/olmoe.py +1 -1
  55. sglang/srt/models/qwen2.py +20 -0
  56. sglang/srt/models/qwen2_moe.py +1 -1
  57. sglang/srt/models/xverse_moe.py +1 -1
  58. sglang/srt/openai_api/adapter.py +8 -0
  59. sglang/srt/openai_api/protocol.py +9 -4
  60. sglang/srt/server.py +2 -1
  61. sglang/srt/server_args.py +19 -9
  62. sglang/srt/utils.py +40 -54
  63. sglang/test/test_block_fp8.py +341 -0
  64. sglang/test/test_utils.py +3 -2
  65. sglang/utils.py +10 -3
  66. sglang/version.py +1 -1
  67. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
  68. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
  69. sglang/srt/layers/fused_moe_patch.py +0 -133
  70. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  71. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  72. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  73. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.base_config import (
30
30
  QuantizationConfig,
31
31
  QuantizeMethodBase,
32
32
  )
33
+ from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
33
34
  from sglang.srt.utils import set_weight_attrs
34
35
 
35
36
  logger = logging.getLogger(__name__)
@@ -628,8 +629,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
628
629
  assert loaded_shard_id < len(self.output_sizes)
629
630
 
630
631
  tp_size = get_tensor_model_parallel_world_size()
631
- shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
632
- shard_size = self.output_sizes[loaded_shard_id] // tp_size
632
+
633
+ if isinstance(param, BlockQuantScaleParameter):
634
+ weight_block_size = self.quant_method.quant_config.weight_block_size
635
+ block_n, _ = weight_block_size[0], weight_block_size[1]
636
+ shard_offset = (
637
+ (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
638
+ ) // tp_size
639
+ shard_size = (
640
+ (self.output_sizes[loaded_shard_id] + block_n - 1) // block_n // tp_size
641
+ )
642
+ else:
643
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
644
+ shard_size = self.output_sizes[loaded_shard_id] // tp_size
633
645
 
634
646
  param.load_merged_column_weight(
635
647
  loaded_weight=loaded_weight,
@@ -795,6 +807,12 @@ class QKVParallelLinear(ColumnParallelLinear):
795
807
  shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
796
808
  shard_size = self._get_shard_size_mapping(loaded_shard_id)
797
809
 
810
+ if isinstance(param, BlockQuantScaleParameter):
811
+ weight_block_size = self.quant_method.quant_config.weight_block_size
812
+ block_n, _ = weight_block_size[0], weight_block_size[1]
813
+ shard_offset = (shard_offset + block_n - 1) // block_n
814
+ shard_size = (shard_size + block_n - 1) // block_n
815
+
798
816
  param.load_qkv_weight(
799
817
  loaded_weight=loaded_weight,
800
818
  num_heads=self.num_kv_head_replicas,
@@ -39,10 +39,12 @@ class LogitsProcessorOutput:
39
39
  # The logprobs of input tokens. shape: [#token, vocab_size]
40
40
  input_token_logprobs: torch.Tensor = None
41
41
 
42
- # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
43
- input_top_logprobs: List = None
44
- # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
45
- output_top_logprobs: List = None
42
+ # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
43
+ input_top_logprobs_val: List = None
44
+ input_top_logprobs_idx: List = None
45
+ # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
46
+ output_top_logprobs_val: List = None
47
+ output_top_logprobs_idx: List = None
46
48
 
47
49
 
48
50
  @dataclasses.dataclass
@@ -89,76 +91,18 @@ class LogitsMetadata:
89
91
 
90
92
 
91
93
  class LogitsProcessor(nn.Module):
92
- def __init__(self, config, skip_all_gather: bool = False):
94
+ def __init__(
95
+ self, config, skip_all_gather: bool = False, logit_scale: Optional[float] = None
96
+ ):
93
97
  super().__init__()
94
98
  self.config = config
99
+ self.logit_scale = logit_scale
95
100
  self.do_tensor_parallel_all_gather = (
96
101
  not skip_all_gather and get_tensor_model_parallel_world_size() > 1
97
102
  )
98
-
99
- def _get_normalized_prompt_logprobs(
100
- self,
101
- input_token_logprobs: torch.Tensor,
102
- logits_metadata: LogitsMetadata,
103
- ):
104
- logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
105
- pruned_lens = torch.tensor(
106
- logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
107
- )
108
-
109
- start = torch.zeros_like(pruned_lens)
110
- start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
111
- end = torch.clamp(
112
- start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
113
- )
114
- sum_logp = (
115
- logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
103
+ self.final_logit_softcapping = getattr(
104
+ self.config, "final_logit_softcapping", None
116
105
  )
117
- normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
118
- return normalized_prompt_logprobs
119
-
120
- @staticmethod
121
- def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
122
- max_k = max(logits_metadata.top_logprobs_nums)
123
- ret = all_logprobs.topk(max_k, dim=1)
124
- values = ret.values.tolist()
125
- indices = ret.indices.tolist()
126
-
127
- if logits_metadata.forward_mode.is_decode():
128
- output_top_logprobs = []
129
- for i, k in enumerate(logits_metadata.top_logprobs_nums):
130
- output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
131
- return None, output_top_logprobs
132
- else:
133
- input_top_logprobs, output_top_logprobs = [], []
134
-
135
- pt = 0
136
- for k, pruned_len in zip(
137
- logits_metadata.top_logprobs_nums,
138
- logits_metadata.extend_logprob_pruned_lens_cpu,
139
- ):
140
- if pruned_len <= 0:
141
- input_top_logprobs.append([])
142
- output_top_logprobs.append([])
143
- continue
144
-
145
- input_top_logprobs.append(
146
- [
147
- list(zip(values[pt + j][:k], indices[pt + j][:k]))
148
- for j in range(pruned_len - 1)
149
- ]
150
- )
151
- output_top_logprobs.append(
152
- list(
153
- zip(
154
- values[pt + pruned_len - 1][:k],
155
- indices[pt + pruned_len - 1][:k],
156
- )
157
- )
158
- )
159
- pt += pruned_len
160
-
161
- return input_top_logprobs, output_top_logprobs
162
106
 
163
107
  def forward(
164
108
  self,
@@ -184,38 +128,33 @@ class LogitsProcessor(nn.Module):
184
128
  last_logits = tensor_model_parallel_all_gather(last_logits)
185
129
  last_logits = last_logits[:, : self.config.vocab_size].float()
186
130
 
187
- if hasattr(self.config, "final_logit_softcapping"):
188
- last_logits.div_(self.config.final_logit_softcapping)
131
+ if self.final_logit_softcapping:
132
+ last_logits.div_(self.final_logit_softcapping)
189
133
  torch.tanh(last_logits, out=last_logits)
190
- last_logits.mul_(self.config.final_logit_softcapping)
134
+ last_logits.mul_(self.final_logit_softcapping)
191
135
 
192
136
  # Return only last_logits if logprob is not requested
193
137
  if not logits_metadata.return_logprob:
194
138
  return LogitsProcessorOutput(
195
139
  next_token_logits=last_logits,
196
- next_token_logprobs=None,
197
- normalized_prompt_logprobs=None,
198
- input_token_logprobs=None,
199
- input_top_logprobs=None,
200
- output_top_logprobs=None,
201
140
  )
202
141
  else:
203
- last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
142
+ last_logprobs = self.compute_temp_top_p_normalized_logprobs(
143
+ last_logits, logits_metadata
144
+ )
204
145
 
205
146
  if logits_metadata.forward_mode.is_decode():
206
147
  if logits_metadata.return_top_logprob:
207
- output_top_logprobs = self.get_top_logprobs(
208
- last_logprobs, logits_metadata
209
- )[1]
148
+ output_top_logprobs_val, output_top_logprobs_idx = (
149
+ self.get_top_logprobs(last_logprobs, logits_metadata)[2:4]
150
+ )
210
151
  else:
211
- output_top_logprobs = None
152
+ output_top_logprobs_val = output_top_logprobs_idx = None
212
153
  return LogitsProcessorOutput(
213
154
  next_token_logits=last_logits,
214
155
  next_token_logprobs=last_logprobs,
215
- normalized_prompt_logprobs=None,
216
- input_token_logprobs=None,
217
- input_top_logprobs=None,
218
- output_top_logprobs=output_top_logprobs,
156
+ output_top_logprobs_val=output_top_logprobs_val,
157
+ output_top_logprobs_idx=output_top_logprobs_idx,
219
158
  )
220
159
  else:
221
160
  # Slice the requested tokens to compute logprob
@@ -233,24 +172,35 @@ class LogitsProcessor(nn.Module):
233
172
  all_logits = self._get_logits(states, lm_head)
234
173
  if self.do_tensor_parallel_all_gather:
235
174
  all_logits = tensor_model_parallel_all_gather(all_logits)
175
+
176
+ # The LM head's weights may be zero-padded for parallelism. Remove any
177
+ # extra logits that this padding may have produced.
236
178
  all_logits = all_logits[:, : self.config.vocab_size].float()
237
179
 
238
- if hasattr(self.config, "final_logit_softcapping"):
239
- all_logits.div_(self.config.final_logit_softcapping)
180
+ if self.final_logit_softcapping:
181
+ all_logits.div_(self.final_logit_softcapping)
240
182
  torch.tanh(all_logits, out=all_logits)
241
- all_logits.mul_(self.config.final_logit_softcapping)
183
+ all_logits.mul_(self.final_logit_softcapping)
242
184
 
243
185
  all_logprobs = all_logits
244
186
  del all_logits, hidden_states
245
- all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
187
+
188
+ all_logprobs = self.compute_temp_top_p_normalized_logprobs(
189
+ all_logprobs, logits_metadata
190
+ )
246
191
 
247
192
  # Get the logprob of top-k tokens
248
193
  if logits_metadata.return_top_logprob:
249
- input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
250
- all_logprobs, logits_metadata
251
- )
194
+ (
195
+ input_top_logprobs_val,
196
+ input_top_logprobs_idx,
197
+ output_top_logprobs_val,
198
+ output_top_logprobs_idx,
199
+ ) = self.get_top_logprobs(all_logprobs, logits_metadata)
252
200
  else:
253
- input_top_logprobs = output_top_logprobs = None
201
+ input_top_logprobs_val = input_top_logprobs_idx = (
202
+ output_top_logprobs_val
203
+ ) = output_top_logprobs_idx = None
254
204
 
255
205
  # Compute the normalized logprobs for the requested tokens.
256
206
  # Note that we pad a zero at the end for easy batching.
@@ -273,8 +223,10 @@ class LogitsProcessor(nn.Module):
273
223
  next_token_logprobs=last_logprobs,
274
224
  normalized_prompt_logprobs=normalized_prompt_logprobs,
275
225
  input_token_logprobs=input_token_logprobs,
276
- input_top_logprobs=input_top_logprobs,
277
- output_top_logprobs=output_top_logprobs,
226
+ input_top_logprobs_val=input_top_logprobs_val,
227
+ input_top_logprobs_idx=input_top_logprobs_idx,
228
+ output_top_logprobs_val=output_top_logprobs_val,
229
+ output_top_logprobs_idx=output_top_logprobs_idx,
278
230
  )
279
231
 
280
232
  def _get_logits(
@@ -288,8 +240,94 @@ class LogitsProcessor(nn.Module):
288
240
  else:
289
241
  # GGUF models
290
242
  logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
243
+
244
+ # Optional scaling factor
245
+ if self.logit_scale is not None:
246
+ logits.mul_(self.logit_scale) # In-place multiply
291
247
  return logits
292
248
 
249
+ @staticmethod
250
+ def _get_normalized_prompt_logprobs(
251
+ input_token_logprobs: torch.Tensor,
252
+ logits_metadata: LogitsMetadata,
253
+ ):
254
+ logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
255
+ pruned_lens = torch.tensor(
256
+ logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
257
+ )
258
+
259
+ start = torch.zeros_like(pruned_lens)
260
+ start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
261
+ end = torch.clamp(
262
+ start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
263
+ )
264
+ sum_logp = (
265
+ logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
266
+ )
267
+ normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
268
+ return normalized_prompt_logprobs
269
+
270
+ @staticmethod
271
+ def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
272
+ max_k = max(logits_metadata.top_logprobs_nums)
273
+ ret = all_logprobs.topk(max_k, dim=1)
274
+ values = ret.values.tolist()
275
+ indices = ret.indices.tolist()
276
+
277
+ if logits_metadata.forward_mode.is_decode():
278
+ output_top_logprobs_val = []
279
+ output_top_logprobs_idx = []
280
+ for i, k in enumerate(logits_metadata.top_logprobs_nums):
281
+ output_top_logprobs_val.append(values[i][:k])
282
+ output_top_logprobs_idx.append(indices[i][:k])
283
+ return None, None, output_top_logprobs_val, output_top_logprobs_idx
284
+ else:
285
+ input_top_logprobs_val, input_top_logprobs_idx = [], []
286
+ output_top_logprobs_val, output_top_logprobs_idx = [], []
287
+
288
+ pt = 0
289
+ for k, pruned_len in zip(
290
+ logits_metadata.top_logprobs_nums,
291
+ logits_metadata.extend_logprob_pruned_lens_cpu,
292
+ ):
293
+ if pruned_len <= 0:
294
+ input_top_logprobs_val.append([])
295
+ input_top_logprobs_idx.append([])
296
+ output_top_logprobs_val.append([])
297
+ output_top_logprobs_idx.append([])
298
+ continue
299
+
300
+ input_top_logprobs_val.append(
301
+ [values[pt + j][:k] for j in range(pruned_len - 1)]
302
+ )
303
+ input_top_logprobs_idx.append(
304
+ [indices[pt + j][:k] for j in range(pruned_len - 1)]
305
+ )
306
+ output_top_logprobs_val.append(
307
+ list(
308
+ values[pt + pruned_len - 1][:k],
309
+ )
310
+ )
311
+ output_top_logprobs_idx.append(
312
+ list(
313
+ indices[pt + pruned_len - 1][:k],
314
+ )
315
+ )
316
+ pt += pruned_len
317
+
318
+ return (
319
+ input_top_logprobs_val,
320
+ input_top_logprobs_idx,
321
+ output_top_logprobs_val,
322
+ output_top_logprobs_idx,
323
+ )
324
+
325
+ @staticmethod
326
+ def compute_temp_top_p_normalized_logprobs(
327
+ last_logits: torch.Tensor, logits_metadata: LogitsMetadata
328
+ ) -> torch.Tensor:
329
+ return torch.nn.functional.log_softmax(last_logits, dim=-1)
330
+
293
331
 
294
332
  def test():
295
333
  all_logprobs = torch.tensor(
@@ -12,15 +12,15 @@ from vllm.model_executor.custom_op import CustomOp
12
12
  from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
13
13
 
14
14
  from sglang.srt.layers.custom_op_util import register_custom_op
15
- from sglang.srt.layers.ep_moe.kernels import (
15
+ from sglang.srt.layers.moe.ep_moe.kernels import (
16
16
  grouped_gemm_triton,
17
17
  post_reorder_triton_kernel,
18
18
  pre_reorder_triton_kernel,
19
19
  run_moe_ep_preproess,
20
20
  silu_and_mul_triton_kernel,
21
21
  )
22
- from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk
23
- from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase
22
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
23
+ from sglang.srt.layers.moe.topk import select_experts
24
24
  from sglang.srt.layers.quantization.base_config import (
25
25
  QuantizationConfig,
26
26
  QuantizeMethodBase,
@@ -113,6 +113,7 @@ class EPMoE(torch.nn.Module):
113
113
  quant_config: Optional[QuantizationConfig] = None,
114
114
  tp_size: Optional[int] = None,
115
115
  prefix: str = "",
116
+ correction_bias: Optional[torch.Tensor] = None,
116
117
  ):
117
118
  super().__init__()
118
119
 
@@ -138,6 +139,7 @@ class EPMoE(torch.nn.Module):
138
139
  assert num_expert_group is not None and topk_group is not None
139
140
  self.num_expert_group = num_expert_group
140
141
  self.topk_group = topk_group
142
+ self.correction_bias = correction_bias
141
143
 
142
144
  if quant_config is None:
143
145
  self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
@@ -170,13 +172,15 @@ class EPMoE(torch.nn.Module):
170
172
  hidden_states.device, use_flashinfer=False # TODO: use flashinfer
171
173
  )
172
174
 
173
- topk_weights, topk_ids = self.select_experts(
174
- hidden_states,
175
- router_logits,
176
- self.top_k,
177
- self.renormalize,
178
- self.topk_group,
179
- self.num_expert_group,
175
+ topk_weights, topk_ids = select_experts(
176
+ hidden_states=hidden_states,
177
+ router_logits=router_logits,
178
+ top_k=self.top_k,
179
+ use_grouped_topk=self.use_grouped_topk,
180
+ renormalize=self.renormalize,
181
+ topk_group=self.topk_group,
182
+ num_expert_group=self.num_expert_group,
183
+ correction_bias=self.correction_bias,
180
184
  )
181
185
 
182
186
  reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
@@ -297,35 +301,6 @@ class EPMoE(torch.nn.Module):
297
301
  )
298
302
  return output
299
303
 
300
- def select_experts(
301
- self,
302
- hidden_states: torch.Tensor,
303
- router_logits: torch.Tensor,
304
- top_k: int,
305
- renormalize: bool,
306
- topk_group: Optional[int] = None,
307
- num_expert_group: Optional[int] = None,
308
- ):
309
- if self.use_grouped_topk:
310
- assert topk_group is not None
311
- assert num_expert_group is not None
312
- topk_weights, topk_ids = grouped_topk(
313
- hidden_states=hidden_states,
314
- gating_output=router_logits,
315
- topk=top_k,
316
- renormalize=renormalize,
317
- num_expert_group=num_expert_group,
318
- topk_group=topk_group,
319
- )
320
- else:
321
- topk_weights, topk_ids = fused_topk(
322
- hidden_states=hidden_states,
323
- gating_output=router_logits,
324
- topk=top_k,
325
- renormalize=renormalize,
326
- )
327
- return topk_weights, topk_ids.to(torch.int32)
328
-
329
304
  @classmethod
330
305
  def make_expert_params_mapping(
331
306
  cls,
@@ -644,6 +619,10 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
644
619
  "QuantConfig has static quantization, but found "
645
620
  "activation scales are None."
646
621
  )
622
+ layer.w13_weight_scale = torch.nn.Parameter(
623
+ torch.max(layer.w13_weight_scale, dim=1).values,
624
+ requires_grad=False,
625
+ )
647
626
  return
648
627
 
649
628
  def apply(
@@ -0,0 +1,46 @@
1
+ """
2
+ Torch-native implementation for FusedMoE. This is used for torch.compile.
3
+ It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
4
+ """
5
+
6
+ from typing import Callable, Optional
7
+
8
+ import torch
9
+ from torch.nn import functional as F
10
+
11
+ from sglang.srt.layers.moe.topk import select_experts
12
+
13
+
14
+ def fused_moe_forward_native(
15
+ layer: torch.nn.Module,
16
+ x: torch.Tensor,
17
+ use_grouped_topk: bool,
18
+ top_k: int,
19
+ router_logits: torch.Tensor,
20
+ renormalize: bool,
21
+ topk_group: Optional[int] = None,
22
+ num_expert_group: Optional[int] = None,
23
+ custom_routing_function: Optional[Callable] = None,
24
+ correction_bias: Optional[torch.Tensor] = None,
25
+ ) -> torch.Tensor:
26
+ topk_weights, topk_ids = select_experts(
27
+ hidden_states=x,
28
+ router_logits=router_logits,
29
+ use_grouped_topk=use_grouped_topk,
30
+ top_k=top_k,
31
+ renormalize=renormalize,
32
+ topk_group=topk_group,
33
+ num_expert_group=num_expert_group,
34
+ custom_routing_function=custom_routing_function,
35
+ correction_bias=correction_bias,
36
+ torch_native=True,
37
+ )
38
+
39
+ w13_weights = layer.w13_weight[topk_ids]
40
+ w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
41
+ w2_weights = layer.w2_weight[topk_ids]
42
+ x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
43
+ x1 = F.silu(x1)
44
+ x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
45
+ expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
46
+ return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
@@ -1,14 +1,12 @@
1
1
  from contextlib import contextmanager
2
2
  from typing import Any, Dict, Optional
3
3
 
4
- import sglang.srt.layers.fused_moe_triton.fused_moe # noqa
5
- from sglang.srt.layers.fused_moe_triton.fused_moe import (
4
+ import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa
5
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
6
6
  fused_experts,
7
- fused_topk,
8
7
  get_config_file_name,
9
- grouped_topk,
10
8
  )
11
- from sglang.srt.layers.fused_moe_triton.layer import (
9
+ from sglang.srt.layers.moe.fused_moe_triton.layer import (
12
10
  FusedMoE,
13
11
  FusedMoEMethodBase,
14
12
  FusedMoeWeightScaleSupported,
@@ -37,8 +35,6 @@ __all__ = [
37
35
  "override_config",
38
36
  "get_config",
39
37
  "fused_moe",
40
- "fused_topk",
41
38
  "fused_experts",
42
39
  "get_config_file_name",
43
- "grouped_topk",
44
40
  ]