sglang 0.4.9.post6__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +8 -0
  3. sglang/srt/configs/model_config.py +3 -0
  4. sglang/srt/configs/step3_vl.py +172 -0
  5. sglang/srt/conversation.py +23 -0
  6. sglang/srt/disaggregation/decode.py +2 -8
  7. sglang/srt/disaggregation/prefill.py +2 -6
  8. sglang/srt/distributed/parallel_state.py +86 -1
  9. sglang/srt/entrypoints/engine.py +14 -18
  10. sglang/srt/entrypoints/http_server.py +10 -2
  11. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  12. sglang/srt/eplb/expert_distribution.py +5 -0
  13. sglang/srt/eplb/expert_location.py +17 -6
  14. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  15. sglang/srt/eplb/expert_location_updater.py +2 -0
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/step3_detector.py +436 -0
  18. sglang/srt/hf_transformers_utils.py +2 -0
  19. sglang/srt/jinja_template_utils.py +4 -1
  20. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  21. sglang/srt/layers/moe/ep_moe/layer.py +20 -640
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
  24. sglang/srt/layers/quantization/fp8.py +0 -18
  25. sglang/srt/layers/quantization/unquant.py +0 -8
  26. sglang/srt/layers/quantization/w4afp8.py +1 -0
  27. sglang/srt/managers/cache_controller.py +143 -45
  28. sglang/srt/managers/data_parallel_controller.py +2 -0
  29. sglang/srt/managers/io_struct.py +0 -2
  30. sglang/srt/managers/scheduler.py +89 -671
  31. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  32. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  33. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  34. sglang/srt/managers/template_manager.py +62 -19
  35. sglang/srt/managers/tokenizer_manager.py +123 -74
  36. sglang/srt/managers/tp_worker.py +4 -0
  37. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  38. sglang/srt/mem_cache/hicache_storage.py +45 -11
  39. sglang/srt/mem_cache/hiradix_cache.py +15 -4
  40. sglang/srt/mem_cache/memory_pool_host.py +73 -1
  41. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  42. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  43. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
  44. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  45. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  46. sglang/srt/model_executor/model_runner.py +5 -0
  47. sglang/srt/models/arcee.py +532 -0
  48. sglang/srt/models/deepseek_v2.py +2 -0
  49. sglang/srt/models/glm4_moe.py +3 -1
  50. sglang/srt/models/granitemoe.py +3 -0
  51. sglang/srt/models/grok.py +3 -0
  52. sglang/srt/models/hunyuan.py +1 -0
  53. sglang/srt/models/llama4.py +3 -0
  54. sglang/srt/models/mixtral.py +3 -0
  55. sglang/srt/models/olmoe.py +3 -0
  56. sglang/srt/models/phimoe.py +1 -0
  57. sglang/srt/models/step3_vl.py +994 -0
  58. sglang/srt/multimodal/processors/base_processor.py +15 -16
  59. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  60. sglang/srt/reasoning_parser.py +2 -1
  61. sglang/srt/server_args.py +10 -13
  62. sglang/srt/speculative/eagle_worker.py +2 -0
  63. sglang/utils.py +0 -11
  64. sglang/version.py +1 -1
  65. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/METADATA +3 -4
  66. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/RECORD +69 -56
  67. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
  68. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
  69. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -86,79 +86,6 @@ if use_flashinfer_trtllm_moe:
86
86
  logger = logging.getLogger(__name__)
87
87
 
88
88
 
89
- class GroupedGemmRunner(torch.nn.Module):
90
- flashinfer_gemm_warpper = None
91
-
92
- def __init__(
93
- self,
94
- device,
95
- use_flashinfer: bool = False,
96
- use_per_token_if_dynamic: bool = True,
97
- ):
98
- super().__init__()
99
- self.device = device
100
- self.use_flashinfer = use_flashinfer
101
- self.use_per_token_if_dynamic = use_per_token_if_dynamic
102
- if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
103
- GroupedGemmRunner._init_flashinfer_wrapper(device)
104
-
105
- @classmethod
106
- def _init_flashinfer_wrapper(cls, device):
107
- from flashinfer import SegmentGEMMWrapper
108
-
109
- workspace_buffer = torch.empty(
110
- 128 * 1024 * 1024, dtype=torch.int8, device=device
111
- )
112
- cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer)
113
-
114
- # c = a * b
115
- def forward(
116
- self,
117
- a: torch.Tensor,
118
- b: torch.Tensor,
119
- c: torch.Tensor,
120
- batch_size: int,
121
- weight_column_major: bool,
122
- seg_indptr: Optional[torch.Tensor] = None,
123
- weight_indices: Optional[torch.Tensor] = None,
124
- use_fp8_w8a8: bool = False,
125
- scale_a: torch.Tensor = None,
126
- scale_b: torch.Tensor = None,
127
- block_shape: Optional[List[int]] = None,
128
- c_dtype=None,
129
- ):
130
- if self.use_flashinfer:
131
- # TODO: flashinfer
132
- assert False
133
- assert GroupedGemmRunner.flashinfer_gemm_warpper is not None
134
- c = GroupedGemmRunner.flashinfer_gemm_warpper.run(
135
- x=a,
136
- weights=b,
137
- batch_size=batch_size,
138
- weight_column_major=weight_column_major,
139
- seg_indptr=seg_indptr,
140
- weight_indices=weight_indices,
141
- )
142
- else:
143
- assert weight_column_major == True
144
- c = grouped_gemm_triton(
145
- a,
146
- b,
147
- c,
148
- batch_size,
149
- weight_column_major,
150
- seg_indptr,
151
- weight_indices,
152
- use_fp8_w8a8,
153
- scale_a,
154
- scale_b,
155
- block_shape=block_shape,
156
- c_dtype=c_dtype,
157
- use_per_token_if_dynamic=self.use_per_token_if_dynamic,
158
- )
159
- return c
160
-
161
-
162
89
  def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
163
90
  # Guess tokens per expert assuming perfect expert distribution first.
164
91
  num_tokens_per_expert = (num_tokens * top_k) // num_experts
@@ -183,140 +110,57 @@ class EPMoE(FusedMoE):
183
110
  hidden_size: int,
184
111
  intermediate_size: int,
185
112
  layer_id: int,
113
+ num_fused_shared_experts: int = 0,
186
114
  params_dtype: Optional[torch.dtype] = None,
187
115
  quant_config: Optional[QuantizationConfig] = None,
188
116
  tp_size: Optional[int] = None,
189
117
  prefix: str = "",
190
118
  activation: str = "silu",
191
119
  routed_scaling_factor: Optional[float] = None,
192
- use_per_token_if_dynamic: bool = True,
193
120
  ):
194
121
  super().__init__(
195
122
  num_experts=num_experts,
196
123
  hidden_size=hidden_size,
197
124
  intermediate_size=intermediate_size,
198
- top_k=top_k,
125
+ num_fused_shared_experts=num_fused_shared_experts,
199
126
  layer_id=layer_id,
127
+ top_k=top_k,
200
128
  params_dtype=params_dtype,
201
129
  quant_config=quant_config,
202
130
  tp_size=tp_size,
203
131
  prefix=prefix,
204
132
  activation=activation,
133
+ # apply_router_weight_on_input=apply_router_weight_on_input,
205
134
  routed_scaling_factor=routed_scaling_factor,
206
135
  enable_ep_moe=True,
207
- skip_quant=True,
208
136
  )
209
137
 
210
- if params_dtype is None:
211
- params_dtype = torch.get_default_dtype()
212
-
213
- self.layer_id = layer_id
214
- self.num_local_experts, self.expert_map = self.determine_expert_map()
215
- self.start_expert_id = self.ep_rank * self.num_local_experts
138
+ self.start_expert_id = self.moe_ep_rank * self.num_local_experts
216
139
  self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
217
140
 
218
141
  self.intermediate_size = intermediate_size
219
- self.use_per_token_if_dynamic = use_per_token_if_dynamic
220
142
 
221
- # TODO(ch-wan): move quant preparation to FusedMoE
222
- if quant_config is None:
223
- self.quant_method: Optional[QuantizeMethodBase] = (
224
- UnquantizedFusedMoEMethod()
225
- )
226
- self.use_fp8_w8a8 = False
227
- self.use_block_quant = False
228
- self.block_shape = None
229
- self.activation_scheme = None
230
- self.w13_input_scale = None
231
- self.w2_input_scale = None
232
- self.w13_weight_scale = None
233
- self.w2_weight_scale = None
234
- elif isinstance(quant_config, W4AFp8Config):
235
- self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
236
- quant_config
237
- )
238
- self.use_fp8_w8a8 = False
239
- self.use_block_quant = False
240
- self.fp8_dtype = torch.float8_e4m3fn
241
- self.w13_input_scale = None
242
- self.w2_input_scale = None
243
- self.w13_weight_scale = None
244
- self.w2_weight_scale = None
245
- self.activation_scheme = quant_config.moe_activation_scheme
246
- elif isinstance(quant_config, Fp8Config):
247
- self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
248
- self.use_fp8_w8a8 = True
143
+ if isinstance(quant_config, Fp8Config):
249
144
  self.use_block_quant = getattr(self.quant_method, "block_quant", False)
250
145
  self.block_shape = (
251
146
  self.quant_method.quant_config.weight_block_size
252
147
  if self.use_block_quant
253
148
  else None
254
149
  )
150
+ self.use_fp8_w8a8 = True
255
151
  self.fp8_dtype = torch.float8_e4m3fn
256
152
  self.activation_scheme = quant_config.activation_scheme
257
153
  else:
258
- raise ValueError(f"Unsupported quant_config: {quant_config}")
259
-
260
- self.quant_config = quant_config
261
- self.quant_method.create_weights(
262
- layer=self,
263
- num_experts=self.num_local_experts,
264
- hidden_size=hidden_size,
265
- intermediate_size=self.intermediate_size,
266
- params_dtype=params_dtype,
267
- weight_loader=self.weight_loader,
268
- )
269
-
270
- self.grouped_gemm_runner = None
271
-
272
- # Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
273
- # Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
274
- def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
275
- """
276
- Calculates how many experts should be assigned to each rank for EP and
277
- creates a mapping from global to local expert index. Experts are
278
- distributed evenly across ranks. Any remaining are assigned to the
279
- last rank.
280
-
281
- Returns:
282
- Tuple[int, Optional[torch.Tensor]]: A tuple containing:
283
- - local_num_experts (int): The number of experts assigned
284
- to the current rank.
285
- - expert_map (Optional[torch.Tensor]): A tensor of shape
286
- (global_num_experts,) mapping from global to local index.
287
- Contains global_num_experts for experts not assigned to the current rank.
288
- Returns None if ep_size is 1.
289
- """
290
- ep_size = self.ep_size
291
- ep_rank = self.ep_rank
292
- global_num_experts = self.num_experts
293
-
294
- assert ep_size > 0
295
- if ep_size == 1:
296
- return (global_num_experts, None)
297
-
298
- local_num_experts = global_num_experts // ep_size
299
-
300
- expert_map = torch.full(
301
- (global_num_experts,), global_num_experts, dtype=torch.int32
302
- )
303
- if ep_rank < (ep_size - 1):
304
- expert_map[
305
- ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
306
- ] = torch.arange(0, local_num_experts, dtype=torch.int32)
307
- else:
308
- local_num_experts = global_num_experts - ep_rank * local_num_experts
309
-
310
- expert_map[-local_num_experts:] = torch.arange(
311
- 0, local_num_experts, dtype=torch.int32
312
- )
313
- return (local_num_experts, expert_map)
154
+ self.use_fp8_w8a8 = False
155
+ self.use_block_quant = False
156
+ self.block_shape = None
157
+ self.activation_scheme = None
314
158
 
315
159
  def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
316
160
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
317
161
  return self.forward_deepgemm(hidden_states, topk_output)
318
162
  else:
319
- return self.forward_normal(hidden_states, topk_output)
163
+ return super().forward(hidden_states, topk_output)
320
164
 
321
165
  def forward_deepgemm(
322
166
  self,
@@ -475,294 +319,6 @@ class EPMoE(FusedMoE):
475
319
  )
476
320
  return output
477
321
 
478
- def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
479
- return self.quant_method.apply(self, hidden_states, topk_output)
480
-
481
- def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
482
-
483
- topk_weights, topk_ids, _ = topk_output
484
-
485
- hidden_states_shape = hidden_states.shape
486
- hidden_states_dtype = hidden_states.dtype
487
- hidden_states_device = hidden_states.device
488
- if self.grouped_gemm_runner is None:
489
- self.grouped_gemm_runner = GroupedGemmRunner(
490
- hidden_states.device,
491
- use_flashinfer=False, # TODO: use flashinfer
492
- use_per_token_if_dynamic=self.use_per_token_if_dynamic,
493
- )
494
-
495
- num_experts = self.num_experts
496
-
497
- reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
498
- topk_ids,
499
- num_experts,
500
- )
501
-
502
- gateup_input = torch.empty(
503
- (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
504
- device=hidden_states.device,
505
- dtype=(
506
- self.fp8_dtype
507
- if self.use_fp8_w8a8 and not self.use_block_quant
508
- else hidden_states.dtype
509
- ),
510
- )
511
- if self.activation_scheme == "dynamic" and not self.use_block_quant:
512
- if self.use_per_token_if_dynamic:
513
- max_value = torch.max(hidden_states, dim=1).values.to(torch.float32)
514
- self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
515
- else:
516
- max_value = (
517
- torch.max(hidden_states)
518
- .repeat(self.num_local_experts)
519
- .to(torch.float32)
520
- )
521
- self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
522
-
523
- # PreReorder
524
- pre_reorder_triton_kernel[(hidden_states.shape[0],)](
525
- hidden_states,
526
- gateup_input,
527
- src2dst,
528
- topk_ids,
529
- self.w13_input_scale,
530
- self.start_expert_id,
531
- self.end_expert_id,
532
- self.top_k,
533
- hidden_states.shape[1],
534
- BLOCK_SIZE=512,
535
- use_per_token_if_dynamic=self.use_per_token_if_dynamic,
536
- )
537
- dispose_tensor(hidden_states)
538
-
539
- if (
540
- self.activation_scheme == "dynamic"
541
- and not self.use_block_quant
542
- and self.use_per_token_if_dynamic
543
- ):
544
- scale = torch.empty(
545
- hidden_states_shape[0] * self.top_k,
546
- device=hidden_states_device,
547
- dtype=torch.float32,
548
- )
549
- scale[src2dst] = (
550
- self.w13_input_scale.unsqueeze(1)
551
- .expand(hidden_states_shape[0], self.top_k)
552
- .reshape(-1)
553
- )
554
- self.w13_input_scale = scale
555
-
556
- seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
557
- weight_indices_cur_rank = torch.arange(
558
- 0,
559
- self.num_local_experts,
560
- device=hidden_states_device,
561
- dtype=torch.int64,
562
- )
563
- # GroupGemm-0
564
- gateup_output = self.grouped_gemm_runner(
565
- a=gateup_input,
566
- b=self.w13_weight,
567
- c=None,
568
- c_dtype=hidden_states_dtype,
569
- batch_size=self.num_local_experts,
570
- weight_column_major=True,
571
- seg_indptr=seg_indptr_cur_rank,
572
- weight_indices=weight_indices_cur_rank,
573
- use_fp8_w8a8=self.use_fp8_w8a8,
574
- scale_a=self.w13_input_scale,
575
- scale_b=self.w13_weight_scale,
576
- block_shape=self.block_shape,
577
- )
578
- del gateup_input
579
-
580
- # Act
581
- if self.activation_scheme == "dynamic" and not self.use_block_quant:
582
- self.w2_input_scale = None
583
- down_input = torch.empty(
584
- gateup_output.shape[0],
585
- gateup_output.shape[1] // 2,
586
- device=gateup_output.device,
587
- dtype=hidden_states_dtype,
588
- )
589
- else:
590
- down_input = torch.empty(
591
- gateup_output.shape[0],
592
- gateup_output.shape[1] // 2,
593
- device=gateup_output.device,
594
- dtype=(
595
- self.fp8_dtype
596
- if (self.use_fp8_w8a8 and not self.use_block_quant)
597
- else hidden_states_dtype
598
- ),
599
- )
600
-
601
- if self.activation == "silu":
602
- silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
603
- gateup_output,
604
- down_input,
605
- gateup_output.shape[1],
606
- reorder_topk_ids,
607
- self.w2_input_scale,
608
- self.start_expert_id,
609
- self.end_expert_id,
610
- BLOCK_SIZE=512,
611
- )
612
- elif self.activation == "gelu":
613
- gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
614
- gateup_output,
615
- down_input,
616
- gateup_output.shape[1],
617
- reorder_topk_ids,
618
- self.w2_input_scale,
619
- self.start_expert_id,
620
- self.end_expert_id,
621
- BLOCK_SIZE=512,
622
- )
623
- else:
624
- raise ValueError(f"Unsupported activation: {self.activation=}")
625
- del gateup_output
626
-
627
- if self.activation_scheme == "dynamic" and not self.use_block_quant:
628
- if self.use_per_token_if_dynamic:
629
- down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
630
- else:
631
- self.w2_input_scale = torch.ones(
632
- self.num_local_experts,
633
- dtype=torch.float32,
634
- device=hidden_states_device,
635
- )
636
-
637
- # GroupGemm-1
638
- down_output = torch.empty(
639
- down_input.shape[0],
640
- self.w2_weight.shape[1],
641
- device=hidden_states_device,
642
- dtype=hidden_states_dtype,
643
- )
644
- down_output = self.grouped_gemm_runner(
645
- a=down_input,
646
- b=self.w2_weight,
647
- c=down_output,
648
- batch_size=self.num_local_experts,
649
- weight_column_major=True,
650
- seg_indptr=seg_indptr_cur_rank,
651
- weight_indices=weight_indices_cur_rank,
652
- use_fp8_w8a8=self.use_fp8_w8a8,
653
- scale_a=self.w2_input_scale,
654
- scale_b=self.w2_weight_scale,
655
- block_shape=self.block_shape,
656
- )
657
- del down_input
658
-
659
- # PostReorder
660
- output = torch.empty(
661
- hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
662
- )
663
- post_reorder_triton_kernel[(hidden_states_shape[0],)](
664
- down_output,
665
- output,
666
- src2dst,
667
- topk_ids,
668
- topk_weights,
669
- self.start_expert_id,
670
- self.end_expert_id,
671
- self.top_k,
672
- hidden_states_shape[1],
673
- 0,
674
- BLOCK_SIZE=512,
675
- )
676
- return output
677
-
678
- @classmethod
679
- def make_expert_params_mapping(
680
- cls,
681
- ckpt_gate_proj_name: str,
682
- ckpt_down_proj_name: str,
683
- ckpt_up_proj_name: str,
684
- num_experts: int,
685
- ) -> List[Tuple[str, str, int, str]]:
686
- return [
687
- # (param_name, weight_name, expert_id, shard_id)
688
- (
689
- (
690
- "experts.w13_"
691
- if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
692
- else "experts.w2_"
693
- ),
694
- f"experts.{expert_id}.{weight_name}.",
695
- expert_id,
696
- shard_id,
697
- )
698
- for expert_id in range(num_experts)
699
- for shard_id, weight_name in [
700
- ("w1", ckpt_gate_proj_name),
701
- ("w2", ckpt_down_proj_name),
702
- ("w3", ckpt_up_proj_name),
703
- ]
704
- ]
705
-
706
- @classmethod
707
- def make_expert_input_scale_params_mapping(
708
- cls,
709
- num_experts: int,
710
- ) -> List[Tuple[str, str, int, str]]:
711
- # (param_name, weight_name, expert_id, shard_id)
712
- return [
713
- (
714
- "experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
715
- f"experts.{expert_id}.{shard_id}.",
716
- expert_id,
717
- shard_id,
718
- )
719
- for expert_id in range(num_experts)
720
- for shard_id in ["w1", "w2", "w3"]
721
- ]
722
-
723
- def weight_loader(
724
- self,
725
- param: torch.nn.Parameter,
726
- loaded_weight: torch.Tensor,
727
- weight_name: str,
728
- shard_id: str,
729
- expert_id: int,
730
- ) -> None:
731
- physical_expert_ids = (
732
- get_global_expert_location_metadata().logical_to_all_physical(
733
- self.layer_id, expert_id
734
- )
735
- )
736
- for physical_expert_id in physical_expert_ids:
737
- self._weight_loader_physical(
738
- param=param,
739
- loaded_weight=loaded_weight,
740
- weight_name=weight_name,
741
- shard_id=shard_id,
742
- expert_id=physical_expert_id,
743
- )
744
-
745
- def _weight_loader_physical(
746
- self,
747
- param: torch.nn.Parameter,
748
- loaded_weight: torch.Tensor,
749
- weight_name: str,
750
- shard_id: str,
751
- expert_id: int,
752
- ) -> None:
753
- if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
754
- return
755
- expert_id = expert_id - self.start_expert_id
756
-
757
- self._weight_loader_impl(
758
- param=param,
759
- loaded_weight=loaded_weight,
760
- weight_name=weight_name,
761
- shard_id=shard_id,
762
- expert_id=expert_id,
763
- )
764
- return
765
-
766
322
 
767
323
  class DeepEPMoE(EPMoE):
768
324
  """
@@ -778,6 +334,7 @@ class DeepEPMoE(EPMoE):
778
334
  hidden_size: int,
779
335
  intermediate_size: int,
780
336
  layer_id: int,
337
+ num_fused_shared_experts: int = 0,
781
338
  params_dtype: Optional[torch.dtype] = None,
782
339
  quant_config: Optional[QuantizationConfig] = None,
783
340
  tp_size: Optional[int] = None,
@@ -792,6 +349,7 @@ class DeepEPMoE(EPMoE):
792
349
  hidden_size=hidden_size,
793
350
  intermediate_size=intermediate_size,
794
351
  layer_id=layer_id,
352
+ num_fused_shared_experts=num_fused_shared_experts,
795
353
  params_dtype=params_dtype,
796
354
  quant_config=quant_config,
797
355
  tp_size=tp_size,
@@ -892,14 +450,15 @@ class DeepEPMoE(EPMoE):
892
450
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
893
451
  return self.forward_aiter(dispatch_output)
894
452
  if dispatch_output.format.is_deepep_normal():
895
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
896
- return self.forward_deepgemm_contiguous(dispatch_output)
897
- else:
898
- return self.forward_normal(dispatch_output)
453
+ assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
454
+ return self.forward_deepgemm_contiguous(dispatch_output)
899
455
  elif dispatch_output.format.is_deepep_ll():
456
+ assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
900
457
  return self.forward_deepgemm_masked(dispatch_output)
901
458
  else:
902
- raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
459
+ raise ValueError(
460
+ f"Dispatch output format {dispatch_output.format} is not supported"
461
+ )
903
462
 
904
463
  def combine(
905
464
  self,
@@ -915,185 +474,6 @@ class DeepEPMoE(EPMoE):
915
474
  forward_batch=forward_batch,
916
475
  )
917
476
 
918
- def _prepare_for_normal(
919
- self,
920
- hidden_states: torch.Tensor,
921
- topk_idx: torch.Tensor,
922
- ):
923
- from sglang.srt.layers.moe.ep_moe.kernels import (
924
- deepep_permute_triton_kernel,
925
- deepep_run_moe_deep_preprocess,
926
- )
927
-
928
- if hidden_states.shape[0] == 0:
929
- reorder_topk_ids = torch.empty(
930
- (0,), device=hidden_states.device, dtype=torch.int64
931
- )
932
- seg_indptr = torch.zeros(
933
- (self.num_experts + 1,),
934
- device=hidden_states.device,
935
- dtype=torch.int64,
936
- )
937
- return reorder_topk_ids, seg_indptr, hidden_states
938
- else:
939
- if _use_aiter:
940
- # skip permutation here as aiter fused_moe has fused inside
941
- reorder_topk_ids = torch.empty(
942
- (0,), device=hidden_states.device, dtype=torch.int64
943
- )
944
- seg_indptr = torch.zeros(
945
- (self.num_experts + 1,),
946
- device=hidden_states.device,
947
- dtype=torch.int64,
948
- )
949
- return reorder_topk_ids, seg_indptr, hidden_states
950
-
951
- reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
952
- topk_idx, self.num_experts
953
- )
954
- num_total_tokens = reorder_topk_ids.numel()
955
- gateup_input = torch.empty(
956
- (int(num_total_tokens), hidden_states.shape[1]),
957
- device=hidden_states.device,
958
- dtype=hidden_states.dtype,
959
- )
960
- # PreReorder
961
- deepep_permute_triton_kernel[(hidden_states.shape[0],)](
962
- hidden_states,
963
- gateup_input,
964
- self.src2dst,
965
- topk_idx,
966
- None,
967
- self.router_topk,
968
- hidden_states.shape[1],
969
- BLOCK_SIZE=512,
970
- )
971
- return reorder_topk_ids, seg_indptr, gateup_input
972
-
973
- def forward_normal(
974
- self,
975
- dispatch_output: DeepEPNormalOutput,
976
- ):
977
- hidden_states, topk_idx = (
978
- dispatch_output.hidden_states,
979
- dispatch_output.topk_idx,
980
- )
981
- reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
982
- hidden_states, topk_idx
983
- )
984
- hidden_states_dtype = hidden_states.dtype
985
- hidden_states_device = hidden_states.device
986
-
987
- assert self.quant_method is not None
988
- assert self.activation == "silu"
989
- if self.grouped_gemm_runner is None:
990
- self.grouped_gemm_runner = GroupedGemmRunner(
991
- hidden_states.device, use_flashinfer=False # TODO: use flashinfer
992
- )
993
-
994
- if self.activation_scheme == "dynamic" and not self.use_block_quant:
995
- max_value = (
996
- torch.max(hidden_states)
997
- .repeat(self.num_local_experts)
998
- .to(torch.float32)
999
- )
1000
- self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
1001
- weight_indices_cur_rank = torch.arange(
1002
- 0,
1003
- self.num_local_experts,
1004
- device=hidden_states.device,
1005
- dtype=torch.int64,
1006
- )
1007
-
1008
- # GroupGemm-0
1009
- if hidden_states.shape[0] > 0:
1010
- gateup_output = self.grouped_gemm_runner(
1011
- a=hidden_states,
1012
- b=self.w13_weight,
1013
- c=None,
1014
- c_dtype=hidden_states.dtype,
1015
- batch_size=self.num_local_experts,
1016
- weight_column_major=True,
1017
- seg_indptr=seg_indptr,
1018
- weight_indices=weight_indices_cur_rank,
1019
- use_fp8_w8a8=self.use_fp8_w8a8,
1020
- scale_a=self.w13_input_scale,
1021
- scale_b=(
1022
- self.w13_weight_scale_inv
1023
- if self.use_block_quant
1024
- else self.w13_weight_scale
1025
- ),
1026
- block_shape=self.block_shape,
1027
- )
1028
- else:
1029
- gateup_output = torch.empty(
1030
- hidden_states.shape[0],
1031
- self.w13_weight.shape[1],
1032
- device=hidden_states.device,
1033
- dtype=hidden_states.dtype,
1034
- )
1035
-
1036
- # Act
1037
- down_input = torch.empty(
1038
- gateup_output.shape[0],
1039
- gateup_output.shape[1] // 2,
1040
- device=gateup_output.device,
1041
- dtype=(
1042
- self.fp8_dtype
1043
- if (self.use_fp8_w8a8 and not self.use_block_quant)
1044
- else hidden_states_dtype
1045
- ),
1046
- )
1047
- if self.w2_input_scale is None and not self.use_block_quant:
1048
- self.w2_input_scale = torch.ones(
1049
- self.num_local_experts,
1050
- dtype=torch.float32,
1051
- device=hidden_states_device,
1052
- )
1053
-
1054
- if self.activation == "silu":
1055
- silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
1056
- gateup_output,
1057
- down_input,
1058
- gateup_output.shape[1],
1059
- reorder_topk_ids,
1060
- self.w2_input_scale,
1061
- 0,
1062
- self.num_local_experts - 1,
1063
- BLOCK_SIZE=512,
1064
- )
1065
- else:
1066
- raise ValueError(f"Unsupported activation: {self.activation=}")
1067
-
1068
- del gateup_output
1069
-
1070
- # GroupGemm-1
1071
- down_output = torch.empty(
1072
- down_input.shape[0],
1073
- self.w2_weight.shape[1],
1074
- device=hidden_states_device,
1075
- dtype=hidden_states_dtype,
1076
- )
1077
- if down_input.shape[0] > 0:
1078
- down_output = self.grouped_gemm_runner(
1079
- a=down_input,
1080
- b=self.w2_weight,
1081
- c=down_output,
1082
- batch_size=self.num_local_experts,
1083
- weight_column_major=True,
1084
- seg_indptr=seg_indptr,
1085
- weight_indices=weight_indices_cur_rank,
1086
- use_fp8_w8a8=self.use_fp8_w8a8,
1087
- scale_a=self.w2_input_scale,
1088
- scale_b=(
1089
- self.w2_weight_scale_inv
1090
- if self.use_block_quant
1091
- else self.w2_weight_scale
1092
- ),
1093
- block_shape=self.block_shape,
1094
- )
1095
- return down_output
1096
-
1097
477
  def forward_aiter(
1098
478
  self,
1099
479
  dispatch_output: DeepEPNormalOutput,