sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +35 -0
  5. sglang/srt/conversation.py +9 -117
  6. sglang/srt/disaggregation/base/conn.py +5 -2
  7. sglang/srt/disaggregation/decode.py +6 -1
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
  9. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  10. sglang/srt/disaggregation/prefill.py +3 -0
  11. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  12. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  13. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  14. sglang/srt/distributed/parallel_state.py +22 -9
  15. sglang/srt/entrypoints/context.py +244 -0
  16. sglang/srt/entrypoints/engine.py +8 -5
  17. sglang/srt/entrypoints/harmony_utils.py +370 -0
  18. sglang/srt/entrypoints/http_server.py +106 -15
  19. sglang/srt/entrypoints/openai/protocol.py +227 -1
  20. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  21. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  22. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  23. sglang/srt/entrypoints/tool.py +87 -0
  24. sglang/srt/eplb/expert_distribution.py +4 -2
  25. sglang/srt/eplb/expert_location.py +5 -1
  26. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  27. sglang/srt/hf_transformers_utils.py +55 -13
  28. sglang/srt/jinja_template_utils.py +8 -1
  29. sglang/srt/layers/attention/aiter_backend.py +5 -8
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  31. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  32. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  33. sglang/srt/layers/attention/triton_backend.py +85 -14
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  35. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  36. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  38. sglang/srt/layers/attention/vision.py +40 -15
  39. sglang/srt/layers/communicator.py +35 -8
  40. sglang/srt/layers/dp_attention.py +12 -0
  41. sglang/srt/layers/linear.py +9 -8
  42. sglang/srt/layers/logits_processor.py +9 -1
  43. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  44. sglang/srt/layers/moe/ep_moe/layer.py +87 -107
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
  48. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
  49. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  50. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  51. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  52. sglang/srt/layers/moe/topk.py +12 -3
  53. sglang/srt/layers/moe/utils.py +59 -0
  54. sglang/srt/layers/quantization/__init__.py +22 -0
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  56. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  57. sglang/srt/layers/quantization/fp4.py +557 -0
  58. sglang/srt/layers/quantization/fp8.py +8 -7
  59. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  60. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  61. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  62. sglang/srt/layers/quantization/mxfp4.py +651 -0
  63. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  64. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  65. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  66. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  67. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  68. sglang/srt/layers/quantization/quark/utils.py +107 -0
  69. sglang/srt/layers/quantization/unquant.py +60 -6
  70. sglang/srt/layers/quantization/w4afp8.py +1 -1
  71. sglang/srt/layers/rotary_embedding.py +225 -1
  72. sglang/srt/layers/utils.py +9 -0
  73. sglang/srt/layers/vocab_parallel_embedding.py +15 -4
  74. sglang/srt/lora/lora_manager.py +70 -14
  75. sglang/srt/lora/lora_registry.py +10 -2
  76. sglang/srt/lora/mem_pool.py +43 -5
  77. sglang/srt/managers/cache_controller.py +61 -32
  78. sglang/srt/managers/data_parallel_controller.py +52 -2
  79. sglang/srt/managers/detokenizer_manager.py +1 -1
  80. sglang/srt/managers/io_struct.py +21 -4
  81. sglang/srt/managers/mm_utils.py +5 -11
  82. sglang/srt/managers/schedule_batch.py +30 -8
  83. sglang/srt/managers/schedule_policy.py +3 -1
  84. sglang/srt/managers/scheduler.py +170 -18
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  86. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  87. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  88. sglang/srt/managers/template_manager.py +59 -22
  89. sglang/srt/managers/tokenizer_manager.py +137 -67
  90. sglang/srt/managers/tp_worker.py +3 -0
  91. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  92. sglang/srt/managers/utils.py +45 -1
  93. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  94. sglang/srt/mem_cache/hicache_storage.py +13 -21
  95. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  96. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  97. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  98. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  99. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  100. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  101. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  102. sglang/srt/model_executor/forward_batch_info.py +48 -17
  103. sglang/srt/model_executor/model_runner.py +24 -2
  104. sglang/srt/model_loader/weight_utils.py +10 -0
  105. sglang/srt/models/bailing_moe.py +425 -0
  106. sglang/srt/models/deepseek_v2.py +95 -50
  107. sglang/srt/models/ernie4.py +426 -0
  108. sglang/srt/models/ernie4_eagle.py +203 -0
  109. sglang/srt/models/gemma3n_mm.py +39 -0
  110. sglang/srt/models/glm4_moe.py +102 -27
  111. sglang/srt/models/gpt_oss.py +1134 -0
  112. sglang/srt/models/grok.py +3 -3
  113. sglang/srt/models/llama4.py +13 -2
  114. sglang/srt/models/mixtral.py +3 -3
  115. sglang/srt/models/mllama4.py +428 -19
  116. sglang/srt/models/qwen2.py +6 -0
  117. sglang/srt/models/qwen2_moe.py +7 -4
  118. sglang/srt/models/qwen3_moe.py +39 -14
  119. sglang/srt/models/step3_vl.py +10 -1
  120. sglang/srt/models/transformers.py +2 -5
  121. sglang/srt/multimodal/processors/base_processor.py +4 -3
  122. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  123. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  124. sglang/srt/operations_strategy.py +1 -1
  125. sglang/srt/reasoning_parser.py +18 -39
  126. sglang/srt/server_args.py +218 -23
  127. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  128. sglang/srt/two_batch_overlap.py +163 -9
  129. sglang/srt/utils.py +41 -26
  130. sglang/srt/weight_sync/utils.py +1 -1
  131. sglang/test/runners.py +4 -4
  132. sglang/test/test_utils.py +4 -4
  133. sglang/version.py +1 -1
  134. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
  135. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
  136. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  137. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  138. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  139. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  140. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  141. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  142. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  143. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -1,41 +1,24 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, List, Optional, Tuple
