sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/configs/internvl.py +3 -0
  3. sglang/srt/configs/model_config.py +7 -0
  4. sglang/srt/constrained/base_grammar_backend.py +10 -2
  5. sglang/srt/constrained/xgrammar_backend.py +7 -5
  6. sglang/srt/conversation.py +16 -1
  7. sglang/srt/debug_utils/__init__.py +0 -0
  8. sglang/srt/debug_utils/dump_comparator.py +131 -0
  9. sglang/srt/debug_utils/dumper.py +108 -0
  10. sglang/srt/debug_utils/text_comparator.py +172 -0
  11. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  12. sglang/srt/disaggregation/mooncake/conn.py +16 -0
  13. sglang/srt/disaggregation/prefill.py +13 -1
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +13 -1
  16. sglang/srt/entrypoints/openai/protocol.py +3 -1
  17. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  18. sglang/srt/entrypoints/openai/serving_chat.py +132 -79
  19. sglang/srt/function_call/ebnf_composer.py +10 -3
  20. sglang/srt/function_call/function_call_parser.py +2 -0
  21. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  22. sglang/srt/function_call/qwen3_coder_detector.py +1 -0
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  24. sglang/srt/layers/attention/vision.py +56 -8
  25. sglang/srt/layers/layernorm.py +26 -1
  26. sglang/srt/layers/logits_processor.py +14 -3
  27. sglang/srt/layers/moe/ep_moe/layer.py +323 -242
  28. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  33. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  34. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  35. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  36. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  37. sglang/srt/layers/moe/topk.py +90 -24
  38. sglang/srt/layers/multimodal.py +11 -8
  39. sglang/srt/layers/quantization/fp8.py +25 -247
  40. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  41. sglang/srt/layers/quantization/modelopt_quant.py +27 -10
  42. sglang/srt/layers/quantization/unquant.py +24 -76
  43. sglang/srt/layers/quantization/w4afp8.py +68 -17
  44. sglang/srt/lora/lora_registry.py +93 -29
  45. sglang/srt/managers/cache_controller.py +9 -7
  46. sglang/srt/managers/data_parallel_controller.py +4 -0
  47. sglang/srt/managers/io_struct.py +12 -0
  48. sglang/srt/managers/mm_utils.py +154 -35
  49. sglang/srt/managers/multimodal_processor.py +3 -14
  50. sglang/srt/managers/schedule_batch.py +14 -8
  51. sglang/srt/managers/scheduler.py +64 -1
  52. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  53. sglang/srt/managers/tokenizer_manager.py +80 -15
  54. sglang/srt/managers/tp_worker.py +8 -0
  55. sglang/srt/mem_cache/hiradix_cache.py +5 -2
  56. sglang/srt/model_executor/model_runner.py +83 -27
  57. sglang/srt/models/deepseek_v2.py +75 -84
  58. sglang/srt/models/glm4_moe.py +1035 -0
  59. sglang/srt/models/glm4_moe_nextn.py +167 -0
  60. sglang/srt/models/interns1.py +328 -0
  61. sglang/srt/models/internvl.py +143 -47
  62. sglang/srt/models/llava.py +9 -5
  63. sglang/srt/models/minicpmo.py +4 -1
  64. sglang/srt/models/qwen2_moe.py +2 -2
  65. sglang/srt/models/qwen3_moe.py +17 -71
  66. sglang/srt/multimodal/processors/base_processor.py +20 -6
  67. sglang/srt/multimodal/processors/clip.py +2 -2
  68. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  69. sglang/srt/multimodal/processors/gemma3.py +2 -2
  70. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  71. sglang/srt/multimodal/processors/internvl.py +21 -8
  72. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  73. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  74. sglang/srt/multimodal/processors/llava.py +4 -4
  75. sglang/srt/multimodal/processors/minicpm.py +2 -3
  76. sglang/srt/multimodal/processors/mlama.py +2 -2
  77. sglang/srt/multimodal/processors/mllama4.py +18 -111
  78. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  79. sglang/srt/multimodal/processors/pixtral.py +2 -2
  80. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  81. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  82. sglang/srt/multimodal/processors/vila.py +3 -1
  83. sglang/srt/poll_based_barrier.py +31 -0
  84. sglang/srt/reasoning_parser.py +2 -1
  85. sglang/srt/server_args.py +65 -6
  86. sglang/srt/two_batch_overlap.py +8 -3
  87. sglang/srt/utils.py +96 -1
  88. sglang/srt/weight_sync/utils.py +119 -0
  89. sglang/test/runners.py +4 -0
  90. sglang/test/test_utils.py +118 -5
  91. sglang/utils.py +19 -0
  92. sglang/version.py +1 -1
  93. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/METADATA +5 -4
  94. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/RECORD +97 -80
  95. sglang/srt/debug_utils.py +0 -74
  96. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/WHEEL +0 -0
  97. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/licenses/LICENSE +0 -0
  98. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post6.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
- from typing import List, Optional, Tuple
4
+ from typing import TYPE_CHECKING, List, Optional, Tuple
3
5
 
4
6
  import torch
5
7
 
@@ -30,13 +32,13 @@ from sglang.srt.layers.quantization.base_config import (
30
32
  QuantizationConfig,
31
33
  QuantizeMethodBase,
32
34
  )
33
- from sglang.srt.layers.quantization.fp8 import Fp8EPMoEMethod
35
+ from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
34
36
  from sglang.srt.layers.quantization.fp8_kernel import (
35
37
  is_fp8_fnuz,
36
38
  sglang_per_token_group_quant_fp8,
37
39
  sglang_per_token_quant_fp8,
38
40
  )
39
- from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod
41
+ from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
40
42
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
41
43
  from sglang.srt.managers.schedule_batch import global_server_args_dict
42
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -47,23 +49,40 @@ from sglang.srt.utils import (
47
49
  get_bool_env_var,
48
50
  is_hip,
49
51
  is_npu,
52
+ next_power_of_2,
50
53
  )
51
54
 
55
+ if TYPE_CHECKING:
56
+ from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
57
+ DeepEPLLOutput,
58
+ DeepEPNormalOutput,
59
+ DispatchOutput,
60
+ )
61
+
52
62
  _is_hip = is_hip()
53
63
  _is_npu = is_npu()
54
64
  _is_fp8_fnuz = is_fp8_fnuz()
55
65
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
66
+ use_flashinfer_trtllm_moe = (
67
+ global_server_args_dict["enable_flashinfer_trtllm_moe"]
68
+ and global_server_args_dict["enable_ep_moe"]
69
+ )
56
70
 
57
71
  if not (_is_npu or _is_hip):
58
72
  from sgl_kernel import silu_and_mul
59
73
 
60
- from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
61
-
62
74
  if _use_aiter:
63
75
  from aiter import ActivationType, QuantType
64
76
  from aiter.fused_moe import fused_moe
65
77
  from aiter.ops.shuffle import shuffle_weight
66
78
 
79
+ if use_flashinfer_trtllm_moe:
80
+ try:
81
+ import flashinfer.fused_moe as fi_fused_moe
82
+ except ImportError:
83
+ fi_fused_moe = None
84
+ use_flashinfer_trtllm_moe = False
85
+
67
86
  logger = logging.getLogger(__name__)
68
87
 
69
88
 
@@ -140,7 +159,17 @@ class GroupedGemmRunner(torch.nn.Module):
140
159
  return c
141
160
 
142
161
 
143
- class EPMoE(torch.nn.Module):
162
+ def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
163
+ # Guess tokens per expert assuming perfect expert distribution first.
164
+ num_tokens_per_expert = (num_tokens * top_k) // num_experts
165
+ # And pad the number to the next power of 2.
166
+ tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
167
+ # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
168
+ tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
169
+ return tile_tokens_dim
170
+
171
+
172
+ class EPMoE(FusedMoE):
144
173
  """
145
174
  MoE Expert Parallel Impl
146
175
 
@@ -162,51 +191,60 @@ class EPMoE(torch.nn.Module):
162
191
  routed_scaling_factor: Optional[float] = None,
163
192
  use_per_token_if_dynamic: bool = True,
164
193
  ):
165
- super().__init__()
194
+ super().__init__(
195
+ num_experts=num_experts,
196
+ hidden_size=hidden_size,
197
+ intermediate_size=intermediate_size,
198
+ top_k=top_k,
199
+ layer_id=layer_id,
200
+ params_dtype=params_dtype,
201
+ quant_config=quant_config,
202
+ tp_size=tp_size,
203
+ prefix=prefix,
204
+ activation=activation,
205
+ routed_scaling_factor=routed_scaling_factor,
206
+ enable_ep_moe=True,
207
+ skip_quant=True,
208
+ )
166
209
 
167
210
  if params_dtype is None:
168
211
  params_dtype = torch.get_default_dtype()
169
212
 
170
- self.tp_size = (
171
- tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
172
- )
173
- self.tp_rank = get_tensor_model_parallel_rank()
174
-
175
213
  self.layer_id = layer_id
176
- self.num_experts = num_experts
177
- assert self.num_experts % self.tp_size == 0
178
- self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
179
- self.start_expert_id = self.tp_rank * self.num_experts_per_partition
180
- self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
214
+ self.num_local_experts, self.expert_map = self.determine_expert_map()
215
+ self.start_expert_id = self.ep_rank * self.num_local_experts
216
+ self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
181
217
 
182
- self.top_k = top_k
183
218
  self.intermediate_size = intermediate_size
184
- self.activation = activation
185
- self.routed_scaling_factor = routed_scaling_factor
186
219
  self.use_per_token_if_dynamic = use_per_token_if_dynamic
187
220
 
221
+ # TODO(ch-wan): move quant preparation to FusedMoE
188
222
  if quant_config is None:
189
- self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
223
+ self.quant_method: Optional[QuantizeMethodBase] = (
224
+ UnquantizedFusedMoEMethod()
225
+ )
190
226
  self.use_fp8_w8a8 = False
191
227
  self.use_block_quant = False
192
228
  self.block_shape = None
193
229
  self.activation_scheme = None
194
- self.use_w4afp8 = False
230
+ self.w13_input_scale = None
231
+ self.w2_input_scale = None
232
+ self.w13_weight_scale = None
233
+ self.w2_weight_scale = None
195
234
  elif isinstance(quant_config, W4AFp8Config):
196
235
  self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
197
236
  quant_config
198
237
  )
199
- self.use_w4afp8 = True
200
238
  self.use_fp8_w8a8 = False
201
239
  self.use_block_quant = False
202
240
  self.fp8_dtype = torch.float8_e4m3fn
241
+ self.w13_input_scale = None
242
+ self.w2_input_scale = None
203
243
  self.w13_weight_scale = None
204
244
  self.w2_weight_scale = None
205
245
  self.activation_scheme = quant_config.moe_activation_scheme
206
- else:
207
- self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
208
- quant_config
209
- )
246
+ elif isinstance(quant_config, Fp8Config):
247
+ self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
210
248
  self.use_fp8_w8a8 = True
211
249
  self.use_block_quant = getattr(self.quant_method, "block_quant", False)
212
250
  self.block_shape = (
@@ -216,11 +254,13 @@ class EPMoE(torch.nn.Module):
216
254
  )
217
255
  self.fp8_dtype = torch.float8_e4m3fn
218
256
  self.activation_scheme = quant_config.activation_scheme
219
- self.use_w4afp8 = False
257
+ else:
258
+ raise ValueError(f"Unsupported quant_config: {quant_config}")
220
259
 
260
+ self.quant_config = quant_config
221
261
  self.quant_method.create_weights(
222
262
  layer=self,
223
- num_experts_per_partition=self.num_experts_per_partition,
263
+ num_experts=self.num_local_experts,
224
264
  hidden_size=hidden_size,
225
265
  intermediate_size=self.intermediate_size,
226
266
  params_dtype=params_dtype,
@@ -229,19 +269,6 @@ class EPMoE(torch.nn.Module):
229
269
 
230
270
  self.grouped_gemm_runner = None
231
271
 
232
- self.w13_weight_fp8 = (
233
- self.w13_weight,
234
- (
235
- self.w13_weight_scale_inv
236
- if self.use_block_quant
237
- else self.w13_weight_scale
238
- ),
239
- )
240
- self.w2_weight_fp8 = (
241
- self.w2_weight,
242
- self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
243
- )
244
-
245
272
  # Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
246
273
  # Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
247
274
  def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
@@ -260,8 +287,8 @@ class EPMoE(torch.nn.Module):
260
287
  Contains global_num_experts for experts not assigned to the current rank.
261
288
  Returns None if ep_size is 1.
262
289
  """
263
- ep_size = self.tp_size
264
- ep_rank = self.tp_rank
290
+ ep_size = self.ep_size
291
+ ep_rank = self.ep_rank
265
292
  global_num_experts = self.num_experts
266
293
 
267
294
  assert ep_size > 0
@@ -271,7 +298,7 @@ class EPMoE(torch.nn.Module):
271
298
  local_num_experts = global_num_experts // ep_size
272
299
 
273
300
  expert_map = torch.full(
274
- (global_num_experts,), self.num_experts, dtype=torch.int32
301
+ (global_num_experts,), global_num_experts, dtype=torch.int32
275
302
  )
276
303
  if ep_rank < (ep_size - 1):
277
304
  expert_map[
@@ -296,6 +323,20 @@ class EPMoE(torch.nn.Module):
296
323
  hidden_states: torch.Tensor,
297
324
  topk_output: TopKOutput,
298
325
  ):
326
+
327
+ self.w13_weight_fp8 = (
328
+ self.w13_weight,
329
+ (
330
+ self.w13_weight_scale_inv
331
+ if self.use_block_quant
332
+ else self.w13_weight_scale
333
+ ),
334
+ )
335
+ self.w2_weight_fp8 = (
336
+ self.w2_weight,
337
+ self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
338
+ )
339
+
299
340
  assert self.quant_method is not None
300
341
  assert self.activation == "silu"
301
342
  hidden_states_shape = hidden_states.shape
@@ -435,7 +476,10 @@ class EPMoE(torch.nn.Module):
435
476
  return output
436
477
 
437
478
  def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
438
- assert self.quant_method is not None
479
+ return self.quant_method.apply(self, hidden_states, topk_output)
480
+
481
+ def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
482
+
439
483
  topk_weights, topk_ids, _ = topk_output
440
484
 
441
485
  hidden_states_shape = hidden_states.shape
@@ -448,53 +492,11 @@ class EPMoE(torch.nn.Module):
448
492
  use_per_token_if_dynamic=self.use_per_token_if_dynamic,
449
493
  )
450
494
 
