sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/bench_one_batch.py +3 -0
  3. sglang/srt/configs/__init__.py +8 -0
  4. sglang/srt/configs/model_config.py +4 -0
  5. sglang/srt/configs/step3_vl.py +172 -0
  6. sglang/srt/conversation.py +23 -0
  7. sglang/srt/disaggregation/decode.py +2 -8
  8. sglang/srt/disaggregation/launch_lb.py +5 -20
  9. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  10. sglang/srt/disaggregation/prefill.py +2 -6
  11. sglang/srt/distributed/parallel_state.py +86 -1
  12. sglang/srt/entrypoints/engine.py +14 -18
  13. sglang/srt/entrypoints/http_server.py +10 -2
  14. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  15. sglang/srt/eplb/expert_distribution.py +5 -0
  16. sglang/srt/eplb/expert_location.py +17 -6
  17. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  18. sglang/srt/eplb/expert_location_updater.py +2 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/step3_detector.py +436 -0
  21. sglang/srt/hf_transformers_utils.py +2 -0
  22. sglang/srt/jinja_template_utils.py +4 -1
  23. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  24. sglang/srt/layers/attention/utils.py +6 -1
  25. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +39 -674
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
  29. sglang/srt/layers/quantization/fp8.py +52 -18
  30. sglang/srt/layers/quantization/unquant.py +0 -8
  31. sglang/srt/layers/quantization/w4afp8.py +1 -0
  32. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  33. sglang/srt/managers/cache_controller.py +165 -67
  34. sglang/srt/managers/data_parallel_controller.py +2 -0
  35. sglang/srt/managers/io_struct.py +0 -2
  36. sglang/srt/managers/scheduler.py +90 -671
  37. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  38. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  39. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  40. sglang/srt/managers/template_manager.py +62 -19
  41. sglang/srt/managers/tokenizer_manager.py +123 -74
  42. sglang/srt/managers/tp_worker.py +4 -0
  43. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  44. sglang/srt/mem_cache/hicache_storage.py +60 -17
  45. sglang/srt/mem_cache/hiradix_cache.py +36 -8
  46. sglang/srt/mem_cache/memory_pool.py +15 -118
  47. sglang/srt/mem_cache/memory_pool_host.py +418 -29
  48. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  49. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  50. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  51. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  52. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  53. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
  54. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  55. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  57. sglang/srt/model_executor/model_runner.py +13 -1
  58. sglang/srt/model_loader/weight_utils.py +2 -0
  59. sglang/srt/models/arcee.py +532 -0
  60. sglang/srt/models/deepseek_v2.py +7 -6
  61. sglang/srt/models/glm4_moe.py +6 -4
  62. sglang/srt/models/granitemoe.py +3 -0
  63. sglang/srt/models/grok.py +3 -0
  64. sglang/srt/models/hunyuan.py +1 -0
  65. sglang/srt/models/llama4.py +3 -0
  66. sglang/srt/models/mixtral.py +3 -0
  67. sglang/srt/models/olmoe.py +3 -0
  68. sglang/srt/models/phimoe.py +1 -0
  69. sglang/srt/models/step3_vl.py +991 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/reasoning_parser.py +2 -1
  73. sglang/srt/server_args.py +49 -18
  74. sglang/srt/speculative/eagle_worker.py +2 -0
  75. sglang/srt/utils.py +1 -0
  76. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  77. sglang/utils.py +0 -11
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
  80. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
  81. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  82. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -25,14 +25,22 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
25
25
  silu_and_mul_triton_kernel,
26
26
  tma_align_input_scale,
27
27
  )
28
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
28
+ from sglang.srt.layers.moe.fused_moe_triton.layer import (
29
+ FlashInferFusedMoE,
30
+ FusedMoE,
31
+ should_use_flashinfer_trtllm_moe,
32
+ )
29
33
  from sglang.srt.layers.moe.topk import TopKOutput
30
34
  from sglang.srt.layers.quantization import deep_gemm_wrapper
31
35
  from sglang.srt.layers.quantization.base_config import (
32
36
  QuantizationConfig,
33
37
  QuantizeMethodBase,
34
38
  )
35
- from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
39
+ from sglang.srt.layers.quantization.fp8 import (
40
+ Fp8Config,
41
+ Fp8MoEMethod,
42
+ get_tile_tokens_dim,
43
+ )
36
44
  from sglang.srt.layers.quantization.fp8_kernel import (
37
45
  is_fp8_fnuz,
38
46
  sglang_per_token_group_quant_fp8,
@@ -49,7 +57,6 @@ from sglang.srt.utils import (
49
57
  get_bool_env_var,
50
58
  is_hip,
51
59
  is_npu,
52
- next_power_of_2,
53
60
  )
54
61
 
55
62
  if TYPE_CHECKING:
@@ -63,10 +70,7 @@ _is_hip = is_hip()
63
70
  _is_npu = is_npu()
64
71
  _is_fp8_fnuz = is_fp8_fnuz()
65
72
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
66
- use_flashinfer_trtllm_moe = (
67
- global_server_args_dict["enable_flashinfer_trtllm_moe"]
68
- and global_server_args_dict["enable_ep_moe"]
69
- )
73
+
70
74
 
71
75
  if not (_is_npu or _is_hip):
72
76
  from sgl_kernel import silu_and_mul
@@ -76,99 +80,9 @@ if _use_aiter:
76
80
  from aiter.fused_moe import fused_moe
77
81
  from aiter.ops.shuffle import shuffle_weight
78
82
 
79
- if use_flashinfer_trtllm_moe:
80
- try:
81
- import flashinfer.fused_moe as fi_fused_moe
82
- except ImportError:
83
- fi_fused_moe = None
84
- use_flashinfer_trtllm_moe = False
85
-
86
83
  logger = logging.getLogger(__name__)
87
84
 
88
85
 
89
- class GroupedGemmRunner(torch.nn.Module):
90
- flashinfer_gemm_warpper = None
91
-
92
- def __init__(
93
- self,
94
- device,
95
- use_flashinfer: bool = False,
96
- use_per_token_if_dynamic: bool = True,
97
- ):
98
- super().__init__()
99
- self.device = device
100
- self.use_flashinfer = use_flashinfer
101
- self.use_per_token_if_dynamic = use_per_token_if_dynamic
102
- if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
103
- GroupedGemmRunner._init_flashinfer_wrapper(device)
104
-
105
- @classmethod
106
- def _init_flashinfer_wrapper(cls, device):
107
- from flashinfer import SegmentGEMMWrapper
108
-
109
- workspace_buffer = torch.empty(
110
- 128 * 1024 * 1024, dtype=torch.int8, device=device
111
- )
112
- cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer)
113
-
114
- # c = a * b
115
- def forward(
116
- self,
117
- a: torch.Tensor,
118
- b: torch.Tensor,
119
- c: torch.Tensor,
120
- batch_size: int,
121
- weight_column_major: bool,
122
- seg_indptr: Optional[torch.Tensor] = None,
123
- weight_indices: Optional[torch.Tensor] = None,
124
- use_fp8_w8a8: bool = False,
125
- scale_a: torch.Tensor = None,
126
- scale_b: torch.Tensor = None,
127
- block_shape: Optional[List[int]] = None,
128
- c_dtype=None,
129
- ):
130
- if self.use_flashinfer:
131
- # TODO: flashinfer
132
- assert False
133
- assert GroupedGemmRunner.flashinfer_gemm_warpper is not None
134
- c = GroupedGemmRunner.flashinfer_gemm_warpper.run(
135
- x=a,
136
- weights=b,
137
- batch_size=batch_size,
138
- weight_column_major=weight_column_major,
139
- seg_indptr=seg_indptr,
140
- weight_indices=weight_indices,
141
- )
142
- else:
143
- assert weight_column_major == True
144
- c = grouped_gemm_triton(
145
- a,
146
- b,
147
- c,
148
- batch_size,
149
- weight_column_major,
150
- seg_indptr,
151
- weight_indices,
152
- use_fp8_w8a8,
153
- scale_a,
154
- scale_b,
155
- block_shape=block_shape,
156
- c_dtype=c_dtype,
157
- use_per_token_if_dynamic=self.use_per_token_if_dynamic,
158
- )
159
- return c
160
-
161
-
162
- def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
163
- # Guess tokens per expert assuming perfect expert distribution first.
164
- num_tokens_per_expert = (num_tokens * top_k) // num_experts
165
- # And pad the number to the next power of 2.
166
- tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
167
- # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
168
- tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
169
- return tile_tokens_dim
170
-
171
-
172
86
  class EPMoE(FusedMoE):
