sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/bench_serving.py +56 -12
  2. sglang/launch_server.py +2 -0
  3. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
  4. sglang/srt/compilation/backend.py +1 -1
  5. sglang/srt/configs/model_config.py +5 -5
  6. sglang/srt/distributed/parallel_state.py +0 -7
  7. sglang/srt/entrypoints/engine.py +18 -15
  8. sglang/srt/entrypoints/grpc_server.py +0 -1
  9. sglang/srt/entrypoints/http_server.py +75 -94
  10. sglang/srt/environ.py +16 -2
  11. sglang/srt/eplb/expert_distribution.py +30 -0
  12. sglang/srt/function_call/function_call_parser.py +2 -0
  13. sglang/srt/function_call/minimax_m2.py +367 -0
  14. sglang/srt/layers/activation.py +6 -0
  15. sglang/srt/layers/attention/flashattention_backend.py +12 -2
  16. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  17. sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
  18. sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
  19. sglang/srt/layers/attention/utils.py +78 -0
  20. sglang/srt/layers/communicator.py +1 -0
  21. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  22. sglang/srt/layers/layernorm.py +19 -4
  23. sglang/srt/layers/logits_processor.py +5 -0
  24. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  25. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  26. sglang/srt/layers/moe/ep_moe/layer.py +79 -272
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  29. sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
  30. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  31. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  32. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  33. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  34. sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
  35. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  36. sglang/srt/layers/moe/topk.py +4 -4
  37. sglang/srt/layers/moe/utils.py +3 -4
  38. sglang/srt/layers/quantization/__init__.py +3 -5
  39. sglang/srt/layers/quantization/awq.py +0 -3
  40. sglang/srt/layers/quantization/base_config.py +7 -0
  41. sglang/srt/layers/quantization/fp8.py +68 -63
  42. sglang/srt/layers/quantization/gguf.py +566 -0
  43. sglang/srt/layers/quantization/mxfp4.py +30 -38
  44. sglang/srt/layers/quantization/unquant.py +23 -45
  45. sglang/srt/layers/quantization/w4afp8.py +38 -2
  46. sglang/srt/layers/radix_attention.py +5 -2
  47. sglang/srt/layers/rotary_embedding.py +13 -1
  48. sglang/srt/layers/sampler.py +12 -1
  49. sglang/srt/managers/io_struct.py +3 -0
  50. sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
  51. sglang/srt/managers/scheduler.py +21 -15
  52. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  53. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  54. sglang/srt/managers/tokenizer_manager.py +11 -19
  55. sglang/srt/mem_cache/hicache_storage.py +7 -1
  56. sglang/srt/mem_cache/memory_pool.py +82 -0
  57. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  58. sglang/srt/model_executor/forward_batch_info.py +44 -3
  59. sglang/srt/model_executor/model_runner.py +1 -149
  60. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  61. sglang/srt/models/deepseek_v2.py +147 -44
  62. sglang/srt/models/glm4_moe.py +322 -354
  63. sglang/srt/models/glm4_moe_nextn.py +4 -14
  64. sglang/srt/models/glm4v_moe.py +29 -196
  65. sglang/srt/models/minimax_m2.py +922 -0
  66. sglang/srt/models/nvila.py +355 -0
  67. sglang/srt/models/nvila_lite.py +184 -0
  68. sglang/srt/models/qwen2.py +22 -1
  69. sglang/srt/models/qwen3.py +34 -4
  70. sglang/srt/models/qwen3_moe.py +2 -4
  71. sglang/srt/multimodal/processors/base_processor.py +1 -0
  72. sglang/srt/multimodal/processors/glm4v.py +1 -1
  73. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  74. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  75. sglang/srt/parser/reasoning_parser.py +28 -1
  76. sglang/srt/server_args.py +365 -186
  77. sglang/srt/single_batch_overlap.py +2 -7
  78. sglang/srt/utils/common.py +87 -42
  79. sglang/srt/utils/hf_transformers_utils.py +7 -3
  80. sglang/test/test_deterministic.py +235 -12
  81. sglang/test/test_deterministic_utils.py +2 -1
  82. sglang/version.py +1 -1
  83. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
  84. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
  85. sglang/srt/models/vila.py +0 -306
  86. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  87. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  88. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
5
5
 
6
6
  import torch
7
7
 
@@ -13,29 +13,23 @@ from sglang.srt.layers.moe import (
13
13
  get_moe_runner_backend,
14
14
  should_use_flashinfer_trtllm_moe,
15
15
  )
16
- from sglang.srt.layers.moe.ep_moe.kernels import (
17
- ep_gather,
18
- ep_scatter,
19
- silu_and_mul_masked_post_quant_fwd,
20
- tma_align_input_scale,
21
- )
22
16
  from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
17
+ from sglang.srt.layers.moe.token_dispatcher.deepep import (
18
+ DeepEPLLCombineInput,
19
+ DeepEPNormalCombineInput,
20
+ )
23
21
  from sglang.srt.layers.moe.topk import TopKOutput
24
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
25
23
  from sglang.srt.layers.quantization.fp8 import Fp8Config
26
- from sglang.srt.layers.quantization.fp8_kernel import (
27
- is_fp8_fnuz,
28
- sglang_per_token_group_quant_fp8,
29
- )
24
+ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
30
25
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
31
26
  from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
32
- from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
33
- from sglang.srt.utils.offloader import get_offloader
27
+ from sglang.srt.utils import get_bool_env_var, is_hip, is_npu
34
28
 
35
29
  if TYPE_CHECKING:
36
30
  from sglang.srt.layers.moe.token_dispatcher import (
37
- DeepEPLLOutput,
38
- DeepEPNormalOutput,
31
+ DeepEPLLDispatchOutput,
32
+ DeepEPNormalDispatchOutput,
39
33
  DispatchOutput,
40
34
  )
41
35
 
@@ -45,7 +39,7 @@ _is_fp8_fnuz = is_fp8_fnuz()
45
39
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
46
40
 
47
41
  if not (_is_npu or _is_hip):
48
- from sgl_kernel import silu_and_mul
42
+ pass
49
43
 
50
44
  if _use_aiter:
51
45
  from aiter import ActivationType, QuantType
@@ -90,6 +84,18 @@ class DeepEPMoE(FusedMoE):
90
84
  routed_scaling_factor=routed_scaling_factor,
91
85
  )
92
86
 
87
+ if _use_aiter or _is_npu:
88
+ self.deprecate_flag = False
89
+ elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and isinstance(
90
+ quant_config, Fp8Config
91
+ ):
92
+ self.deprecate_flag = True
93
+ else:
94
+ self.deprecate_flag = False
95
+
96
+ if self.deprecate_flag:
97
+ return
98
+
93
99
  if isinstance(quant_config, Fp8Config):
94
100
  self.use_block_quant = getattr(self.quant_method, "block_quant", False)
95
101
  self.use_fp8_w8a8 = True
@@ -100,6 +106,7 @@ class DeepEPMoE(FusedMoE):
100
106
  self.use_fp8_w8a8 = False
101
107
  self.use_block_quant = False
102
108
  else:
109
+ self.use_w4afp8 = False
103
110
  self.use_fp8_w8a8 = False
104
111
  self.use_block_quant = False
105
112
  self.use_w4afp8 = False
@@ -151,6 +158,14 @@ class DeepEPMoE(FusedMoE):
151
158
  disable_sbo=False,
152
159
  ):
153
160
 
161
+ if self.deprecate_flag:
162
+ assert forward_shared_experts is None
163
+ assert alt_stream is None
164
+ return super().forward(
165
+ hidden_states,
166
+ topk_output,
167
+ )
168
+
154
169
  # We have to call SBO inside MoE to be compatible with hooks used in offloading
155
170
  return single_batch_overlap.execute_sbo(
156
171
  hidden_states=hidden_states,
@@ -177,35 +192,51 @@ class DeepEPMoE(FusedMoE):
177
192
  dispatch_output: DispatchOutput,
178
193
  down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
179
194
  ):
195
+
196
+ if self.deprecate_flag:
197
+ assert down_gemm_overlap_args is None
198
+ return super().run_moe_core(
199
+ dispatch_output,
200
+ )
201
+
180
202
  from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
181
203
 
182
204
  if _use_aiter:
183
205
  assert DispatchOutputChecker.format_is_deepep(dispatch_output)
184
206
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
185
- return self.forward_aiter(dispatch_output)
186
- if _is_npu:
207
+ output = self.forward_aiter(dispatch_output)
208
+ elif _is_npu:
187
209
  assert DispatchOutputChecker.format_is_deepep(dispatch_output)
188
- return self.forward_npu(dispatch_output)
189
- if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
210
+ output = self.forward_npu(dispatch_output)
211
+ elif DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
190
212
  if self.use_w4afp8:
191
- return self.forward_cutlass_w4afp8(dispatch_output)
192
- assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
193
- return self.forward_deepgemm_contiguous(dispatch_output)
213
+ output = self.forward_cutlass_w4afp8(dispatch_output)
214
+ else:
215
+ assert False, "forward_deepgemm_contiguous is deprecated"
194
216
  elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
195
217
  if (
196
218
  get_moe_runner_backend().is_flashinfer_cutedsl()
197
219
  and self.quant_config.get_name() == "modelopt_fp4"
198
220
  ):
199
- return self.forward_flashinfer_cutedsl(
221
+ output = self.forward_flashinfer_cutedsl(
200
222
  dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
201
223
  )
202
- assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
203
- assert down_gemm_overlap_args is None
204
- return self.forward_deepgemm_masked(dispatch_output)
205
- else:
206
- raise ValueError(
207
- f"Dispatch output format {dispatch_output.format} is not supported"
208
- )
224
+ elif self.use_w4afp8:
225
+ output = self.forward_cutlass_w4afp8_masked(dispatch_output)
226
+ else:
227
+ assert False, "forward_deepgemm_masked is deprecated"
228
+
229
+ combine_input_wrapper = (
230
+ DeepEPNormalCombineInput
231
+ if DispatchOutputChecker.format_is_deepep_normal(dispatch_output)
232
+ else DeepEPLLCombineInput
233
+ )
234
+ return combine_input_wrapper(
235
+ hidden_states=output,
236
+ topk_ids=dispatch_output.topk_ids,
237
+ topk_weights=dispatch_output.topk_weights,
238
+ overlap_args=down_gemm_overlap_args,
239
+ )
209
240
 
210
241
  def combine(
211
242
  self,
@@ -223,7 +254,7 @@ class DeepEPMoE(FusedMoE):
223
254
 
224
255
  def forward_aiter(
225
256
  self,
226
- dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
257
+ dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput],
227
258
  ):
228
259
  hidden_states, topk_ids, topk_weights = (
229
260
  dispatch_output.hidden_states,
@@ -255,158 +286,9 @@ class DeepEPMoE(FusedMoE):
255
286
  expert_mask=self.expert_mask,
256
287
  )
257
288
 
258
- def forward_deepgemm_contiguous(
259
- self,
260
- dispatch_output: DeepEPNormalOutput,
261
- ):
262
- (
263
- hidden_states,
264
- hidden_states_scale,
265
- topk_ids,
266
- topk_weights,
267
- num_recv_tokens_per_expert,
268
- ) = dispatch_output
269
- assert self.quant_method is not None
270
- assert self.moe_runner_config.activation == "silu"
271
- if num_recv_tokens_per_expert is None:
272
- return hidden_states.bfloat16()
273
- all_tokens = sum(num_recv_tokens_per_expert)
274
- if all_tokens <= 0:
275
- return hidden_states.bfloat16()
276
- M, K = hidden_states.size()
277
- N = self.w13_weight.size(1)
278
- scale_block_size = 128
279
-
280
- w13_weight_fp8 = (
281
- self.w13_weight,
282
- (
283
- self.w13_weight_scale_inv
284
- if self.use_block_quant
285
- else self.w13_weight_scale
286
- ),
287
- )
288
- w2_weight_fp8 = (
289
- self.w2_weight,
290
- (
291
- self.w2_weight_scale_inv
292
- if self.use_block_quant
293
- else self.w2_weight_scale
294
- ),
295
- )
296
-
297
- hidden_states_shape = hidden_states.shape
298
- hidden_states_device = hidden_states.device
299
- hidden_states_dtype = hidden_states.dtype
300
-
301
- input_tensor = [
302
- torch.empty(
303
- (all_tokens, K),
304
- device=hidden_states.device,
305
- dtype=hidden_states.dtype,
306
- ),
307
- (
308
- # TODO check whether need `zeros`
309
- torch.zeros(
310
- (ceil_div(K // 128, 4), all_tokens),
311
- device=hidden_states.device,
312
- dtype=torch.int,
313
- ).transpose(0, 1)
314
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
315
- else torch.empty(
316
- (all_tokens, K // 128),
317
- device=hidden_states.device,
318
- dtype=torch.float32,
319
- )
320
- ),
321
- ]
322
- m_indices = torch.empty(
323
- all_tokens, device=hidden_states.device, dtype=torch.int32
324
- )
325
- output_index = torch.empty_like(topk_ids)
326
-
327
- if get_offloader().forbid_copy_engine_usage:
328
- num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
329
- num_recv_tokens_per_expert
330
- )
331
- else:
332
- num_recv_tokens_per_expert_gpu = torch.tensor(
333
- num_recv_tokens_per_expert,
334
- dtype=torch.int32,
335
- pin_memory=True,
336
- device="cpu",
337
- ).cuda(non_blocking=True)
338
- expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
339
-
340
- ep_scatter(
341
- hidden_states,
342
- hidden_states_scale,
343
- topk_ids,
344
- num_recv_tokens_per_expert_gpu,
345
- expert_start_loc,
346
- input_tensor[0],
347
- input_tensor[1],
348
- m_indices,
349
- output_index,
350
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
351
- )
352
- dispose_tensor(hidden_states)
353
-
354
- gateup_output = torch.empty(
355
- (all_tokens, N),
356
- device=hidden_states_device,
357
- dtype=torch.bfloat16,
358
- )
359
- if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
360
- input_tensor[1] = tma_align_input_scale(input_tensor[1])
361
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
362
- input_tensor, w13_weight_fp8, gateup_output, m_indices
363
- )
364
- del input_tensor
365
- down_input = torch.empty(
366
- (
367
- all_tokens,
368
- N // 2,
369
- ),
370
- device=gateup_output.device,
371
- dtype=torch.bfloat16,
372
- )
373
- silu_and_mul(gateup_output.view(-1, N), down_input)
374
- del gateup_output
375
- down_output = torch.empty(
376
- (all_tokens, K),
377
- device=hidden_states_device,
378
- dtype=torch.bfloat16,
379
- )
380
- down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
381
- down_input,
382
- scale_block_size,
383
- column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
384
- scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
385
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
386
- )
387
- del down_input
388
- if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
389
- down_input_scale = tma_align_input_scale(down_input_scale)
390
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
391
- (down_input_fp8, down_input_scale),
392
- w2_weight_fp8,
393
- down_output,
394
- m_indices,
395
- )
396
- del down_input_fp8, down_input_scale
397
-
398
- gather_out = torch.empty(
399
- hidden_states_shape,
400
- device=hidden_states_device,
401
- dtype=torch.bfloat16,
402
- )
403
- ep_gather(down_output, topk_ids, topk_weights, output_index, gather_out)
404
-
405
- return gather_out
406
-
407
289
  def forward_flashinfer_cutedsl(
408
290
  self,
409
- dispatch_output: DeepEPLLOutput,
291
+ dispatch_output: DeepEPLLDispatchOutput,
410
292
  down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
411
293
  ):
412
294
  hidden_states, hidden_states_scale, _, _, masked_m, _ = dispatch_output
@@ -424,7 +306,7 @@ class DeepEPMoE(FusedMoE):
424
306
 
425
307
  def forward_cutlass_w4afp8(
426
308
  self,
427
- dispatch_output: DeepEPNormalOutput,
309
+ dispatch_output: DeepEPNormalDispatchOutput,
428
310
  ):
429
311
  assert self.moe_runner_config.activation == "silu"
430
312
  assert isinstance(self.quant_method, W4AFp8MoEMethod)
@@ -433,89 +315,23 @@ class DeepEPMoE(FusedMoE):
433
315
  dispatch_output=dispatch_output,
434
316
  )
435
317
 
436
- def forward_deepgemm_masked(
318
+ def forward_cutlass_w4afp8_masked(
437
319
  self,
438
- dispatch_output: DeepEPLLOutput,
320
+ dispatch_output: DeepEPLLDispatchOutput,
439
321
  ):
440
- hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output
441
- assert self.quant_method is not None
442
322
  assert self.moe_runner_config.activation == "silu"
443
- assert (
444
- hidden_states_scale.dtype == torch.float32
445
- ), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}"
446
-
447
- # GroupGemm-0
448
- num_groups, m, k = hidden_states.size()
449
- n = self.w13_weight.size(1)
450
- expected_m = min(expected_m, m)
451
- gateup_output = torch.empty(
452
- (num_groups, m, n), device=hidden_states.device, dtype=torch.bfloat16
453
- )
454
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
455
- (hidden_states, hidden_states_scale),
456
- self.w13_weight_fp8,
457
- gateup_output,
458
- masked_m,
459
- expected_m,
460
- )
461
- dispose_tensor(hidden_states)
462
-
463
- # Act
464
- down_input = torch.empty(
465
- (
466
- gateup_output.shape[0],
467
- gateup_output.shape[1],
468
- gateup_output.shape[2] // 2,
469
- ),
470
- device=gateup_output.device,
471
- dtype=self.fp8_dtype,
472
- )
473
- scale_block_size = 128
474
- down_input_scale = torch.empty(
475
- (
476
- gateup_output.shape[0],
477
- gateup_output.shape[1],
478
- gateup_output.shape[2] // 2 // scale_block_size,
479
- ),
480
- device=gateup_output.device,
481
- dtype=torch.float32,
482
- )
483
- silu_and_mul_masked_post_quant_fwd(
484
- gateup_output,
485
- down_input,
486
- down_input_scale,
487
- scale_block_size,
488
- masked_m,
489
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
490
- )
491
- del gateup_output
492
-
493
- # GroupGemm-1
494
- n = self.w2_weight.size(1)
495
- down_input_fp8 = (
496
- down_input,
497
- (
498
- down_input_scale
499
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
500
- else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
501
- ),
502
- )
503
- down_output = torch.empty(
504
- (num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
505
- )
506
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
507
- down_input_fp8,
508
- self.w2_weight_fp8,
509
- down_output,
510
- masked_m,
511
- expected_m,
323
+ assert isinstance(self.quant_method, W4AFp8MoEMethod)
324
+ assert get_bool_env_var(
325
+ "SGLANG_DEEPEP_BF16_DISPATCH"
326
+ ), "W4AFP8 does not support FP8 dispatch; please set SGLANG_DEEPEP_BF16_DISPATCH=1."
327
+ return self.quant_method.apply_deepep_ll(
328
+ layer=self,
329
+ dispatch_output=dispatch_output,
512
330
  )
513
331
 
514
- return down_output
515
-
516
332
  def forward_npu(
517
333
  self,
518
- dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
334
+ dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput],
519
335
  ):
520
336
  assert self.quant_method is not None
521
337
  assert self.moe_runner_config.activation == "silu"
@@ -528,9 +344,9 @@ class DeepEPMoE(FusedMoE):
528
344
  output_dtype = torch.bfloat16
529
345
  group_list_type = 1
530
346
 
531
- def _forward_normal(dispatch_output: DeepEPNormalOutput):
347
+ def _forward_normal(dispatch_output: DeepEPNormalDispatchOutput):
532
348
  if TYPE_CHECKING:
533
- assert isinstance(dispatch_output, DeepEPNormalOutput)
349
+ assert isinstance(dispatch_output, DeepEPNormalDispatchOutput)
534
350
  hidden_states, hidden_states_scale, _, _, num_recv_tokens_per_expert = (
535
351
  dispatch_output
536
352
  )
@@ -600,9 +416,9 @@ class DeepEPMoE(FusedMoE):
600
416
 
601
417
  return hidden_states
602
418
 
603
- def _forward_ll(dispatch_output: DeepEPLLOutput):
419
+ def _forward_ll(dispatch_output: DeepEPLLDispatchOutput):
604
420
  if TYPE_CHECKING:
605
- assert isinstance(dispatch_output, DeepEPLLOutput)
421
+ assert isinstance(dispatch_output, DeepEPLLDispatchOutput)
606
422
  (
607
423
  hidden_states,
608
424
  hidden_states_scale,
@@ -713,12 +529,3 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
713
529
  if get_moe_runner_backend().is_flashinfer_cutlass():
714
530
  return FusedMoE
715
531
  return FusedMoE
716
-
717
-
718
- def copy_list_to_gpu_no_ce(arr: List[int]):
719
- from sgl_kernel.elementwise import copy_to_gpu_no_ce
720
-
721
- tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
722
- tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
723
- copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
724
- return tensor_gpu
@@ -172,7 +172,7 @@ class FusedMoE(torch.nn.Module):
172
172
  self.reduce_results = reduce_results
173
173
  self.use_presharded_weights = use_presharded_weights
174
174
 
175
- self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
175
+ self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels()
176
176
 
177
177
  self.quant_config = quant_config
178
178
  self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
@@ -232,7 +232,7 @@ class FusedMoE(torch.nn.Module):
232
232
  self.quant_method, ModelOptNvFp4FusedMoEMethod
233
233
  ) or (
234
234
  isinstance(self.quant_method, Fp8MoEMethod)
235
- and self.quant_method.use_cutlass_fused_experts_fp8
235
+ and self.quant_method._should_use_cutlass_fused_experts()
236
236
  )
237
237
 
238
238
  def _load_per_tensor_weight_scale(
@@ -839,7 +839,7 @@ class FusedMoE(torch.nn.Module):
839
839
  dispatch_output=dispatch_output,
840
840
  **kwargs,
841
841
  )
842
- final_hidden_states = self.dispatcher.combine(combine_input)
842
+ final_hidden_states = self.dispatcher.combine(combine_input=combine_input)
843
843
 
844
844
  # TODO: should we add some conditions here?
845
845
  final_hidden_states = final_hidden_states[
@@ -47,7 +47,7 @@ def triton_kernel_moe_forward(
47
47
 
48
48
  from sglang.srt.layers.moe.topk import TopKOutputChecker
49
49
 
50
- assert TopKOutputChecker.format_is_triton_kernel(topk_output)
50
+ assert TopKOutputChecker.format_is_triton_kernels(topk_output)
51
51
 
52
52
  routing_data, gather_idx, scatter_idx = topk_output
53
53
 
@@ -172,6 +172,7 @@ def triton_kernel_moe_with_bias_forward(
172
172
  b2: torch.Tensor,
173
173
  topk_output: TopKOutput,
174
174
  moe_runner_config: MoeRunnerConfig,
175
+ apply_router_weight_on_input: bool = False,
175
176
  use_fp8_w8a8: bool = False,
176
177
  per_channel_quant: bool = False,
177
178
  global_num_experts: int = -1,
@@ -184,7 +185,7 @@ def triton_kernel_moe_with_bias_forward(
184
185
  ) -> torch.Tensor:
185
186
  from sglang.srt.layers.moe.topk import TopKOutputChecker
186
187
 
187
- assert TopKOutputChecker.format_is_triton_kernel(topk_output)
188
+ assert TopKOutputChecker.format_is_triton_kernels(topk_output)
188
189
 
189
190
  routing_data, gather_idx, scatter_idx = topk_output
190
191
 
@@ -201,6 +202,7 @@ def triton_kernel_moe_with_bias_forward(
201
202
  scatter_indx=scatter_idx,
202
203
  inplace=False, # triton kernel doesn't support inplace
203
204
  activation=moe_runner_config.activation,
205
+ apply_router_weight_on_input=apply_router_weight_on_input,
204
206
  use_fp8_w8a8=use_fp8_w8a8,
205
207
  per_channel_quant=per_channel_quant,
206
208
  global_num_experts=global_num_experts,
@@ -228,6 +230,7 @@ def triton_kernel_fused_experts_with_bias(
228
230
  scatter_indx: ScatterIndx,
229
231
  inplace: bool = False,
230
232
  activation: str = "silu",
233
+ apply_router_weight_on_input: bool = False,
231
234
  use_fp8_w8a8: bool = False,
232
235
  per_channel_quant: bool = False,
233
236
  global_num_experts: int = -1,
@@ -296,7 +299,7 @@ def triton_kernel_fused_experts_with_bias(
296
299
  routing_data,
297
300
  gather_indx=gather_indx,
298
301
  precision_config=w1_pcg,
299
- gammas=None,
302
+ gammas=routing_data.gate_scal if apply_router_weight_on_input else None,
300
303
  fused_activation=act,
301
304
  )
302
305
 
@@ -307,5 +310,5 @@ def triton_kernel_fused_experts_with_bias(
307
310
  routing_data,
308
311
  scatter_indx=scatter_indx,
309
312
  precision_config=w2_pcg,
310
- gammas=routing_data.gate_scal,
313
+ gammas=None if apply_router_weight_on_input else routing_data.gate_scal,
311
314
  )