4
+ from typing import TYPE_CHECKING, Optional
5
5
 
6
6
  import torch
7
7
 
8
- from sglang.srt.distributed import (
9
- get_tensor_model_parallel_rank,
10
- get_tensor_model_parallel_world_size,
11
- )
12
- from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
8
+ from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
13
9
  from sglang.srt.layers.moe.ep_moe.kernels import (
14
10
  ep_gather,
15
11
  ep_scatter,
16
- gelu_and_mul_triton_kernel,
17
- grouped_gemm_triton,
18
12
  moe_ep_deepgemm_preprocess,
19
13
  post_reorder_triton_kernel,
20
- pre_reorder_triton_kernel,
21
- pre_reorder_triton_kernel_for_cutlass_moe,
22
- run_cutlass_moe_ep_preproess,
23
- run_moe_ep_preproess,
24
14
  silu_and_mul_masked_post_quant_fwd,
25
- silu_and_mul_triton_kernel,
26
15
  tma_align_input_scale,
27
16
  )
28
- from sglang.srt.layers.moe.fused_moe_triton.layer import (
29
- FlashInferFusedMoE,
30
- FusedMoE,
31
- should_use_flashinfer_trtllm_moe,
32
- )
17
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
33
18
  from sglang.srt.layers.moe.topk import TopKOutput
19
+ from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
34
20
  from sglang.srt.layers.quantization import deep_gemm_wrapper
35
- from sglang.srt.layers.quantization.base_config import (
36
- QuantizationConfig,
37
- QuantizeMethodBase,
38
- )
21
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
22
  from sglang.srt.layers.quantization.fp8 import (
40
23
  Fp8Config,
41
24
  Fp8MoEMethod,
@@ -44,23 +27,13 @@ from sglang.srt.layers.quantization.fp8 import (
44
27
  from sglang.srt.layers.quantization.fp8_kernel import (
45
28
  is_fp8_fnuz,
46
29
  sglang_per_token_group_quant_fp8,
47
- sglang_per_token_quant_fp8,
48
30
  )
49
- from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
50
- from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
51
31
  from sglang.srt.managers.schedule_batch import global_server_args_dict
52
32
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
53
- from sglang.srt.utils import (
54
- DeepEPMode,
55
- ceil_div,
56
- dispose_tensor,
57
- get_bool_env_var,
58
- is_hip,
59
- is_npu,
60
- )
33
+ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
61
34
 
62
35
  if TYPE_CHECKING:
63
- from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
36
+ from sglang.srt.layers.moe.token_dispatcher import (
64
37
  DeepEPLLOutput,
65
38
  DeepEPNormalOutput,
66
39
  DispatchOutput,
@@ -71,7 +44,6 @@ _is_npu = is_npu()
71
44
  _is_fp8_fnuz = is_fp8_fnuz()
72
45
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
73
46
 
74
-
75
47
  if not (_is_npu or _is_hip):
76
48
  from sgl_kernel import silu_and_mul
77
49
 
@@ -83,6 +55,22 @@ if _use_aiter:
83
55
  logger = logging.getLogger(__name__)
84
56
 
85
57
 
58
+ # TODO(kaixih@nvidia): ideally we should merge this logic into
59
+ # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
60
+ @torch.compile
61
+ def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
62
+ temp = x.to(torch.float32).view(torch.int32)
63
+ exp = torch.bitwise_right_shift(temp, 23)
64
+ mant = torch.bitwise_and(temp, 0x7FFFFF)
65
+ is_ru = torch.logical_and(
66
+ torch.logical_and((mant > 0), (exp != 0xFE)),
67
+ ~torch.logical_and((exp == 0), (mant <= 0x400000)),
68
+ )
69
+ exp = torch.where(is_ru, exp + 1, exp)
70
+ new_x = exp.to(torch.uint8).view(torch.int)
71
+ return new_x.transpose(1, 2).contiguous().transpose(1, 2)
72
+
73
+
86
74
  class EPMoE(FusedMoE):
87
75
  """
88
76
  MoE Expert Parallel Impl
@@ -104,6 +92,9 @@ class EPMoE(FusedMoE):
104
92
  prefix: str = "",
105
93
  activation: str = "silu",
106
94
  routed_scaling_factor: Optional[float] = None,
95
+ activation_alpha: Optional[float] = None,
96
+ swiglu_limit: Optional[float] = None,
97
+ with_bias: bool = False,
107
98
  ):
108
99
  super().__init__(
109
100
  num_experts=num_experts,
@@ -119,7 +110,9 @@ class EPMoE(FusedMoE):
119
110
  activation=activation,
120
111
  # apply_router_weight_on_input=apply_router_weight_on_input,
121
112
  routed_scaling_factor=routed_scaling_factor,
122
- enable_ep_moe=True,
113
+ activation_alpha=activation_alpha,
114
+ swiglu_limit=swiglu_limit,
115
+ with_bias=with_bias,
123
116
  )
124
117
 
125
118
  self.start_expert_id = self.moe_ep_rank * self.num_local_experts
@@ -227,10 +220,22 @@ class EPMoE(FusedMoE):
227
220
 
228
221
  dispose_tensor(hidden_states)
229
222
 
223
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
224
+ b, s_mn, s_k = gateup_input_scale.shape
225
+ assert (
226
+ s_mn % 4 == 0 and s_k % 4 == 0
227
+ ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
228
+
230
229
  # GroupGemm-0
231
230
  gateup_input_fp8 = (
232
231
  gateup_input,
233
- deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
232
+ (
233
+ _cast_to_e8m0_with_rounding_up(gateup_input_scale)
234
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
235
+ else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
236
+ gateup_input_scale
237
+ )
238
+ ),
234
239
  )
235
240
  num_groups, m, k = gateup_input_fp8[0].size()
236
241
  n = self.w13_weight.size(1)
@@ -238,7 +243,12 @@ class EPMoE(FusedMoE):
238
243
  (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
239
244
  )
240
245
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
241
- gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
246
+ gateup_input_fp8,
247
+ self.w13_weight_fp8,
248
+ gateup_output,
249
+ masked_m,
250
+ expected_m,
251
+ recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
242
252
  )
243
253
  del gateup_input
244
254
  del gateup_input_fp8
@@ -269,6 +279,7 @@ class EPMoE(FusedMoE):
269
279
  down_input_scale,
270
280
  scale_block_size,
271
281
  masked_m,
282
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
272
283
  )
273
284
  del gateup_output
274
285
 
@@ -276,13 +287,24 @@ class EPMoE(FusedMoE):
276
287
  n = self.w2_weight.size(1)
277
288
  down_input_fp8 = (
278
289
  down_input,
279
- deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
290
+ (
291
+ down_input_scale
292
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
293
+ else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
294
+ down_input_scale
295
+ )
296
+ ),
280
297
  )
281
298
  down_output = torch.empty(
282
299
  (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
283
300
  )
284
301
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
285
- down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
302
+ down_input_fp8,
303
+ self.w2_weight_fp8,
304
+ down_output,
305
+ masked_m,
306
+ expected_m,
307
+ recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
286
308
  )
287
309
  del down_input
288
310
  del down_input_fp8
@@ -304,6 +326,8 @@ class EPMoE(FusedMoE):
304
326
  m_max * self.start_expert_id,
305
327
  BLOCK_SIZE=512,
306
328
  )
329
+ if self.routed_scaling_factor is not None:
330
+ output *= self.routed_scaling_factor
307
331
  return output
308
332
 
309
333
 
@@ -328,7 +352,7 @@ class DeepEPMoE(EPMoE):
328
352
  prefix: str = "",
329
353
  activation: str = "silu",
330
354
  routed_scaling_factor: Optional[float] = None,
331
- deepep_mode: DeepEPMode = DeepEPMode.auto,
355
+ deepep_mode: DeepEPMode = DeepEPMode.AUTO,
332
356
  ):
333
357
  super().__init__(
334
358
  num_experts=num_experts,
@@ -348,7 +372,6 @@ class DeepEPMoE(EPMoE):
348
372
 
349
373
  # TODO: move to the beginning of the file
350
374
  from sglang.srt.distributed.parallel_state import get_tp_group
351
- from sglang.srt.managers.schedule_batch import global_server_args_dict
352
375
  from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
353
376
 
354
377
  self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
@@ -701,72 +724,29 @@ class DeepEPMoE(EPMoE):
701
724
  return down_output
702
725
 
703
726
 
704
- class FlashInferEPMoE(EPMoE):
705
- def __init__(self, *args, **kwargs):
706
- renormalize = kwargs.pop("renormalize", True)
707
- num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
708
- use_grouped_topk = kwargs.pop("use_grouped_topk", False)
709
- num_expert_group = kwargs.pop("num_expert_group", None)
710
- topk_group = kwargs.pop("topk_group", None)
711
- correction_bias = kwargs.pop("correction_bias", None)
712
- super().__init__(*args, **kwargs)
713
- self.renormalize = renormalize
714
- self.num_fused_shared_experts = num_fused_shared_experts
715
- self.use_grouped_topk = use_grouped_topk
716
- if self.use_grouped_topk:
717
- assert num_expert_group is not None and topk_group is not None
718
- self.num_expert_group = num_expert_group
719
- self.topk_group = topk_group
720
- self.correction_bias = correction_bias
721
- self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
722
-
723
- def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
724
- assert self.use_flashinfer_trtllm_moe
725
- assert (
726
- self.activation == "silu"
727
- ), "Only silu is supported for flashinfer blockscale fp8 moe"
728
- assert (
729
- self.renormalize
730
- ), "Renormalize is required for flashinfer blockscale fp8 moe"
731
- assert (
732
- self.num_fused_shared_experts == 0
733
- ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
734
- a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
735
- # NOTE: scales of hidden states have to be transposed!
736
- a_sf_t = a_sf.t().contiguous()
737
- from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
738
-
739
- return trtllm_fp8_block_scale_moe(
740
- routing_logits=router_logits.to(torch.float32),
741
- routing_bias=self.correction_bias.to(hidden_states.dtype),
742
- hidden_states=a_q,
743
- hidden_states_scale=a_sf_t,
744
- gemm1_weights=self.w13_weight,
745
- gemm1_weights_scale=self.w13_weight_scale_inv,
746
- gemm2_weights=self.w2_weight,
747
- gemm2_weights_scale=self.w2_weight_scale_inv,
748
- num_experts=self.num_experts,
749
- top_k=self.top_k,
750
- n_group=self.num_expert_group,
751
- topk_group=self.topk_group,
752
- intermediate_size=self.w2_weight.shape[2],
753
- local_expert_offset=self.start_expert_id,
754
- local_num_experts=self.num_local_experts,
755
- routed_scaling_factor=self.routed_scaling_factor,
756
- tile_tokens_dim=get_tile_tokens_dim(
757
- hidden_states.shape[0], self.top_k, self.num_experts
758
- ),
759
- routing_method_type=2, # DeepSeek-styled routing method
760
- use_shuffled_weight=False,
761
- )
762
-
763
-
764
727
  def get_moe_impl_class():
765
- if global_server_args_dict["enable_deepep_moe"]:
728
+ if global_server_args_dict["moe_a2a_backend"].is_deepep():
766
729
  return DeepEPMoE
730
+
731
+ # NEW: Direct FP4 detection (bypasses EP requirements)
732
+ # Check for FP4 quantization with TRTLLM flag, regardless of EP
733
+ if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
734
+ try:
735
+ # Check the quantization argument directly
736
+ quantization = global_server_args_dict.get("quantization")
737
+ if quantization == "modelopt_fp4":
738
+ from sglang.srt.layers.moe.fused_moe_triton.layer import (
739
+ FlashInferFP4MoE,
740
+ )
741
+
742
+ return FlashInferFP4MoE
743
+ except:
744
+ pass
745
+
746
+ if should_use_flashinfer_trtllm_moe():
747
+ return FlashInferFusedMoE
767
748
  if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
768
- # Must come before EPMoE because FusedMoE also supports enable_ep_moe
769
749
  return FusedMoE
770
- if global_server_args_dict["enable_ep_moe"]:
771
- return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
772
- return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
750
+ if get_moe_expert_parallel_world_size() > 1:
751
+ return EPMoE
752
+ return FusedMoE
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 2
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 2
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 2
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 32,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 256,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 8,
40
+ "num_stages": 2
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 32,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 32,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 8,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 32,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 8,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 256,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 256,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 2
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 64,
101
+ "BLOCK_SIZE_K": 256,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 256,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 256,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }