sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +35 -0
  5. sglang/srt/conversation.py +9 -117
  6. sglang/srt/disaggregation/base/conn.py +5 -2
  7. sglang/srt/disaggregation/decode.py +6 -1
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
  9. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  10. sglang/srt/disaggregation/prefill.py +3 -0
  11. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  12. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  13. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  14. sglang/srt/distributed/parallel_state.py +22 -9
  15. sglang/srt/entrypoints/context.py +244 -0
  16. sglang/srt/entrypoints/engine.py +8 -5
  17. sglang/srt/entrypoints/harmony_utils.py +370 -0
  18. sglang/srt/entrypoints/http_server.py +106 -15
  19. sglang/srt/entrypoints/openai/protocol.py +227 -1
  20. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  21. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  22. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  23. sglang/srt/entrypoints/tool.py +87 -0
  24. sglang/srt/eplb/expert_distribution.py +4 -2
  25. sglang/srt/eplb/expert_location.py +5 -1
  26. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  27. sglang/srt/hf_transformers_utils.py +55 -13
  28. sglang/srt/jinja_template_utils.py +8 -1
  29. sglang/srt/layers/attention/aiter_backend.py +5 -8
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  31. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  32. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  33. sglang/srt/layers/attention/triton_backend.py +85 -14
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  35. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  36. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  38. sglang/srt/layers/attention/vision.py +40 -15
  39. sglang/srt/layers/communicator.py +35 -8
  40. sglang/srt/layers/dp_attention.py +12 -0
  41. sglang/srt/layers/linear.py +9 -8
  42. sglang/srt/layers/logits_processor.py +9 -1
  43. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  44. sglang/srt/layers/moe/ep_moe/layer.py +87 -107
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
  48. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
  49. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  50. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  51. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  52. sglang/srt/layers/moe/topk.py +12 -3
  53. sglang/srt/layers/moe/utils.py +59 -0
  54. sglang/srt/layers/quantization/__init__.py +22 -0
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  56. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  57. sglang/srt/layers/quantization/fp4.py +557 -0
  58. sglang/srt/layers/quantization/fp8.py +8 -7
  59. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  60. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  61. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  62. sglang/srt/layers/quantization/mxfp4.py +651 -0
  63. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  64. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  65. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  66. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  67. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  68. sglang/srt/layers/quantization/quark/utils.py +107 -0
  69. sglang/srt/layers/quantization/unquant.py +60 -6
  70. sglang/srt/layers/quantization/w4afp8.py +1 -1
  71. sglang/srt/layers/rotary_embedding.py +225 -1
  72. sglang/srt/layers/utils.py +9 -0
  73. sglang/srt/layers/vocab_parallel_embedding.py +15 -4
  74. sglang/srt/lora/lora_manager.py +70 -14
  75. sglang/srt/lora/lora_registry.py +10 -2
  76. sglang/srt/lora/mem_pool.py +43 -5
  77. sglang/srt/managers/cache_controller.py +61 -32
  78. sglang/srt/managers/data_parallel_controller.py +52 -2
  79. sglang/srt/managers/detokenizer_manager.py +1 -1
  80. sglang/srt/managers/io_struct.py +21 -4
  81. sglang/srt/managers/mm_utils.py +5 -11
  82. sglang/srt/managers/schedule_batch.py +30 -8
  83. sglang/srt/managers/schedule_policy.py +3 -1
  84. sglang/srt/managers/scheduler.py +170 -18
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  86. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  87. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  88. sglang/srt/managers/template_manager.py +59 -22
  89. sglang/srt/managers/tokenizer_manager.py +137 -67
  90. sglang/srt/managers/tp_worker.py +3 -0
  91. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  92. sglang/srt/managers/utils.py +45 -1
  93. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  94. sglang/srt/mem_cache/hicache_storage.py +13 -21
  95. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  96. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  97. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  98. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  99. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  100. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  101. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  102. sglang/srt/model_executor/forward_batch_info.py +48 -17
  103. sglang/srt/model_executor/model_runner.py +24 -2
  104. sglang/srt/model_loader/weight_utils.py +10 -0
  105. sglang/srt/models/bailing_moe.py +425 -0
  106. sglang/srt/models/deepseek_v2.py +95 -50
  107. sglang/srt/models/ernie4.py +426 -0
  108. sglang/srt/models/ernie4_eagle.py +203 -0
  109. sglang/srt/models/gemma3n_mm.py +39 -0
  110. sglang/srt/models/glm4_moe.py +102 -27
  111. sglang/srt/models/gpt_oss.py +1134 -0
  112. sglang/srt/models/grok.py +3 -3
  113. sglang/srt/models/llama4.py +13 -2
  114. sglang/srt/models/mixtral.py +3 -3
  115. sglang/srt/models/mllama4.py +428 -19
  116. sglang/srt/models/qwen2.py +6 -0
  117. sglang/srt/models/qwen2_moe.py +7 -4
  118. sglang/srt/models/qwen3_moe.py +39 -14
  119. sglang/srt/models/step3_vl.py +10 -1
  120. sglang/srt/models/transformers.py +2 -5
  121. sglang/srt/multimodal/processors/base_processor.py +4 -3
  122. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  123. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  124. sglang/srt/operations_strategy.py +1 -1
  125. sglang/srt/reasoning_parser.py +18 -39
  126. sglang/srt/server_args.py +218 -23
  127. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  128. sglang/srt/two_batch_overlap.py +163 -9
  129. sglang/srt/utils.py +41 -26
  130. sglang/srt/weight_sync/utils.py +1 -1
  131. sglang/test/runners.py +4 -4
  132. sglang/test/test_utils.py +4 -4
  133. sglang/version.py +1 -1
  134. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
  135. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
  136. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  137. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  138. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  139. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  140. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  141. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  142. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  143. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional
6
6
 
7
7
  import torch
8
8
  from sgl_kernel import gelu_and_mul, silu_and_mul
9
- from triton_kernels.matmul_ogs import matmul_ogs
9
+ from triton_kernels.matmul_ogs import (
10
+ FlexCtx,
11
+ FnSpecs,
12
+ FusedActivation,
13
+ PrecisionConfig,
14
+ matmul_ogs,
15
+ )
16
+ from triton_kernels.numerics import InFlexData
10
17
  from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
11
-
12
- from sglang.srt.utils import direct_register_custom_op
18
+ from triton_kernels.swiglu import swiglu_fn
13
19
 
14
20
  if TYPE_CHECKING:
15
21
  from sglang.srt.layers.moe.topk import TopKOutput
16
22
 
17
23
 
24
+ def quantize(w, dtype, dev, **opt):
25
+ if dtype == "bf16":
26
+ return w.to(torch.bfloat16), InFlexData()
27
+ elif dtype == "fp8":
28
+ wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2)
29
+ return (
30
+ wq,
31
+ InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)),
32
+ MicroscalingCtx(),
33
+ )
34
+ else:
35
+ assert dtype == "mx4", f"{dtype=}"
36
+ swizzle_mx_scale = opt["swizzle_mx_scale"]
37
+ swizzle_axis = 2 if swizzle_mx_scale else None
38
+ w = w.to(torch.bfloat16)
39
+ w, mx_scales, weight_scale_shape = downcast_to_mxfp(
40
+ w, torch.uint8, axis=1, swizzle_axis=swizzle_axis
41
+ )
42
+ return (
43
+ w,
44
+ InFlexData(),
45
+ MicroscalingCtx(
46
+ weight_scale=mx_scales,
47
+ swizzle_mx=swizzle_mx_scale,
48
+ actual_weight_scale_shape=weight_scale_shape,
49
+ ),
50
+ )
51
+
52
+
18
53
  def triton_kernel_moe_forward(
19
54
  hidden_states: torch.Tensor,
20
55
  w1: torch.Tensor,
@@ -148,16 +183,17 @@ def triton_kernel_fused_experts(
148
183
  return intermediate_cache3
149
184
 
150
185
 
151
- def triton_kernel_moe_forward_fake(
186
+ def triton_kernel_moe_with_bias_forward(
152
187
  hidden_states: torch.Tensor,
153
188
  w1: torch.Tensor,
189
+ w1_pcg,
190
+ b1: torch.Tensor,
154
191
  w2: torch.Tensor,
155
- gating_output: torch.Tensor,
156
- topk: int,
157
- renormalize: bool,
192
+ w2_pcg,
193
+ b2: torch.Tensor,
194
+ topk_output: TopKOutput,
158
195
  inplace: bool = False,
159
196
  activation: str = "silu",
160
- apply_router_weight_on_input: bool = False,
161
197
  use_fp8_w8a8: bool = False,
162
198
  per_channel_quant: bool = False,
163
199
  global_num_experts: int = -1,
@@ -167,13 +203,131 @@ def triton_kernel_moe_forward_fake(
167
203
  a1_scale: Optional[torch.Tensor] = None,
168
204
  a2_scale: Optional[torch.Tensor] = None,
169
205
  block_shape: Optional[list[int]] = None,
206
+ activation_alpha: Optional[float] = None,
207
+ swiglu_limit: Optional[int] = None,
170
208
  ) -> torch.Tensor:
171
- return torch.empty_like(hidden_states)
209
+ assert topk_output.format.is_triton_kernel()
210
+ routing_data, gather_idx, scatter_idx = topk_output
211
+
212
+ return triton_kernel_fused_experts_with_bias(
213
+ hidden_states,
214
+ w1=w1,
215
+ w1_pcg=w1_pcg,
216
+ b1=b1,
217
+ w2=w2,
218
+ w2_pcg=w2_pcg,
219
+ b2=b2,
220
+ routing_data=routing_data,
221
+ gather_indx=gather_idx,
222
+ scatter_indx=scatter_idx,
223
+ inplace=inplace,
224
+ activation=activation,
225
+ use_fp8_w8a8=use_fp8_w8a8,
226
+ per_channel_quant=per_channel_quant,
227
+ global_num_experts=global_num_experts,
228
+ expert_map=expert_map,
229
+ w1_scale=w1_scale,
230
+ w2_scale=w2_scale,
231
+ a1_scale=a1_scale,
232
+ a2_scale=a2_scale,
233
+ block_shape=block_shape,
234
+ activation_alpha=activation_alpha,
235
+ swiglu_limit=swiglu_limit,
236
+ )
172
237
 
173
238
 
174
- direct_register_custom_op(
175
- op_name="forward_cuda_triton",
176
- op_func=triton_kernel_moe_forward,
177
- mutates_args=[],
178
- fake_impl=triton_kernel_moe_forward_fake,
179
- )
239
+ def triton_kernel_fused_experts_with_bias(
240
+ hidden_states: torch.Tensor,
241
+ w1: torch.Tensor,
242
+ w1_pcg,
243
+ b1: torch.Tensor,
244
+ w2: torch.Tensor,
245
+ w2_pcg,
246
+ b2: torch.Tensor,
247
+ routing_data: RoutingData,
248
+ gather_indx: GatherIndx,
249
+ scatter_indx: ScatterIndx,
250
+ inplace: bool = False,
251
+ activation: str = "silu",
252
+ use_fp8_w8a8: bool = False,
253
+ per_channel_quant: bool = False,
254
+ global_num_experts: int = -1,
255
+ expert_map: Optional[torch.Tensor] = None,
256
+ w1_scale: Optional[torch.Tensor] = None,
257
+ w2_scale: Optional[torch.Tensor] = None,
258
+ a1_scale: Optional[torch.Tensor] = None,
259
+ a2_scale: Optional[torch.Tensor] = None,
260
+ block_shape: Optional[list[int]] = None,
261
+ activation_alpha: Optional[float] = None,
262
+ swiglu_limit: Optional[int] = None,
263
+ ) -> torch.Tensor:
264
+ # print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
265
+ assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
266
+ assert per_channel_quant == False, "per_channel_quant is not supported"
267
+ assert expert_map == None, "expert_map is not supported"
268
+ assert w1_scale == None, "w1_scale is not supported"
269
+ assert w2_scale == None, "w2_scale is not supported"
270
+ assert a1_scale == None, "a1_scale is not supported"
271
+ assert a2_scale == None, "a2_scale is not supported"
272
+ assert block_shape == None, "block_shape is not supported"
273
+
274
+ # type check
275
+ assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
276
+ for w in (w1, w2):
277
+ # TODO assert bf16 or mxfp4
278
+ # assert (w.dtype == torch.bfloat16) or check-is-mxfp4, f"w must be bfloat16 or mxfp4 {w1.dtype=}"
279
+ pass
280
+
281
+ # Shape check
282
+ assert hidden_states.ndim == 2, "hidden_states must be 2D"
283
+ assert (
284
+ hidden_states.shape[-1] == w1.shape[-2]
285
+ ), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
286
+ assert (
287
+ w2.shape[-1] == w1.shape[1]
288
+ ), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
289
+
290
+ # feature check
291
+ assert inplace == False, "Inplace is not supported in new triton MoE kernel"
292
+
293
+ E, _, _ = w1.shape
294
+
295
+ if global_num_experts == -1:
296
+ global_num_experts = E
297
+
298
+ # TODO maybe completely remove this branch
299
+ if w1.dtype == torch.bfloat16:
300
+ device = "cuda"
301
+ optg = dict()
302
+ w1, w1_flex = quantize(w1, "bf16", device, **optg)
303
+ w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex))
304
+
305
+ w2, w2_flex = quantize(w2, "bf16", device, **optg)
306
+ w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))
307
+
308
+ act = FusedActivation(
309
+ FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
310
+ (activation_alpha, swiglu_limit),
311
+ 2,
312
+ )
313
+
314
+ intermediate_cache = matmul_ogs(
315
+ hidden_states,
316
+ w1,
317
+ b1,
318
+ routing_data,
319
+ gather_indx=gather_indx,
320
+ precision_config=w1_pcg,
321
+ gammas=None,
322
+ fused_activation=act,
323
+ )
324
+
325
+ return matmul_ogs(
326
+ intermediate_cache,
327
+ w2,
328
+ b2,
329
+ routing_data,
330
+ scatter_indx=scatter_indx,
331
+ precision_config=w2_pcg,
332
+ gammas=routing_data.gate_scal,
333
+ )
@@ -0,0 +1,23 @@
1
+ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
2
+ BaseDispatcher,
3
+ BaseDispatcherConfig,
4
+ DispatchOutput,
5
+ DispatchOutputFormat,
6
+ )
7
+ from sglang.srt.layers.moe.token_dispatcher.deepep import (
8
+ DeepEPConfig,
9
+ DeepEPDispatcher,
10
+ DeepEPLLOutput,
11
+ DeepEPNormalOutput,
12
+ )
13
+
14
+ __all__ = [
15
+ "BaseDispatcher",
16
+ "BaseDispatcherConfig",
17
+ "DispatchOutput",
18
+ "DispatchOutputFormat",
19
+ "DeepEPConfig",
20
+ "DeepEPDispatcher",
21
+ "DeepEPNormalOutput",
22
+ "DeepEPLLOutput",
23
+ ]
@@ -2,11 +2,22 @@ from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
4
  from enum import Enum, auto
5
- from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
5
+ from typing import Protocol, runtime_checkable
6
6
 
7
7
  import torch
8
8
 
9
9
 
10
+ class MoEA2ABackend(Enum):
11
+ none = "none"
12
+ deepep = "deepep"
13
+
14
+ def is_none(self):
15
+ return self == MoEA2ABackend.none
16
+
17
+ def is_deepep(self):
18
+ return self == MoEA2ABackend.deepep
19
+
20
+
10
21
  class DispatchOutputFormat(Enum):
11
22
  standard = auto()
12
23
  deepep_normal = auto()
@@ -1,5 +1,3 @@
1
- # TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
2
-
3
1
  from __future__ import annotations
4
2
 
5
3
  import logging
@@ -22,15 +20,10 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
22
20
  DispatchOutput,
23
21
  DispatchOutputFormat,
24
22
  )
23
+ from sglang.srt.layers.moe.utils import DeepEPMode
25
24
  from sglang.srt.layers.quantization import deep_gemm_wrapper
26
25
  from sglang.srt.managers.schedule_batch import global_server_args_dict
27
- from sglang.srt.utils import (
28
- DeepEPMode,
29
- get_bool_env_var,
30
- get_int_env_var,
31
- is_hip,
32
- load_json_config,
33
- )
26
+ from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
34
27
 
35
28
  try:
36
29
  from deep_ep import Buffer, Config
@@ -150,9 +143,9 @@ class DeepEPBuffer:
150
143
  num_rdma_bytes,
151
144
  )
152
145
 
153
- if deepep_mode == DeepEPMode.normal:
146
+ if deepep_mode == DeepEPMode.NORMAL:
154
147
  num_qps_per_rank = DeepEPConfig.get_instance().num_sms // 2
155
- elif deepep_mode in [DeepEPMode.low_latency, DeepEPMode.auto]:
148
+ elif deepep_mode in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]:
156
149
  num_qps_per_rank = num_experts // group.size()
157
150
  else:
158
151
  raise NotImplementedError
@@ -161,7 +154,7 @@ class DeepEPBuffer:
161
154
  device="cuda"
162
155
  ).multi_processor_count
163
156
  if (
164
- (deepep_mode != DeepEPMode.low_latency)
157
+ (deepep_mode != DeepEPMode.LOW_LATENCY)
165
158
  and not global_server_args_dict["enable_two_batch_overlap"]
166
159
  and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
167
160
  ):
@@ -611,7 +604,7 @@ class DeepEPDispatcher(BaseDispatcher):
611
604
  num_local_experts: int = None,
612
605
  hidden_size: int = None,
613
606
  params_dtype: torch.dtype = None,
614
- deepep_mode: DeepEPMode = DeepEPMode.auto,
607
+ deepep_mode: DeepEPMode = DeepEPMode.AUTO,
615
608
  async_finish: bool = False,
616
609
  return_recv_hook: bool = False,
617
610
  ):
@@ -697,9 +690,9 @@ class DeepEPDispatcher(BaseDispatcher):
697
690
  resolved_deepep_mode = self.deepep_mode.resolve(
698
691
  forward_batch.is_extend_in_batch
699
692
  )
700
- if resolved_deepep_mode == DeepEPMode.normal:
693
+ if resolved_deepep_mode == DeepEPMode.NORMAL:
701
694
  return self._normal_dispatcher
702
- elif resolved_deepep_mode == DeepEPMode.low_latency:
695
+ elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
703
696
  return self._low_latency_dispatcher
704
697
  else:
705
698
  raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
@@ -185,8 +185,9 @@ class TopK(CustomOp):
185
185
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
186
186
  ) -> TopKOutput:
187
187
  if self.use_triton_kernels:
188
+ # renormalize=True is equivalent to sm_first=False
188
189
  routing_data, gather_idx, scatter_idx = routing(
189
- router_logits, self.top_k, self.renormalize
190
+ router_logits, self.top_k, sm_first=not self.renormalize
190
191
  )
191
192
  return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
192
193
  else:
@@ -397,8 +398,12 @@ def grouped_topk_gpu(
397
398
  .reshape(num_token, -1)
398
399
  ) # [n, e]
399
400
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
401
+ # TODO: NPU can't support directly evaluating a comparison for now
400
402
  topk_weights, topk_ids = torch.topk(
401
- tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
403
+ tmp_scores,
404
+ k=topk,
405
+ dim=-1,
406
+ sorted=(True if num_fused_shared_experts > 0 else False),
402
407
  )
403
408
  if num_fused_shared_experts:
404
409
  topk_ids[:, -1] = torch.randint(
@@ -488,8 +493,12 @@ def biased_grouped_topk_impl(
488
493
  tmp_scores = scores_for_choice.masked_fill(
489
494
  ~score_mask.bool(), float("-inf")
490
495
  ) # [n, e]
496
+ # TODO: NPU can't support directly evaluating a comparison for now
491
497
  _, topk_ids = torch.topk(
492
- tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
498
+ tmp_scores,
499
+ k=topk,
500
+ dim=-1,
501
+ sorted=(True if num_fused_shared_experts > 0 else False),
493
502
  )
494
503
  topk_weights = scores.gather(1, topk_ids)
495
504
 
@@ -0,0 +1,59 @@
1
+ import importlib.util
2
+ from enum import Enum
3
+ from functools import lru_cache
4
+
5
+ from packaging import version as pkg_version
6
+
7
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
8
+
9
+
10
+ @lru_cache(maxsize=1)
11
+ def should_use_flashinfer_trtllm_moe():
12
+ result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
13
+ not importlib.util.find_spec("flashinfer")
14
+ or pkg_version.parse(__import__("flashinfer").__version__)
15
+ >= pkg_version.parse("0.2.9rc1")
16
+ )
17
+ return result
18
+
19
+
20
+ class MoeA2ABackend(Enum):
21
+
22
+ STANDARD = ("standard", "none")
23
+ DEEPEP = "deepep"
24
+
25
+ @classmethod
26
+ def _missing_(cls, value):
27
+ if value is None:
28
+ return cls.STANDARD
29
+ for member in cls:
30
+ if value in member.value:
31
+ return member
32
+ raise ValueError(f"No {cls.__name__} member for value {value}")
33
+
34
+ def is_deepep(self):
35
+ return self == MoeA2ABackend.DEEPEP
36
+
37
+ def is_standard(self):
38
+ return self == MoeA2ABackend.STANDARD
39
+
40
+
41
+ class DeepEPMode(Enum):
42
+ NORMAL = "normal"
43
+ LOW_LATENCY = "low_latency"
44
+ AUTO = "auto"
45
+
46
+ def enable_normal(self):
47
+ return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
48
+
49
+ def enable_low_latency(self):
50
+ return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
51
+
52
+ def resolve(self, is_extend_in_batch: bool):
53
+ if self != DeepEPMode.AUTO:
54
+ return self
55
+
56
+ if is_extend_in_batch:
57
+ return DeepEPMode.NORMAL
58
+ else:
59
+ return DeepEPMode.LOW_LATENCY
@@ -47,6 +47,12 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
47
47
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
48
48
  CompressedTensorsConfig,
49
49
  )
50
+ from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
51
+
52
+ is_mxfp_supported = mxfp_supported()
53
+ if is_mxfp_supported:
54
+ from sglang.srt.layers.quantization.fp4 import MxFp4Config
55
+
50
56
  from sglang.srt.layers.quantization.fp8 import Fp8Config
51
57
  from sglang.srt.layers.quantization.gptq import (
52
58
  GPTQConfig,
@@ -60,6 +66,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
60
66
  ModelOptFp8Config,
61
67
  )
62
68
  from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
69
+ from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
63
70
  from sglang.srt.layers.quantization.petit import PetitNvFp4Config
64
71
  from sglang.srt.layers.quantization.qoq import QoQConfig
65
72
  from sglang.srt.layers.quantization.utils import get_linear_quant_method
@@ -85,6 +92,21 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
85
92
  "petit_nvfp4": PetitNvFp4Config,
86
93
  }
87
94
 
95
+
96
+ if is_cuda():
97
+ BASE_QUANTIZATION_METHODS.update(
98
+ {
99
+ "quark": Mxfp4Config,
100
+ "mxfp4": Mxfp4Config,
101
+ }
102
+ )
103
+ elif is_mxfp_supported and is_hip():
104
+ BASE_QUANTIZATION_METHODS.update(
105
+ {
106
+ "quark": MxFp4Config,
107
+ "mxfp4": MxFp4Config,
108
+ }
109
+ )
88
110
  # VLLM-dependent quantization methods
89
111
  VLLM_QUANTIZATION_METHODS = {
90
112
  "aqlm": AQLMConfig,
@@ -23,6 +23,7 @@ from sglang.srt.layers.quantization.utils import (
23
23
  from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
24
24
 
25
25
  if TYPE_CHECKING:
26
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
26
27
  from sglang.srt.layers.moe.topk import TopKOutput
27
28
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
28
29
  CompressedTensorsConfig,
@@ -189,7 +190,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
189
190
  layer.w13_input_scale = None
190
191
  layer.w2_input_scale = None
191
192
 
192
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
193
+ def process_weights_after_loading(self, layer: FusedMoE) -> None:
193
194
  # Fp8 moe kernels require a single activation scale.
194
195
  # We take the max of all the scales in case they differ.
195
196
  if self.static_input_scales:
@@ -246,7 +247,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
246
247
  assert layer.w13_weight_scale is not None
247
248
  shard_size = layer.intermediate_size_per_partition
248
249
  max_w13_scales = layer.w13_weight_scale.max(dim=1).values
249
- for expert_id in range(layer.local_num_experts):
250
+ for expert_id in range(layer.num_local_experts):
250
251
  start = 0
251
252
  for shard_id in range(2):
252
253
  dq_weight = per_tensor_dequantize(
@@ -148,7 +148,7 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
148
148
  "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
149
149
  "N": n,
150
150
  "K": k,
151
- "NUM_GROUPS": 1,
151
+ "NUM_GROUPS": num_groups,
152
152
  "BLOCK_M": block_m,
153
153
  "BLOCK_N": block_n,
154
154
  "BLOCK_K": block_k,