173
87
  """
174
88
  MoE Expert Parallel Impl
@@ -183,140 +97,57 @@ class EPMoE(FusedMoE):
183
97
  hidden_size: int,
184
98
  intermediate_size: int,
185
99
  layer_id: int,
100
+ num_fused_shared_experts: int = 0,
186
101
  params_dtype: Optional[torch.dtype] = None,
187
102
  quant_config: Optional[QuantizationConfig] = None,
188
103
  tp_size: Optional[int] = None,
189
104
  prefix: str = "",
190
105
  activation: str = "silu",
191
106
  routed_scaling_factor: Optional[float] = None,
192
- use_per_token_if_dynamic: bool = True,
193
107
  ):
194
108
  super().__init__(
195
109
  num_experts=num_experts,
196
110
  hidden_size=hidden_size,
197
111
  intermediate_size=intermediate_size,
198
- top_k=top_k,
112
+ num_fused_shared_experts=num_fused_shared_experts,
199
113
  layer_id=layer_id,
114
+ top_k=top_k,
200
115
  params_dtype=params_dtype,
201
116
  quant_config=quant_config,
202
117
  tp_size=tp_size,
203
118
  prefix=prefix,
204
119
  activation=activation,
120
+ # apply_router_weight_on_input=apply_router_weight_on_input,
205
121
  routed_scaling_factor=routed_scaling_factor,
206
122
  enable_ep_moe=True,
207
- skip_quant=True,
208
123
  )
209
124
 
210
- if params_dtype is None:
211
- params_dtype = torch.get_default_dtype()
212
-
213
- self.layer_id = layer_id
214
- self.num_local_experts, self.expert_map = self.determine_expert_map()
215
- self.start_expert_id = self.ep_rank * self.num_local_experts
125
+ self.start_expert_id = self.moe_ep_rank * self.num_local_experts
216
126
  self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
217
127
 
218
128
  self.intermediate_size = intermediate_size
219
- self.use_per_token_if_dynamic = use_per_token_if_dynamic
220
129
 
221
- # TODO(ch-wan): move quant preparation to FusedMoE
222
- if quant_config is None:
223
- self.quant_method: Optional[QuantizeMethodBase] = (
224
- UnquantizedFusedMoEMethod()
225
- )
226
- self.use_fp8_w8a8 = False
227
- self.use_block_quant = False
228
- self.block_shape = None
229
- self.activation_scheme = None
230
- self.w13_input_scale = None
231
- self.w2_input_scale = None
232
- self.w13_weight_scale = None
233
- self.w2_weight_scale = None
234
- elif isinstance(quant_config, W4AFp8Config):
235
- self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
236
- quant_config
237
- )
238
- self.use_fp8_w8a8 = False
239
- self.use_block_quant = False
240
- self.fp8_dtype = torch.float8_e4m3fn
241
- self.w13_input_scale = None
242
- self.w2_input_scale = None
243
- self.w13_weight_scale = None
244
- self.w2_weight_scale = None
245
- self.activation_scheme = quant_config.moe_activation_scheme
246
- elif isinstance(quant_config, Fp8Config):
247
- self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
248
- self.use_fp8_w8a8 = True
130
+ if isinstance(quant_config, Fp8Config):
249
131
  self.use_block_quant = getattr(self.quant_method, "block_quant", False)
250
132
  self.block_shape = (
251
133
  self.quant_method.quant_config.weight_block_size
252
134
  if self.use_block_quant
253
135
  else None
254
136
  )
137
+ self.use_fp8_w8a8 = True
255
138
  self.fp8_dtype = torch.float8_e4m3fn
256
139
  self.activation_scheme = quant_config.activation_scheme
257
140
  else:
258
- raise ValueError(f"Unsupported quant_config: {quant_config}")
259
-
260
- self.quant_config = quant_config
261
- self.quant_method.create_weights(
262
- layer=self,
263
- num_experts=self.num_local_experts,
264
- hidden_size=hidden_size,
265
- intermediate_size=self.intermediate_size,
266
- params_dtype=params_dtype,
267
- weight_loader=self.weight_loader,
268
- )
269
-
270
- self.grouped_gemm_runner = None
271
-
272
- # Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
273
- # Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
274
- def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
275
- """
276
- Calculates how many experts should be assigned to each rank for EP and
277
- creates a mapping from global to local expert index. Experts are
278
- distributed evenly across ranks. Any remaining are assigned to the
279
- last rank.
280
-
281
- Returns:
282
- Tuple[int, Optional[torch.Tensor]]: A tuple containing:
283
- - local_num_experts (int): The number of experts assigned
284
- to the current rank.
285
- - expert_map (Optional[torch.Tensor]): A tensor of shape
286
- (global_num_experts,) mapping from global to local index.
287
- Contains global_num_experts for experts not assigned to the current rank.
288
- Returns None if ep_size is 1.
289
- """
290
- ep_size = self.ep_size
291
- ep_rank = self.ep_rank
292
- global_num_experts = self.num_experts
293
-
294
- assert ep_size > 0
295
- if ep_size == 1:
296
- return (global_num_experts, None)
297
-
298
- local_num_experts = global_num_experts // ep_size
299
-
300
- expert_map = torch.full(
301
- (global_num_experts,), global_num_experts, dtype=torch.int32
302
- )
303
- if ep_rank < (ep_size - 1):
304
- expert_map[
305
- ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
306
- ] = torch.arange(0, local_num_experts, dtype=torch.int32)
307
- else:
308
- local_num_experts = global_num_experts - ep_rank * local_num_experts
309
-
310
- expert_map[-local_num_experts:] = torch.arange(
311
- 0, local_num_experts, dtype=torch.int32
312
- )
313
- return (local_num_experts, expert_map)
141
+ self.use_fp8_w8a8 = False
142
+ self.use_block_quant = False
143
+ self.block_shape = None
144
+ self.activation_scheme = None
314
145
 
315
146
  def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
316
147
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
317
148
  return self.forward_deepgemm(hidden_states, topk_output)
318
149
  else:
319
- return self.forward_normal(hidden_states, topk_output)
150
+ return super().forward(hidden_states, topk_output)
320
151
 
321
152
  def forward_deepgemm(
322
153
  self,
@@ -475,294 +306,6 @@ class EPMoE(FusedMoE):
475
306
  )
476
307
  return output
477
308
 
478
- def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
479
- return self.quant_method.apply(self, hidden_states, topk_output)
480
-
481
- def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
482
-
483
- topk_weights, topk_ids, _ = topk_output
484
-
485
- hidden_states_shape = hidden_states.shape
486
- hidden_states_dtype = hidden_states.dtype
487
- hidden_states_device = hidden_states.device
488
- if self.grouped_gemm_runner is None:
489
- self.grouped_gemm_runner = GroupedGemmRunner(
490
- hidden_states.device,
491
- use_flashinfer=False, # TODO: use flashinfer
492
- use_per_token_if_dynamic=self.use_per_token_if_dynamic,
493
- )
494
-
495
- num_experts = self.num_experts
496
-
497
- reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
498
- topk_ids,
499
- num_experts,
500
- )
501
-
502
- gateup_input = torch.empty(
503
- (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
504
- device=hidden_states.device,
505
- dtype=(
506
- self.fp8_dtype
507
- if self.use_fp8_w8a8 and not self.use_block_quant
508
- else hidden_states.dtype
509
- ),
510
- )
511
- if self.activation_scheme == "dynamic" and not self.use_block_quant:
512
- if self.use_per_token_if_dynamic:
513
- max_value = torch.max(hidden_states, dim=1).values.to(torch.float32)
514
- self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
515
- else:
516
- max_value = (
517
- torch.max(hidden_states)
518
- .repeat(self.num_local_experts)
519
- .to(torch.float32)
520
- )
521
- self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
522
-
523
- # PreReorder
524
- pre_reorder_triton_kernel[(hidden_states.shape[0],)](
525
- hidden_states,
526
- gateup_input,
527
- src2dst,
528
- topk_ids,
529
- self.w13_input_scale,
530
- self.start_expert_id,
531
- self.end_expert_id,
532
- self.top_k,
533
- hidden_states.shape[1],
534
- BLOCK_SIZE=512,
535
- use_per_token_if_dynamic=self.use_per_token_if_dynamic,
536
- )
537
- dispose_tensor(hidden_states)
538
-
539
- if (
540
- self.activation_scheme == "dynamic"
541
- and not self.use_block_quant
542
- and self.use_per_token_if_dynamic
543
- ):
544
- scale = torch.empty(
545
- hidden_states_shape[0] * self.top_k,
546
- device=hidden_states_device,
547
- dtype=torch.float32,
548
- )
549
- scale[src2dst] = (
550
- self.w13_input_scale.unsqueeze(1)
551
- .expand(hidden_states_shape[0], self.top_k)
552
- .reshape(-1)
553
- )
554
- self.w13_input_scale = scale
555
-
556
- seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
557
- weight_indices_cur_rank = torch.arange(
558
- 0,
559
- self.num_local_experts,
560
- device=hidden_states_device,
561
- dtype=torch.int64,
562
- )
563
- # GroupGemm-0
564
- gateup_output = self.grouped_gemm_runner(
565
- a=gateup_input,
566
- b=self.w13_weight,
567
- c=None,
568
- c_dtype=hidden_states_dtype,
569
- batch_size=self.num_local_experts,
570
- weight_column_major=True,
571
- seg_indptr=seg_indptr_cur_rank,
572
- weight_indices=weight_indices_cur_rank,
573
- use_fp8_w8a8=self.use_fp8_w8a8,
574
- scale_a=self.w13_input_scale,
575
- scale_b=self.w13_weight_scale,
576
- block_shape=self.block_shape,
577
- )
578
- del gateup_input
579
-
580
- # Act
581
- if self.activation_scheme == "dynamic" and not self.use_block_quant:
582
- self.w2_input_scale = None
583
- down_input = torch.empty(
584
- gateup_output.shape[0],
585
- gateup_output.shape[1] // 2,
586
- device=gateup_output.device,
587
- dtype=hidden_states_dtype,
588
- )
589
- else:
590
- down_input = torch.empty(
591
- gateup_output.shape[0],
592
- gateup_output.shape[1] // 2,
593
- device=gateup_output.device,
594
- dtype=(
595
- self.fp8_dtype
596
- if (self.use_fp8_w8a8 and not self.use_block_quant)
597
- else hidden_states_dtype
598
- ),
599
- )
600
-
601
- if self.activation == "silu":
602
- silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
603
- gateup_output,
604
- down_input,
605
- gateup_output.shape[1],
606
- reorder_topk_ids,
607
- self.w2_input_scale,
608
- self.start_expert_id,
609
- self.end_expert_id,
610
- BLOCK_SIZE=512,
611
- )
612
- elif self.activation == "gelu":
613
- gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
614
- gateup_output,
615
- down_input,
616
- gateup_output.shape[1],
617
- reorder_topk_ids,
618
- self.w2_input_scale,
619
- self.start_expert_id,
620
- self.end_expert_id,
621
- BLOCK_SIZE=512,
622
- )
623
- else:
624
- raise ValueError(f"Unsupported activation: {self.activation=}")
625
- del gateup_output
626
-
627
- if self.activation_scheme == "dynamic" and not self.use_block_quant:
628
- if self.use_per_token_if_dynamic:
629
- down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
630
- else:
631
- self.w2_input_scale = torch.ones(
632
- self.num_local_experts,
633
- dtype=torch.float32,
634
- device=hidden_states_device,
635
- )
636
-
637
- # GroupGemm-1
638
- down_output = torch.empty(
639
- down_input.shape[0],
640
- self.w2_weight.shape[1],
641
- device=hidden_states_device,
642
- dtype=hidden_states_dtype,
643
- )
644
- down_output = self.grouped_gemm_runner(
645
- a=down_input,
646
- b=self.w2_weight,
647
- c=down_output,
648
- batch_size=self.num_local_experts,
649
- weight_column_major=True,
650
- seg_indptr=seg_indptr_cur_rank,
651
- weight_indices=weight_indices_cur_rank,
652
- use_fp8_w8a8=self.use_fp8_w8a8,
653
- scale_a=self.w2_input_scale,
654
- scale_b=self.w2_weight_scale,
655
- block_shape=self.block_shape,
656
- )
657
- del down_input
658
-
659
- # PostReorder
660
- output = torch.empty(
661
- hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
662
- )
663
- post_reorder_triton_kernel[(hidden_states_shape[0],)](
664
- down_output,
665
- output,
666
- src2dst,
667
- topk_ids,
668
- topk_weights,
669
- self.start_expert_id,
670
- self.end_expert_id,
671
- self.top_k,
672
- hidden_states_shape[1],
673
- 0,
674
- BLOCK_SIZE=512,
675
- )
676
- return output
677
-
678
- @classmethod
679
- def make_expert_params_mapping(
680
- cls,
681
- ckpt_gate_proj_name: str,
682
- ckpt_down_proj_name: str,
683
- ckpt_up_proj_name: str,
684
- num_experts: int,
685
- ) -> List[Tuple[str, str, int, str]]:
686
- return [
687
- # (param_name, weight_name, expert_id, shard_id)
688
- (
689
- (
690
- "experts.w13_"
691
- if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
692
- else "experts.w2_"
693
- ),
694
- f"experts.{expert_id}.{weight_name}.",
695
- expert_id,
696
- shard_id,
697
- )
698
- for expert_id in range(num_experts)
699
- for shard_id, weight_name in [
700
- ("w1", ckpt_gate_proj_name),
701
- ("w2", ckpt_down_proj_name),
702
- ("w3", ckpt_up_proj_name),
703
- ]
704
- ]
705
-
706
- @classmethod
707
- def make_expert_input_scale_params_mapping(
708
- cls,
709
- num_experts: int,
710
- ) -> List[Tuple[str, str, int, str]]:
711
- # (param_name, weight_name, expert_id, shard_id)
712
- return [
713
- (
714
- "experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
715
- f"experts.{expert_id}.{shard_id}.",
716
- expert_id,
717
- shard_id,
718
- )
719
- for expert_id in range(num_experts)
720
- for shard_id in ["w1", "w2", "w3"]
721
- ]
722
-
723
- def weight_loader(
724
- self,
725
- param: torch.nn.Parameter,
726
- loaded_weight: torch.Tensor,
727
- weight_name: str,
728
- shard_id: str,
729
- expert_id: int,
730
- ) -> None:
731
- physical_expert_ids = (
732
- get_global_expert_location_metadata().logical_to_all_physical(
733
- self.layer_id, expert_id
734
- )
735
- )
736
- for physical_expert_id in physical_expert_ids:
737
- self._weight_loader_physical(
738
- param=param,
739
- loaded_weight=loaded_weight,
740
- weight_name=weight_name,
741
- shard_id=shard_id,
742
- expert_id=physical_expert_id,
743
- )
744
-
745
- def _weight_loader_physical(
746
- self,
747
- param: torch.nn.Parameter,
748
- loaded_weight: torch.Tensor,
749
- weight_name: str,
750
- shard_id: str,
751
- expert_id: int,
752
- ) -> None:
753
- if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
754
- return
755
- expert_id = expert_id - self.start_expert_id
756
-
757
- self._weight_loader_impl(
758
- param=param,
759
- loaded_weight=loaded_weight,
760
- weight_name=weight_name,
761
- shard_id=shard_id,
762
- expert_id=expert_id,
763
- )
764
- return
765
-
766
309
 
767
310
  class DeepEPMoE(EPMoE):
768
311
  """
@@ -778,6 +321,7 @@ class DeepEPMoE(EPMoE):
778
321
  hidden_size: int,
779
322
  intermediate_size: int,
780
323
  layer_id: int,
324
+ num_fused_shared_experts: int = 0,
781
325
  params_dtype: Optional[torch.dtype] = None,
782
326
  quant_config: Optional[QuantizationConfig] = None,
783
327
  tp_size: Optional[int] = None,
@@ -792,6 +336,7 @@ class DeepEPMoE(EPMoE):
792
336
  hidden_size=hidden_size,
793
337
  intermediate_size=intermediate_size,
794
338
  layer_id=layer_id,
339
+ num_fused_shared_experts=num_fused_shared_experts,
795
340
  params_dtype=params_dtype,
796
341
  quant_config=quant_config,
797
342
  tp_size=tp_size,
@@ -892,14 +437,15 @@ class DeepEPMoE(EPMoE):
892
437
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
893
438
  return self.forward_aiter(dispatch_output)
894
439
  if dispatch_output.format.is_deepep_normal():
895
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
896
- return self.forward_deepgemm_contiguous(dispatch_output)
897
- else:
898
- return self.forward_normal(dispatch_output)
440
+ assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
441
+ return self.forward_deepgemm_contiguous(dispatch_output)
899
442
  elif dispatch_output.format.is_deepep_ll():
443
+ assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
900
444
  return self.forward_deepgemm_masked(dispatch_output)
901
445
  else:
902
- raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
446
+ raise ValueError(
447
+ f"Dispatch output format {dispatch_output.format} is not supported"
448
+ )
903
449
 
904
450
  def combine(
905
451
  self,
@@ -915,185 +461,6 @@ class DeepEPMoE(EPMoE):
915
461
  forward_batch=forward_batch,
916
462
  )
917
463
 
918
- def _prepare_for_normal(
919
- self,
920
- hidden_states: torch.Tensor,
921
- topk_idx: torch.Tensor,
922
- ):
923
- from sglang.srt.layers.moe.ep_moe.kernels import (
924
- deepep_permute_triton_kernel,
925
- deepep_run_moe_deep_preprocess,
926
- )
927
-
928
- if hidden_states.shape[0] == 0:
929
- reorder_topk_ids = torch.empty(
930
- (0,), device=hidden_states.device, dtype=torch.int64
931
- )
932
- seg_indptr = torch.zeros(
933
- (self.num_experts + 1,),
934
- device=hidden_states.device,
935
- dtype=torch.int64,
936
- )
937
- return reorder_topk_ids, seg_indptr, hidden_states
938
- else:
939
- if _use_aiter:
940
- # skip permutation here as aiter fused_moe has fused inside
941
- reorder_topk_ids = torch.empty(
942
- (0,), device=hidden_states.device, dtype=torch.int64
943
- )
944
- seg_indptr = torch.zeros(
945
- (self.num_experts + 1,),
946
- device=hidden_states.device,
947
- dtype=torch.int64,
948
- )
949
- return reorder_topk_ids, seg_indptr, hidden_states
950
-
951
- reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
952
- topk_idx, self.num_experts
953
- )
954
- num_total_tokens = reorder_topk_ids.numel()
955
- gateup_input = torch.empty(
956
- (int(num_total_tokens), hidden_states.shape[1]),
957
- device=hidden_states.device,
958
- dtype=hidden_states.dtype,
959
- )
960
- # PreReorder
961
- deepep_permute_triton_kernel[(hidden_states.shape[0],)](
962
- hidden_states,
963
- gateup_input,
964
- self.src2dst,
965
- topk_idx,
966
- None,
967
- self.router_topk,
968
- hidden_states.shape[1],
969
- BLOCK_SIZE=512,
970
- )
971
- return reorder_topk_ids, seg_indptr, gateup_input
972
-
973
- def forward_normal(
974
- self,
975
- dispatch_output: DeepEPNormalOutput,
976
- ):
977
- hidden_states, topk_idx = (
978
- dispatch_output.hidden_states,
979
- dispatch_output.topk_idx,
980
- )
981
- reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
982
- hidden_states, topk_idx
983
- )
984
- hidden_states_dtype = hidden_states.dtype
985
- hidden_states_device = hidden_states.device
986
-
987
- assert self.quant_method is not None
988
- assert self.activation == "silu"
989
- if self.grouped_gemm_runner is None:
990
- self.grouped_gemm_runner = GroupedGemmRunner(
991
- hidden_states.device, use_flashinfer=False # TODO: use flashinfer
992
- )
993
-
994
- if self.activation_scheme == "dynamic" and not self.use_block_quant:
995
- max_value = (
996
- torch.max(hidden_states)
997
- .repeat(self.num_local_experts)
998
- .to(torch.float32)
999
- )
1000
- self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
1001
- weight_indices_cur_rank = torch.arange(
1002
- 0,
1003
- self.num_local_experts,
1004
- device=hidden_states.device,
1005
- dtype=torch.int64,
1006
- )
1007
-
1008
- # GroupGemm-0
1009
- if hidden_states.shape[0] > 0:
1010
- gateup_output = self.grouped_gemm_runner(
1011
- a=hidden_states,
1012
- b=self.w13_weight,
1013
- c=None,
1014
- c_dtype=hidden_states.dtype,
1015
- batch_size=self.num_local_experts,
1016
- weight_column_major=True,
1017
- seg_indptr=seg_indptr,
1018
- weight_indices=weight_indices_cur_rank,
1019
- use_fp8_w8a8=self.use_fp8_w8a8,
1020
- scale_a=self.w13_input_scale,
1021
- scale_b=(
1022
- self.w13_weight_scale_inv
1023
- if self.use_block_quant
1024
- else self.w13_weight_scale
1025
- ),
1026
- block_shape=self.block_shape,
1027
- )
1028
- else:
1029
- gateup_output = torch.empty(
1030
- hidden_states.shape[0],
1031
- self.w13_weight.shape[1],
1032
- device=hidden_states.device,
1033
- dtype=hidden_states.dtype,
1034
- )
1035
-
1036
- # Act
1037
- down_input = torch.empty(
1038
- gateup_output.shape[0],
1039
- gateup_output.shape[1] // 2,
1040
- device=gateup_output.device,
1041
- dtype=(
1042
- self.fp8_dtype
1043
- if (self.use_fp8_w8a8 and not self.use_block_quant)
1044
- else hidden_states_dtype
1045
- ),
1046
- )
1047
- if self.w2_input_scale is None and not self.use_block_quant:
1048
- self.w2_input_scale = torch.ones(
1049
- self.num_local_experts,
1050
- dtype=torch.float32,
1051
- device=hidden_states_device,
1052
- )
1053
-
1054
- if self.activation == "silu":
1055
- silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
1056
- gateup_output,
1057
- down_input,
1058
- gateup_output.shape[1],
1059
- reorder_topk_ids,
1060
- self.w2_input_scale,
1061
- 0,
1062
- self.num_local_experts - 1,
1063
- BLOCK_SIZE=512,
1064
- )
1065
- else:
1066
- raise ValueError(f"Unsupported activation: {self.activation=}")
1067
-
1068
- del gateup_output
1069
-
1070
- # GroupGemm-1
1071
- down_output = torch.empty(
1072
- down_input.shape[0],
1073
- self.w2_weight.shape[1],
1074
- device=hidden_states_device,
1075
- dtype=hidden_states_dtype,
1076
- )
1077
- if down_input.shape[0] > 0:
1078
- down_output = self.grouped_gemm_runner(
1079
- a=down_input,
1080
- b=self.w2_weight,
1081
- c=down_output,
1082
- batch_size=self.num_local_experts,
1083
- weight_column_major=True,
1084
- seg_indptr=seg_indptr,
1085
- weight_indices=weight_indices_cur_rank,
1086
- use_fp8_w8a8=self.use_fp8_w8a8,
1087
- scale_a=self.w2_input_scale,
1088
- scale_b=(
1089
- self.w2_weight_scale_inv
1090
- if self.use_block_quant
1091
- else self.w2_weight_scale
1092
- ),
1093
- block_shape=self.block_shape,
1094
- )
1095
- return down_output
1096
-
1097
464
  def forward_aiter(
1098
465
  self,
1099
466
  dispatch_output: DeepEPNormalOutput,
@@ -1351,10 +718,10 @@ class FlashInferEPMoE(EPMoE):
1351
718
  self.num_expert_group = num_expert_group
1352
719
  self.topk_group = topk_group
1353
720
  self.correction_bias = correction_bias
1354
- self.use_flashinfer_trtllm_moe = use_flashinfer_trtllm_moe
721
+ self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
1355
722
 
1356
723
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
1357
- assert use_flashinfer_trtllm_moe
724
+ assert self.use_flashinfer_trtllm_moe
1358
725
  assert (
1359
726
  self.activation == "silu"
1360
727
  ), "Only silu is supported for flashinfer blockscale fp8 moe"
@@ -1367,8 +734,9 @@ class FlashInferEPMoE(EPMoE):
1367
734
  a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
1368
735
  # NOTE: scales of hidden states have to be transposed!
1369
736
  a_sf_t = a_sf.t().contiguous()
1370
- assert fi_fused_moe is not None
1371
- return fi_fused_moe.trtllm_fp8_block_scale_moe(
737
+ from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
738
+
739
+ return trtllm_fp8_block_scale_moe(
1372
740
  routing_logits=router_logits.to(torch.float32),
1373
741
  routing_bias=self.correction_bias.to(hidden_states.dtype),
1374
742
  hidden_states=a_q,
@@ -1385,7 +753,7 @@ class FlashInferEPMoE(EPMoE):
1385
753
  local_expert_offset=self.start_expert_id,
1386
754
  local_num_experts=self.num_local_experts,
1387
755
  routed_scaling_factor=self.routed_scaling_factor,
1388
- tile_tokens_dim=_get_tile_tokens_dim(
756
+ tile_tokens_dim=get_tile_tokens_dim(
1389
757
  hidden_states.shape[0], self.top_k, self.num_experts
1390
758
  ),
1391
759
  routing_method_type=2, # DeepSeek-styled routing method
@@ -1399,9 +767,6 @@ def get_moe_impl_class():
1399
767
  if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
1400
768
  # Must come before EPMoE because FusedMoE also supports enable_ep_moe
1401
769
  return FusedMoE
1402
- if use_flashinfer_trtllm_moe:
1403
- # Must come before EPMoE because FusedMoE also supports enable_ep_moe
1404
- return FlashInferEPMoE
1405
770
  if global_server_args_dict["enable_ep_moe"]:
1406
- return EPMoE
1407
- return FusedMoE
771
+ return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
772
+ return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE