sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +8 -0
  3. sglang/srt/configs/model_config.py +6 -0
  4. sglang/srt/configs/step3_vl.py +172 -0
  5. sglang/srt/conversation.py +23 -0
  6. sglang/srt/disaggregation/decode.py +2 -8
  7. sglang/srt/disaggregation/prefill.py +2 -6
  8. sglang/srt/distributed/parallel_state.py +86 -1
  9. sglang/srt/entrypoints/engine.py +14 -18
  10. sglang/srt/entrypoints/http_server.py +23 -3
  11. sglang/srt/entrypoints/openai/protocol.py +3 -1
  12. sglang/srt/entrypoints/openai/serving_base.py +5 -2
  13. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  14. sglang/srt/eplb/expert_distribution.py +5 -0
  15. sglang/srt/eplb/expert_location.py +17 -6
  16. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  17. sglang/srt/eplb/expert_location_updater.py +2 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/step3_detector.py +436 -0
  20. sglang/srt/hf_transformers_utils.py +2 -0
  21. sglang/srt/jinja_template_utils.py +4 -1
  22. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  23. sglang/srt/layers/moe/ep_moe/layer.py +98 -603
  24. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
  29. sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
  30. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
  31. sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
  32. sglang/srt/layers/moe/topk.py +6 -2
  33. sglang/srt/layers/quantization/fp8.py +0 -18
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -0
  35. sglang/srt/layers/quantization/unquant.py +0 -8
  36. sglang/srt/layers/quantization/w4afp8.py +1 -0
  37. sglang/srt/managers/cache_controller.py +143 -45
  38. sglang/srt/managers/data_parallel_controller.py +6 -0
  39. sglang/srt/managers/io_struct.py +12 -2
  40. sglang/srt/managers/scheduler.py +116 -669
  41. sglang/srt/managers/scheduler_input_blocker.py +106 -0
  42. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  43. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  44. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  45. sglang/srt/managers/template_manager.py +62 -19
  46. sglang/srt/managers/tokenizer_manager.py +166 -83
  47. sglang/srt/managers/tp_worker.py +9 -0
  48. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  49. sglang/srt/mem_cache/hicache_storage.py +45 -11
  50. sglang/srt/mem_cache/hiradix_cache.py +15 -4
  51. sglang/srt/mem_cache/memory_pool_host.py +73 -1
  52. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  53. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  54. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
  55. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  56. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  57. sglang/srt/model_executor/model_runner.py +20 -13
  58. sglang/srt/models/arcee.py +532 -0
  59. sglang/srt/models/deepseek_v2.py +15 -56
  60. sglang/srt/models/glm4_moe.py +3 -1
  61. sglang/srt/models/granitemoe.py +3 -0
  62. sglang/srt/models/grok.py +3 -0
  63. sglang/srt/models/hunyuan.py +1 -0
  64. sglang/srt/models/llama4.py +3 -0
  65. sglang/srt/models/mixtral.py +3 -0
  66. sglang/srt/models/olmoe.py +3 -0
  67. sglang/srt/models/phimoe.py +1 -0
  68. sglang/srt/models/qwen3_moe.py +12 -69
  69. sglang/srt/models/step3_vl.py +994 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/poll_based_barrier.py +31 -0
  73. sglang/srt/reasoning_parser.py +2 -1
  74. sglang/srt/server_args.py +18 -13
  75. sglang/srt/speculative/eagle_worker.py +2 -0
  76. sglang/srt/two_batch_overlap.py +8 -3
  77. sglang/test/test_utils.py +53 -0
  78. sglang/utils.py +0 -11
  79. sglang/version.py +1 -1
  80. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
  81. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
  82. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
- from typing import List, Optional, Tuple
4
+ from typing import TYPE_CHECKING, List, Optional, Tuple
3
5
 
4
6
  import torch
5
7
 
@@ -50,6 +52,13 @@ from sglang.srt.utils import (
50
52
  next_power_of_2,
51
53
  )
52
54
 
55
+ if TYPE_CHECKING:
56
+ from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
57
+ DeepEPLLOutput,
58
+ DeepEPNormalOutput,
59
+ DispatchOutput,
60
+ )
61
+
53
62
  _is_hip = is_hip()
54
63
  _is_npu = is_npu()
55
64
  _is_fp8_fnuz = is_fp8_fnuz()
@@ -77,79 +86,6 @@ if use_flashinfer_trtllm_moe:
77
86
  logger = logging.getLogger(__name__)
78
87
 
79
88
 
80
- class GroupedGemmRunner(torch.nn.Module):
81
- flashinfer_gemm_warpper = None
82
-
83
- def __init__(
84
- self,
85
- device,
86
- use_flashinfer: bool = False,
87
- use_per_token_if_dynamic: bool = True,
88
- ):
89
- super().__init__()
90
- self.device = device
91
- self.use_flashinfer = use_flashinfer
92
- self.use_per_token_if_dynamic = use_per_token_if_dynamic
93
- if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
94
- GroupedGemmRunner._init_flashinfer_wrapper(device)
95
-
96
- @classmethod
97
- def _init_flashinfer_wrapper(cls, device):
98
- from flashinfer import SegmentGEMMWrapper
99
-
100
- workspace_buffer = torch.empty(
101
- 128 * 1024 * 1024, dtype=torch.int8, device=device
102
- )
103
- cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer)
104
-
105
- # c = a * b
106
- def forward(
107
- self,
108
- a: torch.Tensor,
109
- b: torch.Tensor,
110
- c: torch.Tensor,
111
- batch_size: int,
112
- weight_column_major: bool,
113
- seg_indptr: Optional[torch.Tensor] = None,
114
- weight_indices: Optional[torch.Tensor] = None,
115
- use_fp8_w8a8: bool = False,
116
- scale_a: torch.Tensor = None,
117
- scale_b: torch.Tensor = None,
118
- block_shape: Optional[List[int]] = None,
119
- c_dtype=None,
120
- ):
121
- if self.use_flashinfer:
122
- # TODO: flashinfer
123
- assert False
124
- assert GroupedGemmRunner.flashinfer_gemm_warpper is not None
125
- c = GroupedGemmRunner.flashinfer_gemm_warpper.run(
126
- x=a,
127
- weights=b,
128
- batch_size=batch_size,
129
- weight_column_major=weight_column_major,
130
- seg_indptr=seg_indptr,
131
- weight_indices=weight_indices,
132
- )
133
- else:
134
- assert weight_column_major == True
135
- c = grouped_gemm_triton(
136
- a,
137
- b,
138
- c,
139
- batch_size,
140
- weight_column_major,
141
- seg_indptr,
142
- weight_indices,
143
- use_fp8_w8a8,
144
- scale_a,
145
- scale_b,
146
- block_shape=block_shape,
147
- c_dtype=c_dtype,
148
- use_per_token_if_dynamic=self.use_per_token_if_dynamic,
149
- )
150
- return c
151
-
152
-
153
89
  def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
154
90
  # Guess tokens per expert assuming perfect expert distribution first.
155
91
  num_tokens_per_expert = (num_tokens * top_k) // num_experts
@@ -174,140 +110,57 @@ class EPMoE(FusedMoE):
174
110
  hidden_size: int,
175
111
  intermediate_size: int,
176
112
  layer_id: int,
113
+ num_fused_shared_experts: int = 0,
177
114
  params_dtype: Optional[torch.dtype] = None,
178
115
  quant_config: Optional[QuantizationConfig] = None,
179
116
  tp_size: Optional[int] = None,
180
117
  prefix: str = "",
181
118
  activation: str = "silu",
182
119
  routed_scaling_factor: Optional[float] = None,
183
- use_per_token_if_dynamic: bool = True,
184
120
  ):
185
121
  super().__init__(
186
122
  num_experts=num_experts,
187
123
  hidden_size=hidden_size,
188
124
  intermediate_size=intermediate_size,
189
- top_k=top_k,
125
+ num_fused_shared_experts=num_fused_shared_experts,
190
126
  layer_id=layer_id,
127
+ top_k=top_k,
191
128
  params_dtype=params_dtype,
192
129
  quant_config=quant_config,
193
130
  tp_size=tp_size,
194
131
  prefix=prefix,
195
132
  activation=activation,
133
+ # apply_router_weight_on_input=apply_router_weight_on_input,
196
134
  routed_scaling_factor=routed_scaling_factor,
197
135
  enable_ep_moe=True,
198
- skip_quant=True,
199
136
  )
200
137
 
201
- if params_dtype is None:
202
- params_dtype = torch.get_default_dtype()
203
-
204
- self.layer_id = layer_id
205
- self.num_local_experts, self.expert_map = self.determine_expert_map()
206
- self.start_expert_id = self.ep_rank * self.num_local_experts
138
+ self.start_expert_id = self.moe_ep_rank * self.num_local_experts
207
139
  self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
208
140
 
209
141
  self.intermediate_size = intermediate_size
210
- self.use_per_token_if_dynamic = use_per_token_if_dynamic
211
142
 
212
- # TODO(ch-wan): move quant preparation to FusedMoE
213
- if quant_config is None:
214
- self.quant_method: Optional[QuantizeMethodBase] = (
215
- UnquantizedFusedMoEMethod()
216
- )
217
- self.use_fp8_w8a8 = False
218
- self.use_block_quant = False
219
- self.block_shape = None
220
- self.activation_scheme = None
221
- self.w13_input_scale = None
222
- self.w2_input_scale = None
223
- self.w13_weight_scale = None
224
- self.w2_weight_scale = None
225
- elif isinstance(quant_config, W4AFp8Config):
226
- self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
227
- quant_config
228
- )
229
- self.use_fp8_w8a8 = False
230
- self.use_block_quant = False
231
- self.fp8_dtype = torch.float8_e4m3fn
232
- self.w13_input_scale = None
233
- self.w2_input_scale = None
234
- self.w13_weight_scale = None
235
- self.w2_weight_scale = None
236
- self.activation_scheme = quant_config.moe_activation_scheme
237
- elif isinstance(quant_config, Fp8Config):
238
- self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
239
- self.use_fp8_w8a8 = True
143
+ if isinstance(quant_config, Fp8Config):
240
144
  self.use_block_quant = getattr(self.quant_method, "block_quant", False)
241
145
  self.block_shape = (
242
146
  self.quant_method.quant_config.weight_block_size
243
147
  if self.use_block_quant
244
148
  else None
245
149
  )
150
+ self.use_fp8_w8a8 = True
246
151
  self.fp8_dtype = torch.float8_e4m3fn
247
152
  self.activation_scheme = quant_config.activation_scheme
248
153
  else:
249
- raise ValueError(f"Unsupported quant_config: {quant_config}")
250
-
251
- self.quant_config = quant_config
252
- self.quant_method.create_weights(
253
- layer=self,
254
- num_experts=self.num_local_experts,
255
- hidden_size=hidden_size,
256
- intermediate_size=self.intermediate_size,
257
- params_dtype=params_dtype,
258
- weight_loader=self.weight_loader,
259
- )
260
-
261
- self.grouped_gemm_runner = None
262
-
263
- # Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
264
- # Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
265
- def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
266
- """
267
- Calculates how many experts should be assigned to each rank for EP and
268
- creates a mapping from global to local expert index. Experts are
269
- distributed evenly across ranks. Any remaining are assigned to the
270
- last rank.
271
-
272
- Returns:
273
- Tuple[int, Optional[torch.Tensor]]: A tuple containing:
274
- - local_num_experts (int): The number of experts assigned
275
- to the current rank.
276
- - expert_map (Optional[torch.Tensor]): A tensor of shape
277
- (global_num_experts,) mapping from global to local index.
278
- Contains global_num_experts for experts not assigned to the current rank.
279
- Returns None if ep_size is 1.
280
- """
281
- ep_size = self.ep_size
282
- ep_rank = self.ep_rank
283
- global_num_experts = self.num_experts
284
-
285
- assert ep_size > 0
286
- if ep_size == 1:
287
- return (global_num_experts, None)
288
-
289
- local_num_experts = global_num_experts // ep_size
290
-
291
- expert_map = torch.full(
292
- (global_num_experts,), global_num_experts, dtype=torch.int32
293
- )
294
- if ep_rank < (ep_size - 1):
295
- expert_map[
296
- ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
297
- ] = torch.arange(0, local_num_experts, dtype=torch.int32)
298
- else:
299
- local_num_experts = global_num_experts - ep_rank * local_num_experts
300
-
301
- expert_map[-local_num_experts:] = torch.arange(
302
- 0, local_num_experts, dtype=torch.int32
303
- )
304
- return (local_num_experts, expert_map)
154
+ self.use_fp8_w8a8 = False
155
+ self.use_block_quant = False
156
+ self.block_shape = None
157
+ self.activation_scheme = None
305
158
 
306
159
  def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
307
160
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
308
161
  return self.forward_deepgemm(hidden_states, topk_output)
309
162
  else:
310
- return self.forward_normal(hidden_states, topk_output)
163
+ return super().forward(hidden_states, topk_output)
311
164
 
312
165
  def forward_deepgemm(
313
166
  self,
@@ -466,294 +319,6 @@ class EPMoE(FusedMoE):
466
319
  )
467
320
  return output
468
321
 
469
- def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
470
- return self.quant_method.apply(self, hidden_states, topk_output)
471
-
472
- def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
473
-
474
- topk_weights, topk_ids, _ = topk_output
475
-
476
- hidden_states_shape = hidden_states.shape
477
- hidden_states_dtype = hidden_states.dtype
478
- hidden_states_device = hidden_states.device
479
- if self.grouped_gemm_runner is None:
480
- self.grouped_gemm_runner = GroupedGemmRunner(
481
- hidden_states.device,
482
- use_flashinfer=False, # TODO: use flashinfer
483
- use_per_token_if_dynamic=self.use_per_token_if_dynamic,
484
- )
485
-
486
- num_experts = self.num_experts
487
-
488
- reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
489
- topk_ids,
490
- num_experts,
491
- )
492
-
493
- gateup_input = torch.empty(
494
- (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
495
- device=hidden_states.device,
496
- dtype=(
497
- self.fp8_dtype
498
- if self.use_fp8_w8a8 and not self.use_block_quant
499
- else hidden_states.dtype
500
- ),
501
- )
502
- if self.activation_scheme == "dynamic" and not self.use_block_quant:
503
- if self.use_per_token_if_dynamic:
504
- max_value = torch.max(hidden_states, dim=1).values.to(torch.float32)
505
- self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
506
- else:
507
- max_value = (
508
- torch.max(hidden_states)
509
- .repeat(self.num_local_experts)
510
- .to(torch.float32)
511
- )
512
- self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
513
-
514
- # PreReorder
515
- pre_reorder_triton_kernel[(hidden_states.shape[0],)](
516
- hidden_states,
517
- gateup_input,
518
- src2dst,
519
- topk_ids,
520
- self.w13_input_scale,
521
- self.start_expert_id,
522
- self.end_expert_id,
523
- self.top_k,
524
- hidden_states.shape[1],
525
- BLOCK_SIZE=512,
526
- use_per_token_if_dynamic=self.use_per_token_if_dynamic,
527
- )
528
- dispose_tensor(hidden_states)
529
-
530
- if (
531
- self.activation_scheme == "dynamic"
532
- and not self.use_block_quant
533
- and self.use_per_token_if_dynamic
534
- ):
535
- scale = torch.empty(
536
- hidden_states_shape[0] * self.top_k,
537
- device=hidden_states_device,
538
- dtype=torch.float32,
539
- )
540
- scale[src2dst] = (
541
- self.w13_input_scale.unsqueeze(1)
542
- .expand(hidden_states_shape[0], self.top_k)
543
- .reshape(-1)
544
- )
545
- self.w13_input_scale = scale
546
-
547
- seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
548
- weight_indices_cur_rank = torch.arange(
549
- 0,
550
- self.num_local_experts,
551
- device=hidden_states_device,
552
- dtype=torch.int64,
553
- )
554
- # GroupGemm-0
555
- gateup_output = self.grouped_gemm_runner(
556
- a=gateup_input,
557
- b=self.w13_weight,
558
- c=None,
559
- c_dtype=hidden_states_dtype,
560
- batch_size=self.num_local_experts,
561
- weight_column_major=True,
562
- seg_indptr=seg_indptr_cur_rank,
563
- weight_indices=weight_indices_cur_rank,
564
- use_fp8_w8a8=self.use_fp8_w8a8,
565
- scale_a=self.w13_input_scale,
566
- scale_b=self.w13_weight_scale,
567
- block_shape=self.block_shape,
568
- )
569
- del gateup_input
570
-
571
- # Act
572
- if self.activation_scheme == "dynamic" and not self.use_block_quant:
573
- self.w2_input_scale = None
574
- down_input = torch.empty(
575
- gateup_output.shape[0],
576
- gateup_output.shape[1] // 2,
577
- device=gateup_output.device,
578
- dtype=hidden_states_dtype,
579
- )
580
- else:
581
- down_input = torch.empty(
582
- gateup_output.shape[0],
583
- gateup_output.shape[1] // 2,
584
- device=gateup_output.device,
585
- dtype=(
586
- self.fp8_dtype
587
- if (self.use_fp8_w8a8 and not self.use_block_quant)
588
- else hidden_states_dtype
589
- ),
590
- )
591
-
592
- if self.activation == "silu":
593
- silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
594
- gateup_output,
595
- down_input,
596
- gateup_output.shape[1],
597
- reorder_topk_ids,
598
- self.w2_input_scale,
599
- self.start_expert_id,
600
- self.end_expert_id,
601
- BLOCK_SIZE=512,
602
- )
603
- elif self.activation == "gelu":
604
- gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
605
- gateup_output,
606
- down_input,
607
- gateup_output.shape[1],
608
- reorder_topk_ids,
609
- self.w2_input_scale,
610
- self.start_expert_id,
611
- self.end_expert_id,
612
- BLOCK_SIZE=512,
613
- )
614
- else:
615
- raise ValueError(f"Unsupported activation: {self.activation=}")
616
- del gateup_output
617
-
618
- if self.activation_scheme == "dynamic" and not self.use_block_quant:
619
- if self.use_per_token_if_dynamic:
620
- down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
621
- else:
622
- self.w2_input_scale = torch.ones(
623
- self.num_local_experts,
624
- dtype=torch.float32,
625
- device=hidden_states_device,
626
- )
627
-
628
- # GroupGemm-1
629
- down_output = torch.empty(
630
- down_input.shape[0],
631
- self.w2_weight.shape[1],
632
- device=hidden_states_device,
633
- dtype=hidden_states_dtype,
634
- )
635
- down_output = self.grouped_gemm_runner(
636
- a=down_input,
637
- b=self.w2_weight,
638
- c=down_output,
639
- batch_size=self.num_local_experts,
640
- weight_column_major=True,
641
- seg_indptr=seg_indptr_cur_rank,
642
- weight_indices=weight_indices_cur_rank,
643
- use_fp8_w8a8=self.use_fp8_w8a8,
644
- scale_a=self.w2_input_scale,
645
- scale_b=self.w2_weight_scale,
646
- block_shape=self.block_shape,
647
- )
648
- del down_input
649
-
650
- # PostReorder
651
- output = torch.empty(
652
- hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
653
- )
654
- post_reorder_triton_kernel[(hidden_states_shape[0],)](
655
- down_output,
656
- output,
657
- src2dst,
658
- topk_ids,
659
- topk_weights,
660
- self.start_expert_id,
661
- self.end_expert_id,
662
- self.top_k,
663
- hidden_states_shape[1],
664
- 0,
665
- BLOCK_SIZE=512,
666
- )
667
- return output
668
-
669
- @classmethod
670
- def make_expert_params_mapping(
671
- cls,
672
- ckpt_gate_proj_name: str,
673
- ckpt_down_proj_name: str,
674
- ckpt_up_proj_name: str,
675
- num_experts: int,
676
- ) -> List[Tuple[str, str, int, str]]:
677
- return [
678
- # (param_name, weight_name, expert_id, shard_id)
679
- (
680
- (
681
- "experts.w13_"
682
- if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
683
- else "experts.w2_"
684
- ),
685
- f"experts.{expert_id}.{weight_name}.",
686
- expert_id,
687
- shard_id,
688
- )
689
- for expert_id in range(num_experts)
690
- for shard_id, weight_name in [
691
- ("w1", ckpt_gate_proj_name),
692
- ("w2", ckpt_down_proj_name),
693
- ("w3", ckpt_up_proj_name),
694
- ]
695
- ]
696
-
697
- @classmethod
698
- def make_expert_input_scale_params_mapping(
699
- cls,
700
- num_experts: int,
701
- ) -> List[Tuple[str, str, int, str]]:
702
- # (param_name, weight_name, expert_id, shard_id)
703
- return [
704
- (
705
- "experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
706
- f"experts.{expert_id}.{shard_id}.",
707
- expert_id,
708
- shard_id,
709
- )
710
- for expert_id in range(num_experts)
711
- for shard_id in ["w1", "w2", "w3"]
712
- ]
713
-
714
- def weight_loader(
715
- self,
716
- param: torch.nn.Parameter,
717
- loaded_weight: torch.Tensor,
718
- weight_name: str,
719
- shard_id: str,
720
- expert_id: int,
721
- ) -> None:
722
- physical_expert_ids = (
723
- get_global_expert_location_metadata().logical_to_all_physical(
724
- self.layer_id, expert_id
725
- )
726
- )
727
- for physical_expert_id in physical_expert_ids:
728
- self._weight_loader_physical(
729
- param=param,
730
- loaded_weight=loaded_weight,
731
- weight_name=weight_name,
732
- shard_id=shard_id,
733
- expert_id=physical_expert_id,
734
- )
735
-
736
- def _weight_loader_physical(
737
- self,
738
- param: torch.nn.Parameter,
739
- loaded_weight: torch.Tensor,
740
- weight_name: str,
741
- shard_id: str,
742
- expert_id: int,
743
- ) -> None:
744
- if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
745
- return
746
- expert_id = expert_id - self.start_expert_id
747
-
748
- self._weight_loader_impl(
749
- param=param,
750
- loaded_weight=loaded_weight,
751
- weight_name=weight_name,
752
- shard_id=shard_id,
753
- expert_id=expert_id,
754
- )
755
- return
756
-
757
322
 
758
323
  class DeepEPMoE(EPMoE):
759
324
  """
@@ -769,6 +334,7 @@ class DeepEPMoE(EPMoE):
769
334
  hidden_size: int,
770
335
  intermediate_size: int,
771
336
  layer_id: int,
337
+ num_fused_shared_experts: int = 0,
772
338
  params_dtype: Optional[torch.dtype] = None,
773
339
  quant_config: Optional[QuantizationConfig] = None,
774
340
  tp_size: Optional[int] = None,
@@ -783,6 +349,7 @@ class DeepEPMoE(EPMoE):
783
349
  hidden_size=hidden_size,
784
350
  intermediate_size=intermediate_size,
785
351
  layer_id=layer_id,
352
+ num_fused_shared_experts=num_fused_shared_experts,
786
353
  params_dtype=params_dtype,
787
354
  quant_config=quant_config,
788
355
  tp_size=tp_size,
@@ -791,11 +358,24 @@ class DeepEPMoE(EPMoE):
791
358
  routed_scaling_factor=routed_scaling_factor,
792
359
  )
793
360
  self.deepep_mode = deepep_mode
794
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
795
- assert self.use_fp8_w8a8, (
796
- "DeepGEMM requires an fp8_w8a8 model; "
797
- "alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable."
798
- )
361
+
362
+ # TODO: move to the beginning of the file
363
+ from sglang.srt.distributed.parallel_state import get_tp_group
364
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
365
+ from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
366
+
367
+ self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
368
+ group=get_tp_group().device_group,
369
+ router_topk=self.top_k,
370
+ permute_fusion=True,
371
+ num_experts=self.num_experts,
372
+ num_local_experts=self.num_local_experts,
373
+ hidden_size=hidden_size,
374
+ params_dtype=params_dtype,
375
+ deepep_mode=deepep_mode,
376
+ async_finish=True, # TODO
377
+ return_recv_hook=True,
378
+ )
799
379
 
800
380
  if self.deepep_mode.enable_low_latency():
801
381
  assert (
@@ -837,156 +417,72 @@ class DeepEPMoE(EPMoE):
837
417
  hidden_states: torch.Tensor,
838
418
  topk_idx: torch.Tensor,
839
419
  topk_weights: torch.Tensor,
840
- reorder_topk_ids: torch.Tensor,
841
- seg_indptr: torch.Tensor,
842
- masked_m: torch.Tensor,
843
- expected_m: int,
844
- num_recv_tokens_per_expert: List[int],
845
420
  forward_batch: ForwardBatch,
846
421
  ):
847
- if _use_aiter:
848
- # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
849
- return self.forward_aiter(hidden_states, topk_idx, topk_weights)
850
- resolved_deepep_mode = self.deepep_mode.resolve(
851
- forward_batch.is_extend_in_batch
422
+ dispatch_output = self.dispatch(
423
+ hidden_states, topk_idx, topk_weights, forward_batch
852
424
  )
853
- if resolved_deepep_mode == DeepEPMode.normal:
854
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
855
- return self.forward_deepgemm_contiguous(
856
- hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
857
- )
858
- else:
859
- return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
860
- elif resolved_deepep_mode == DeepEPMode.low_latency:
861
- return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
862
- else:
863
- raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
425
+ hidden_states = self.moe_impl(dispatch_output)
426
+ hidden_states = self.combine(
427
+ hidden_states,
428
+ dispatch_output.topk_idx,
429
+ dispatch_output.topk_weights,
430
+ forward_batch,
431
+ )
432
+ return hidden_states
864
433
 
865
- def forward_normal(
434
+ def dispatch(
866
435
  self,
867
436
  hidden_states: torch.Tensor,
868
- reorder_topk_ids: torch.Tensor,
869
- seg_indptr: torch.Tensor,
437
+ topk_idx: torch.Tensor,
438
+ topk_weights: torch.Tensor,
439
+ forward_batch: ForwardBatch,
870
440
  ):
871
- hidden_states_dtype = hidden_states.dtype
872
- hidden_states_device = hidden_states.device
873
-
874
- assert self.quant_method is not None
875
- assert self.activation == "silu"
876
- if self.grouped_gemm_runner is None:
877
- self.grouped_gemm_runner = GroupedGemmRunner(
878
- hidden_states.device, use_flashinfer=False # TODO: use flashinfer
879
- )
880
-
881
- if self.activation_scheme == "dynamic" and not self.use_block_quant:
882
- max_value = (
883
- torch.max(hidden_states)
884
- .repeat(self.num_local_experts)
885
- .to(torch.float32)
886
- )
887
- self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
888
- weight_indices_cur_rank = torch.arange(
889
- 0,
890
- self.num_local_experts,
891
- device=hidden_states.device,
892
- dtype=torch.int64,
441
+ return self.deepep_dispatcher.dispatch(
442
+ hidden_states=hidden_states,
443
+ topk_idx=topk_idx,
444
+ topk_weights=topk_weights,
445
+ forward_batch=forward_batch,
893
446
  )
894
447
 
895
- # GroupGemm-0
896
- if hidden_states.shape[0] > 0:
897
- gateup_output = self.grouped_gemm_runner(
898
- a=hidden_states,
899
- b=self.w13_weight,
900
- c=None,
901
- c_dtype=hidden_states.dtype,
902
- batch_size=self.num_local_experts,
903
- weight_column_major=True,
904
- seg_indptr=seg_indptr,
905
- weight_indices=weight_indices_cur_rank,
906
- use_fp8_w8a8=self.use_fp8_w8a8,
907
- scale_a=self.w13_input_scale,
908
- scale_b=(
909
- self.w13_weight_scale_inv
910
- if self.use_block_quant
911
- else self.w13_weight_scale
912
- ),
913
- block_shape=self.block_shape,
914
- )
915
- else:
916
- gateup_output = torch.empty(
917
- hidden_states.shape[0],
918
- self.w13_weight.shape[1],
919
- device=hidden_states.device,
920
- dtype=hidden_states.dtype,
921
- )
922
-
923
- # Act
924
- down_input = torch.empty(
925
- gateup_output.shape[0],
926
- gateup_output.shape[1] // 2,
927
- device=gateup_output.device,
928
- dtype=(
929
- self.fp8_dtype
930
- if (self.use_fp8_w8a8 and not self.use_block_quant)
931
- else hidden_states_dtype
932
- ),
933
- )
934
- if self.w2_input_scale is None and not self.use_block_quant:
935
- self.w2_input_scale = torch.ones(
936
- self.num_local_experts,
937
- dtype=torch.float32,
938
- device=hidden_states_device,
939
- )
940
-
941
- if self.activation == "silu":
942
- silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
943
- gateup_output,
944
- down_input,
945
- gateup_output.shape[1],
946
- reorder_topk_ids,
947
- self.w2_input_scale,
948
- 0,
949
- self.num_local_experts - 1,
950
- BLOCK_SIZE=512,
951
- )
448
+ def moe_impl(self, dispatch_output: DispatchOutput):
449
+ if _use_aiter:
450
+ # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
451
+ return self.forward_aiter(dispatch_output)
452
+ if dispatch_output.format.is_deepep_normal():
453
+ assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
454
+ return self.forward_deepgemm_contiguous(dispatch_output)
455
+ elif dispatch_output.format.is_deepep_ll():
456
+ assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
457
+ return self.forward_deepgemm_masked(dispatch_output)
952
458
  else:
953
- raise ValueError(f"Unsupported activation: {self.activation=}")
954
-
955
- del gateup_output
956
-
957
- # GroupGemm-1
958
- down_output = torch.empty(
959
- down_input.shape[0],
960
- self.w2_weight.shape[1],
961
- device=hidden_states_device,
962
- dtype=hidden_states_dtype,
963
- )
964
- if down_input.shape[0] > 0:
965
- down_output = self.grouped_gemm_runner(
966
- a=down_input,
967
- b=self.w2_weight,
968
- c=down_output,
969
- batch_size=self.num_local_experts,
970
- weight_column_major=True,
971
- seg_indptr=seg_indptr,
972
- weight_indices=weight_indices_cur_rank,
973
- use_fp8_w8a8=self.use_fp8_w8a8,
974
- scale_a=self.w2_input_scale,
975
- scale_b=(
976
- self.w2_weight_scale_inv
977
- if self.use_block_quant
978
- else self.w2_weight_scale
979
- ),
980
- block_shape=self.block_shape,
459
+ raise ValueError(
460
+ f"Dispatch output format {dispatch_output.format} is not supported"
981
461
  )
982
- return down_output
983
462
 
984
- def forward_aiter(
463
+ def combine(
985
464
  self,
986
465
  hidden_states: torch.Tensor,
987
466
  topk_idx: torch.Tensor,
988
467
  topk_weights: torch.Tensor,
468
+ forward_batch: ForwardBatch,
989
469
  ):
470
+ return self.deepep_dispatcher.combine(
471
+ hidden_states=hidden_states,
472
+ topk_idx=topk_idx,
473
+ topk_weights=topk_weights,
474
+ forward_batch=forward_batch,
475
+ )
476
+
477
+ def forward_aiter(
478
+ self,
479
+ dispatch_output: DeepEPNormalOutput,
480
+ ):
481
+ hidden_states, topk_idx, topk_weights = (
482
+ dispatch_output.hidden_states,
483
+ dispatch_output.topk_idx,
484
+ dispatch_output.topk_weights,
485
+ )
990
486
  if hidden_states.shape[0] == 0:
991
487
  return hidden_states
992
488
  # in original deepep, idx == -1 meaning invalid and will not be processed.
@@ -1014,11 +510,11 @@ class DeepEPMoE(EPMoE):
1014
510
 
1015
511
  def forward_deepgemm_contiguous(
1016
512
  self,
1017
- hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
1018
- topk_idx,
1019
- topk_weights,
1020
- num_recv_tokens_per_expert: List[int],
513
+ dispatch_output: DeepEPNormalOutput,
1021
514
  ):
515
+ hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
516
+ dispatch_output
517
+ )
1022
518
  hidden_states_fp8, hidden_states_scale = hidden_states_fp8
1023
519
  assert self.quant_method is not None
1024
520
  assert self.activation == "silu"
@@ -1138,10 +634,9 @@ class DeepEPMoE(EPMoE):
1138
634
 
1139
635
  def forward_deepgemm_masked(
1140
636
  self,
1141
- hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
1142
- masked_m: torch.Tensor,
1143
- expected_m: int,
637
+ dispatch_output: DeepEPLLOutput,
1144
638
  ):
639
+ hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
1145
640
  assert self.quant_method is not None
1146
641
  assert self.activation == "silu"
1147
642
 
@@ -1268,7 +763,7 @@ class FlashInferEPMoE(EPMoE):
1268
763
  topk_group=self.topk_group,
1269
764
  intermediate_size=self.w2_weight.shape[2],
1270
765
  local_expert_offset=self.start_expert_id,
1271
- local_num_experts=self.num_experts_per_partition,
766
+ local_num_experts=self.num_local_experts,
1272
767
  routed_scaling_factor=self.routed_scaling_factor,
1273
768
  tile_tokens_dim=_get_tile_tokens_dim(
1274
769
  hidden_states.shape[0], self.top_k, self.num_experts