451
- if self.use_w4afp8:
452
- local_topk_ids = topk_ids
453
- if self.expert_map is not None:
454
- "Translate info from expert_map to topk_ids"
455
- local_topk_ids = torch.where(
456
- self.expert_map[topk_ids] != self.num_experts,
457
- self.expert_map[topk_ids],
458
- self.num_experts,
459
- )
460
-
461
- output = cutlass_w4a8_moe(
462
- self.start_expert_id,
463
- self.end_expert_id,
464
- self.num_experts,
465
- hidden_states,
466
- self.w13_weight,
467
- self.w2_weight,
468
- self.w13_weight_scale_inv,
469
- self.w2_weight_scale_inv,
470
- topk_weights,
471
- topk_ids,
472
- local_topk_ids,
473
- self.quant_method.a_strides1,
474
- self.quant_method.b_strides1,
475
- self.quant_method.c_strides1,
476
- self.quant_method.a_strides2,
477
- self.quant_method.b_strides2,
478
- self.quant_method.c_strides2,
479
- self.quant_method.s_strides13,
480
- self.quant_method.s_strides2,
481
- self.quant_method.expert_offsets,
482
- self.quant_method.problem_sizes1,
483
- self.quant_method.problem_sizes2,
484
- self.w13_input_scale,
485
- self.w2_input_scale,
486
- )
487
- return output
488
-
489
- if self.grouped_gemm_runner is None:
490
- self.grouped_gemm_runner = GroupedGemmRunner(
491
- hidden_states.device,
492
- use_flashinfer=False, # TODO: use flashinfer
493
- use_per_token_if_dynamic=self.use_per_token_if_dynamic,
494
- )
495
+ num_experts = self.num_experts
495
496
 
496
497
  reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
497
- topk_ids, self.num_experts
498
+ topk_ids,
499
+ num_experts,
498
500
  )
499
501
 
500
502
  gateup_input = torch.empty(
@@ -502,7 +504,7 @@ class EPMoE(torch.nn.Module):
502
504
  device=hidden_states.device,
503
505
  dtype=(
504
506
  self.fp8_dtype
505
- if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant)
507
+ if self.use_fp8_w8a8 and not self.use_block_quant
506
508
  else hidden_states.dtype
507
509
  ),
508
510
  )
@@ -513,7 +515,7 @@ class EPMoE(torch.nn.Module):
513
515
  else:
514
516
  max_value = (
515
517
  torch.max(hidden_states)
516
- .repeat(self.num_experts_per_partition)
518
+ .repeat(self.num_local_experts)
517
519
  .to(torch.float32)
518
520
  )
519
521
  self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
@@ -554,7 +556,7 @@ class EPMoE(torch.nn.Module):
554
556
  seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
555
557
  weight_indices_cur_rank = torch.arange(
556
558
  0,
557
- self.num_experts_per_partition,
559
+ self.num_local_experts,
558
560
  device=hidden_states_device,
559
561
  dtype=torch.int64,
560
562
  )
@@ -564,17 +566,13 @@ class EPMoE(torch.nn.Module):
564
566
  b=self.w13_weight,
565
567
  c=None,
566
568
  c_dtype=hidden_states_dtype,
567
- batch_size=self.num_experts_per_partition,
569
+ batch_size=self.num_local_experts,
568
570
  weight_column_major=True,
569
571
  seg_indptr=seg_indptr_cur_rank,
570
572
  weight_indices=weight_indices_cur_rank,
571
573
  use_fp8_w8a8=self.use_fp8_w8a8,
572
574
  scale_a=self.w13_input_scale,
573
- scale_b=(
574
- self.w13_weight_scale_inv
575
- if self.use_block_quant
576
- else self.w13_weight_scale
577
- ),
575
+ scale_b=self.w13_weight_scale,
578
576
  block_shape=self.block_shape,
579
577
  )
580
578
  del gateup_input
@@ -631,7 +629,7 @@ class EPMoE(torch.nn.Module):
631
629
  down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
632
630
  else:
633
631
  self.w2_input_scale = torch.ones(
634
- self.num_experts_per_partition,
632
+ self.num_local_experts,
635
633
  dtype=torch.float32,
636
634
  device=hidden_states_device,
637
635
  )
@@ -647,17 +645,13 @@ class EPMoE(torch.nn.Module):
647
645
  a=down_input,
648
646
  b=self.w2_weight,
649
647
  c=down_output,
650
- batch_size=self.num_experts_per_partition,
648
+ batch_size=self.num_local_experts,
651
649
  weight_column_major=True,
652
650
  seg_indptr=seg_indptr_cur_rank,
653
651
  weight_indices=weight_indices_cur_rank,
654
652
  use_fp8_w8a8=self.use_fp8_w8a8,
655
653
  scale_a=self.w2_input_scale,
656
- scale_b=(
657
- self.w2_weight_scale_inv
658
- if self.use_block_quant
659
- else self.w2_weight_scale
660
- ),
654
+ scale_b=self.w2_weight_scale,
661
655
  block_shape=self.block_shape,
662
656
  )
663
657
  del down_input
@@ -760,95 +754,14 @@ class EPMoE(torch.nn.Module):
760
754
  return
761
755
  expert_id = expert_id - self.start_expert_id
762
756
 
763
- if shard_id not in ("w1", "w2", "w3"):
764
- raise ValueError(
765
- f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
766
- )
767
-
768
- # Special case for fp8 scales.
769
- if "scale" in weight_name:
770
- self._load_fp8_scale(
771
- param.data,
772
- loaded_weight,
773
- weight_name,
774
- shard_id,
775
- expert_id,
776
- )
777
- return
778
-
779
- if shard_id == "w2":
780
- param.data[expert_id] = loaded_weight
781
- elif shard_id == "w1":
782
- param.data[expert_id][: self.intermediate_size, :] = loaded_weight
783
- elif shard_id == "w3":
784
- param.data[expert_id][self.intermediate_size :, :] = loaded_weight
785
- else:
786
- raise ValueError(f"Expected shard_id w1,w2 or w3 but got {shard_id}")
787
-
788
- def _load_fp8_scale(
789
- self,
790
- param: torch.nn.Parameter,
791
- loaded_weight: torch.Tensor,
792
- weight_name: str,
793
- shard_id: str,
794
- expert_id: int,
795
- ) -> None:
796
- param_data = param.data
797
-
798
- # Input scales can be loaded directly and should be equal.
799
- if "input_scale" in weight_name:
800
- if self.use_w4afp8:
801
- if shard_id == "w1":
802
- param_data[expert_id][0] = loaded_weight
803
- elif shard_id == "w3":
804
- param_data[expert_id][1] = loaded_weight
805
- else:
806
- param_data[expert_id] = loaded_weight
807
- return
808
-
809
- if (
810
- (shard_id == "w1" or shard_id == "w3")
811
- and param_data[expert_id] != 1
812
- and (param_data[expert_id] - loaded_weight).abs() > 1e-5
813
- ):
814
- raise ValueError(
815
- "input_scales of w1 and w3 of a layer "
816
- f"must be equal. But got {param_data[expert_id]} "
817
- f"vs. {loaded_weight}"
818
- )
819
- param_data[expert_id] = loaded_weight
820
- # Weight scales
821
- elif "weight_scale" in weight_name:
822
- if self.use_block_quant:
823
- block_n, block_k = self.block_shape[0], self.block_shape[1]
824
- if shard_id == "w1":
825
- param_data[expert_id][
826
- : (self.intermediate_size + block_n - 1) // block_n, :
827
- ] = loaded_weight
828
- elif shard_id == "w3":
829
- param_data[expert_id][
830
- (self.intermediate_size + block_n - 1) // block_n :, :
831
- ] = loaded_weight
832
- else: # w2
833
- param_data[expert_id] = loaded_weight
834
- elif self.use_w4afp8:
835
- if shard_id == "w1":
836
- param_data[expert_id][: self.intermediate_size, :] = loaded_weight
837
- elif shard_id == "w3":
838
- param_data[expert_id][self.intermediate_size :, :] = loaded_weight
839
- else:
840
- param_data[expert_id] = loaded_weight
841
- # If we are in merged column case (gate_up_proj)
842
- else:
843
- if shard_id in ("w1", "w3"):
844
- # We have to keep the weight scales of w1 and w3 because
845
- # we need to re-quantize w1/w3 weights after weight loading.
846
- idx = 0 if shard_id == "w1" else 1
847
- param_data[expert_id][idx] = loaded_weight
848
-
849
- # If we are in the row parallel case (down_proj)
850
- else:
851
- param_data[expert_id] = loaded_weight
757
+ self._weight_loader_impl(
758
+ param=param,
759
+ loaded_weight=loaded_weight,
760
+ weight_name=weight_name,
761
+ shard_id=shard_id,
762
+ expert_id=expert_id,
763
+ )
764
+ return
852
765
 
853
766
 
854
767
  class DeepEPMoE(EPMoE):
@@ -887,24 +800,37 @@ class DeepEPMoE(EPMoE):
887
800
  routed_scaling_factor=routed_scaling_factor,
888
801
  )
889
802
  self.deepep_mode = deepep_mode
890
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
891
- assert self.use_fp8_w8a8, (
892
- "DeepGEMM requires an fp8_w8a8 model; "
893
- "alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable."
894
- )
803
+
804
+ # TODO: move to the beginning of the file
805
+ from sglang.srt.distributed.parallel_state import get_tp_group
806
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
807
+ from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
808
+
809
+ self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
810
+ group=get_tp_group().device_group,
811
+ router_topk=self.top_k,
812
+ permute_fusion=True,
813
+ num_experts=self.num_experts,
814
+ num_local_experts=self.num_local_experts,
815
+ hidden_size=hidden_size,
816
+ params_dtype=params_dtype,
817
+ deepep_mode=deepep_mode,
818
+ async_finish=True, # TODO
819
+ return_recv_hook=True,
820
+ )
895
821
 
896
822
  if self.deepep_mode.enable_low_latency():
897
823
  assert (
898
824
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
899
825
  ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
900
826
  if _use_aiter:
901
- # expert_mask is of size (self.num_experts_per_partition + 1),
827
+ # expert_mask is of size (self.num_local_experts + 1),
902
828
  # the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
903
829
  # for instance, if we have 4 experts on this rank, we would have a expert_mask like:
904
830
  # self.expert_mask = [1, 1, 1, 1, 0]
905
831
  # idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
906
832
  self.expert_mask = torch.zeros(
907
- (self.num_experts_per_partition + 1),
833
+ (self.num_local_experts + 1),
908
834
  device=torch.cuda.current_device(),
909
835
  dtype=torch.int,
910
836
  )
@@ -933,37 +859,128 @@ class DeepEPMoE(EPMoE):
933
859
  hidden_states: torch.Tensor,
934
860
  topk_idx: torch.Tensor,
935
861
  topk_weights: torch.Tensor,
936
- reorder_topk_ids: torch.Tensor,
937
- seg_indptr: torch.Tensor,
938
- masked_m: torch.Tensor,
939
- expected_m: int,
940
- num_recv_tokens_per_expert: List[int],
941
862
  forward_batch: ForwardBatch,
942
863
  ):
864
+ dispatch_output = self.dispatch(
865
+ hidden_states, topk_idx, topk_weights, forward_batch
866
+ )
867
+ hidden_states = self.moe_impl(dispatch_output)
868
+ hidden_states = self.combine(
869
+ hidden_states,
870
+ dispatch_output.topk_idx,
871
+ dispatch_output.topk_weights,
872
+ forward_batch,
873
+ )
874
+ return hidden_states
875
+
876
+ def dispatch(
877
+ self,
878
+ hidden_states: torch.Tensor,
879
+ topk_idx: torch.Tensor,
880
+ topk_weights: torch.Tensor,
881
+ forward_batch: ForwardBatch,
882
+ ):
883
+ return self.deepep_dispatcher.dispatch(
884
+ hidden_states=hidden_states,
885
+ topk_idx=topk_idx,
886
+ topk_weights=topk_weights,
887
+ forward_batch=forward_batch,
888
+ )
889
+
890
+ def moe_impl(self, dispatch_output: DispatchOutput):
943
891
  if _use_aiter:
944
892
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
945
- return self.forward_aiter(hidden_states, topk_idx, topk_weights)
946
- resolved_deepep_mode = self.deepep_mode.resolve(
947
- forward_batch.is_extend_in_batch
948
- )
949
- if resolved_deepep_mode == DeepEPMode.normal:
950
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
951
- return self.forward_deepgemm_contiguous(
952
- hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
953
- )
893
+ return self.forward_aiter(dispatch_output)
894
+ if dispatch_output.format.is_deepep_normal():
895
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
896
+ return self.forward_deepgemm_contiguous(dispatch_output)
954
897
  else:
955
- return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
956
- elif resolved_deepep_mode == DeepEPMode.low_latency:
957
- return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
898
+ return self.forward_normal(dispatch_output)
899
+ elif dispatch_output.format.is_deepep_ll():
900
+ return self.forward_deepgemm_masked(dispatch_output)
958
901
  else:
959
902
  raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
960
903
 
961
- def forward_normal(
904
+ def combine(
962
905
  self,
963
906
  hidden_states: torch.Tensor,
964
- reorder_topk_ids: torch.Tensor,
965
- seg_indptr: torch.Tensor,
907
+ topk_idx: torch.Tensor,
908
+ topk_weights: torch.Tensor,
909
+ forward_batch: ForwardBatch,
910
+ ):
911
+ return self.deepep_dispatcher.combine(
912
+ hidden_states=hidden_states,
913
+ topk_idx=topk_idx,
914
+ topk_weights=topk_weights,
915
+ forward_batch=forward_batch,
916
+ )
917
+
918
+ def _prepare_for_normal(
919
+ self,
920
+ hidden_states: torch.Tensor,
921
+ topk_idx: torch.Tensor,
922
+ ):
923
+ from sglang.srt.layers.moe.ep_moe.kernels import (
924
+ deepep_permute_triton_kernel,
925
+ deepep_run_moe_deep_preprocess,
926
+ )
927
+
928
+ if hidden_states.shape[0] == 0:
929
+ reorder_topk_ids = torch.empty(
930
+ (0,), device=hidden_states.device, dtype=torch.int64
931
+ )
932
+ seg_indptr = torch.zeros(
933
+ (self.num_experts + 1,),
934
+ device=hidden_states.device,
935
+ dtype=torch.int64,
936
+ )
937
+ return reorder_topk_ids, seg_indptr, hidden_states
938
+ else:
939
+ if _use_aiter:
940
+ # skip permutation here as aiter fused_moe has fused inside
941
+ reorder_topk_ids = torch.empty(
942
+ (0,), device=hidden_states.device, dtype=torch.int64
943
+ )
944
+ seg_indptr = torch.zeros(
945
+ (self.num_experts + 1,),
946
+ device=hidden_states.device,
947
+ dtype=torch.int64,
948
+ )
949
+ return reorder_topk_ids, seg_indptr, hidden_states
950
+
951
+ reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
952
+ topk_idx, self.num_experts
953
+ )
954
+ num_total_tokens = reorder_topk_ids.numel()
955
+ gateup_input = torch.empty(
956
+ (int(num_total_tokens), hidden_states.shape[1]),
957
+ device=hidden_states.device,
958
+ dtype=hidden_states.dtype,
959
+ )
960
+ # PreReorder
961
+ deepep_permute_triton_kernel[(hidden_states.shape[0],)](
962
+ hidden_states,
963
+ gateup_input,
964
+ self.src2dst,
965
+ topk_idx,
966
+ None,
967
+ self.router_topk,
968
+ hidden_states.shape[1],
969
+ BLOCK_SIZE=512,
970
+ )
971
+ return reorder_topk_ids, seg_indptr, gateup_input
972
+
973
+ def forward_normal(
974
+ self,
975
+ dispatch_output: DeepEPNormalOutput,
966
976
  ):
977
+ hidden_states, topk_idx = (
978
+ dispatch_output.hidden_states,
979
+ dispatch_output.topk_idx,
980
+ )
981
+ reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
982
+ hidden_states, topk_idx
983
+ )
967
984
  hidden_states_dtype = hidden_states.dtype
968
985
  hidden_states_device = hidden_states.device
969
986
 
@@ -977,13 +994,13 @@ class DeepEPMoE(EPMoE):
977
994
  if self.activation_scheme == "dynamic" and not self.use_block_quant:
978
995
  max_value = (
979
996
  torch.max(hidden_states)
980
- .repeat(self.num_experts_per_partition)
997
+ .repeat(self.num_local_experts)
981
998
  .to(torch.float32)
982
999
  )
983
1000
  self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
984
1001
  weight_indices_cur_rank = torch.arange(
985
1002
  0,
986
- self.num_experts_per_partition,
1003
+ self.num_local_experts,
987
1004
  device=hidden_states.device,
988
1005
  dtype=torch.int64,
989
1006
  )
@@ -995,7 +1012,7 @@ class DeepEPMoE(EPMoE):
995
1012
  b=self.w13_weight,
996
1013
  c=None,
997
1014
  c_dtype=hidden_states.dtype,
998
- batch_size=self.num_experts_per_partition,
1015
+ batch_size=self.num_local_experts,
999
1016
  weight_column_major=True,
1000
1017
  seg_indptr=seg_indptr,
1001
1018
  weight_indices=weight_indices_cur_rank,
@@ -1029,7 +1046,7 @@ class DeepEPMoE(EPMoE):
1029
1046
  )
1030
1047
  if self.w2_input_scale is None and not self.use_block_quant:
1031
1048
  self.w2_input_scale = torch.ones(
1032
- self.num_experts_per_partition,
1049
+ self.num_local_experts,
1033
1050
  dtype=torch.float32,
1034
1051
  device=hidden_states_device,
1035
1052
  )
@@ -1042,7 +1059,7 @@ class DeepEPMoE(EPMoE):
1042
1059
  reorder_topk_ids,
1043
1060
  self.w2_input_scale,
1044
1061
  0,
1045
- self.num_experts_per_partition - 1,
1062
+ self.num_local_experts - 1,
1046
1063
  BLOCK_SIZE=512,
1047
1064
  )
1048
1065
  else:
@@ -1062,7 +1079,7 @@ class DeepEPMoE(EPMoE):
1062
1079
  a=down_input,
1063
1080
  b=self.w2_weight,
1064
1081
  c=down_output,
1065
- batch_size=self.num_experts_per_partition,
1082
+ batch_size=self.num_local_experts,
1066
1083
  weight_column_major=True,
1067
1084
  seg_indptr=seg_indptr,
1068
1085
  weight_indices=weight_indices_cur_rank,
@@ -1079,17 +1096,20 @@ class DeepEPMoE(EPMoE):
1079
1096
 
1080
1097
  def forward_aiter(
1081
1098
  self,
1082
- hidden_states: torch.Tensor,
1083
- topk_idx: torch.Tensor,
1084
- topk_weights: torch.Tensor,
1099
+ dispatch_output: DeepEPNormalOutput,
1085
1100
  ):
1101
+ hidden_states, topk_idx, topk_weights = (
1102
+ dispatch_output.hidden_states,
1103
+ dispatch_output.topk_idx,
1104
+ dispatch_output.topk_weights,
1105
+ )
1086
1106
  if hidden_states.shape[0] == 0:
1087
1107
  return hidden_states
1088
1108
  # in original deepep, idx == -1 meaning invalid and will not be processed.
1089
1109
  # aiter does not accept -1, we use a expert mask to make these idx invalid
1090
- # (idx == num_experts_per_partition) meaning not used in aiter fused_moe
1110
+ # (idx == num_local_experts) meaning not used in aiter fused_moe
1091
1111
  topk_idx_copy = topk_idx.to(torch.int32)
1092
- topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition
1112
+ topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts
1093
1113
 
1094
1114
  return fused_moe(
1095
1115
  hidden_states,
@@ -1110,11 +1130,11 @@ class DeepEPMoE(EPMoE):
1110
1130
 
1111
1131
  def forward_deepgemm_contiguous(
1112
1132
  self,
1113
- hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
1114
- topk_idx,
1115
- topk_weights,
1116
- num_recv_tokens_per_expert: List[int],
1133
+ dispatch_output: DeepEPNormalOutput,
1117
1134
  ):
1135
+ hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
1136
+ dispatch_output
1137
+ )
1118
1138
  hidden_states_fp8, hidden_states_scale = hidden_states_fp8
1119
1139
  assert self.quant_method is not None
1120
1140
  assert self.activation == "silu"
@@ -1234,10 +1254,9 @@ class DeepEPMoE(EPMoE):
1234
1254
 
1235
1255
  def forward_deepgemm_masked(
1236
1256
  self,
1237
- hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
1238
- masked_m: torch.Tensor,
1239
- expected_m: int,
1257
+ dispatch_output: DeepEPLLOutput,
1240
1258
  ):
1259
+ hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
1241
1260
  assert self.quant_method is not None
1242
1261
  assert self.activation == "silu"
1243
1262
 
@@ -1315,12 +1334,74 @@ class DeepEPMoE(EPMoE):
1315
1334
  return down_output
1316
1335
 
1317
1336
 
1337
+ class FlashInferEPMoE(EPMoE):
1338
+ def __init__(self, *args, **kwargs):
1339
+ renormalize = kwargs.pop("renormalize", True)
1340
+ num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
1341
+ use_grouped_topk = kwargs.pop("use_grouped_topk", False)
1342
+ num_expert_group = kwargs.pop("num_expert_group", None)
1343
+ topk_group = kwargs.pop("topk_group", None)
1344
+ correction_bias = kwargs.pop("correction_bias", None)
1345
+ super().__init__(*args, **kwargs)
1346
+ self.renormalize = renormalize
1347
+ self.num_fused_shared_experts = num_fused_shared_experts
1348
+ self.use_grouped_topk = use_grouped_topk
1349
+ if self.use_grouped_topk:
1350
+ assert num_expert_group is not None and topk_group is not None
1351
+ self.num_expert_group = num_expert_group
1352
+ self.topk_group = topk_group
1353
+ self.correction_bias = correction_bias
1354
+ self.use_flashinfer_trtllm_moe = use_flashinfer_trtllm_moe
1355
+
1356
+ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
1357
+ assert use_flashinfer_trtllm_moe
1358
+ assert (
1359
+ self.activation == "silu"
1360
+ ), "Only silu is supported for flashinfer blockscale fp8 moe"
1361
+ assert (
1362
+ self.renormalize
1363
+ ), "Renormalize is required for flashinfer blockscale fp8 moe"
1364
+ assert (
1365
+ self.num_fused_shared_experts == 0
1366
+ ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
1367
+ a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
1368
+ # NOTE: scales of hidden states have to be transposed!
1369
+ a_sf_t = a_sf.t().contiguous()
1370
+ assert fi_fused_moe is not None
1371
+ return fi_fused_moe.trtllm_fp8_block_scale_moe(
1372
+ routing_logits=router_logits.to(torch.float32),
1373
+ routing_bias=self.correction_bias.to(hidden_states.dtype),
1374
+ hidden_states=a_q,
1375
+ hidden_states_scale=a_sf_t,
1376
+ gemm1_weights=self.w13_weight,
1377
+ gemm1_weights_scale=self.w13_weight_scale_inv,
1378
+ gemm2_weights=self.w2_weight,
1379
+ gemm2_weights_scale=self.w2_weight_scale_inv,
1380
+ num_experts=self.num_experts,
1381
+ top_k=self.top_k,
1382
+ n_group=self.num_expert_group,
1383
+ topk_group=self.topk_group,
1384
+ intermediate_size=self.w2_weight.shape[2],
1385
+ local_expert_offset=self.start_expert_id,
1386
+ local_num_experts=self.num_local_experts,
1387
+ routed_scaling_factor=self.routed_scaling_factor,
1388
+ tile_tokens_dim=_get_tile_tokens_dim(
1389
+ hidden_states.shape[0], self.top_k, self.num_experts
1390
+ ),
1391
+ routing_method_type=2, # DeepSeek-styled routing method
1392
+ use_shuffled_weight=False,
1393
+ )
1394
+
1395
+
1318
1396
  def get_moe_impl_class():
1319
1397
  if global_server_args_dict["enable_deepep_moe"]:
1320
1398
  return DeepEPMoE
1321
- if global_server_args_dict["enable_flashinfer_moe"]:
1399
+ if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
1322
1400
  # Must come before EPMoE because FusedMoE also supports enable_ep_moe
1323
1401
  return FusedMoE
1402
+ if use_flashinfer_trtllm_moe:
1403
+ # Must come before EPMoE because FusedMoE also supports enable_ep_moe
1404
+ return FlashInferEPMoE
1324
1405
  if global_server_args_dict["enable_ep_moe"]:
1325
1406
  return EPMoE
1326
1407
  return FusedMoE