sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional
6
6
 
7
7
  import torch
8
8
  from sgl_kernel import gelu_and_mul, silu_and_mul
9
- from triton_kernels.matmul_ogs import matmul_ogs
9
+ from triton_kernels.matmul_ogs import (
10
+ FlexCtx,
11
+ FnSpecs,
12
+ FusedActivation,
13
+ PrecisionConfig,
14
+ matmul_ogs,
15
+ )
16
+ from triton_kernels.numerics import InFlexData
10
17
  from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
11
-
12
- from sglang.srt.utils import direct_register_custom_op
18
+ from triton_kernels.swiglu import swiglu_fn
13
19
 
14
20
  if TYPE_CHECKING:
15
21
  from sglang.srt.layers.moe.topk import TopKOutput
16
22
 
17
23
 
24
+ def quantize(w, dtype, dev, **opt):
25
+ if dtype == "bf16":
26
+ return w.to(torch.bfloat16), InFlexData()
27
+ elif dtype == "fp8":
28
+ wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2)
29
+ return (
30
+ wq,
31
+ InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)),
32
+ MicroscalingCtx(),
33
+ )
34
+ else:
35
+ assert dtype == "mx4", f"{dtype=}"
36
+ swizzle_mx_scale = opt["swizzle_mx_scale"]
37
+ swizzle_axis = 2 if swizzle_mx_scale else None
38
+ w = w.to(torch.bfloat16)
39
+ w, mx_scales, weight_scale_shape = downcast_to_mxfp(
40
+ w, torch.uint8, axis=1, swizzle_axis=swizzle_axis
41
+ )
42
+ return (
43
+ w,
44
+ InFlexData(),
45
+ MicroscalingCtx(
46
+ weight_scale=mx_scales,
47
+ swizzle_mx=swizzle_mx_scale,
48
+ actual_weight_scale_shape=weight_scale_shape,
49
+ ),
50
+ )
51
+
52
+
18
53
  def triton_kernel_moe_forward(
19
54
  hidden_states: torch.Tensor,
20
55
  w1: torch.Tensor,
@@ -146,3 +181,153 @@ def triton_kernel_fused_experts(
146
181
  )
147
182
 
148
183
  return intermediate_cache3
184
+
185
+
186
+ def triton_kernel_moe_with_bias_forward(
187
+ hidden_states: torch.Tensor,
188
+ w1: torch.Tensor,
189
+ w1_pcg,
190
+ b1: torch.Tensor,
191
+ w2: torch.Tensor,
192
+ w2_pcg,
193
+ b2: torch.Tensor,
194
+ topk_output: TopKOutput,
195
+ inplace: bool = False,
196
+ activation: str = "silu",
197
+ use_fp8_w8a8: bool = False,
198
+ per_channel_quant: bool = False,
199
+ global_num_experts: int = -1,
200
+ expert_map: Optional[torch.Tensor] = None,
201
+ w1_scale: Optional[torch.Tensor] = None,
202
+ w2_scale: Optional[torch.Tensor] = None,
203
+ a1_scale: Optional[torch.Tensor] = None,
204
+ a2_scale: Optional[torch.Tensor] = None,
205
+ block_shape: Optional[list[int]] = None,
206
+ activation_alpha: Optional[float] = None,
207
+ swiglu_limit: Optional[int] = None,
208
+ ) -> torch.Tensor:
209
+ assert topk_output.format.is_triton_kernel()
210
+ routing_data, gather_idx, scatter_idx = topk_output
211
+
212
+ return triton_kernel_fused_experts_with_bias(
213
+ hidden_states,
214
+ w1=w1,
215
+ w1_pcg=w1_pcg,
216
+ b1=b1,
217
+ w2=w2,
218
+ w2_pcg=w2_pcg,
219
+ b2=b2,
220
+ routing_data=routing_data,
221
+ gather_indx=gather_idx,
222
+ scatter_indx=scatter_idx,
223
+ inplace=inplace,
224
+ activation=activation,
225
+ use_fp8_w8a8=use_fp8_w8a8,
226
+ per_channel_quant=per_channel_quant,
227
+ global_num_experts=global_num_experts,
228
+ expert_map=expert_map,
229
+ w1_scale=w1_scale,
230
+ w2_scale=w2_scale,
231
+ a1_scale=a1_scale,
232
+ a2_scale=a2_scale,
233
+ block_shape=block_shape,
234
+ activation_alpha=activation_alpha,
235
+ swiglu_limit=swiglu_limit,
236
+ )
237
+
238
+
239
+ def triton_kernel_fused_experts_with_bias(
240
+ hidden_states: torch.Tensor,
241
+ w1: torch.Tensor,
242
+ w1_pcg,
243
+ b1: torch.Tensor,
244
+ w2: torch.Tensor,
245
+ w2_pcg,
246
+ b2: torch.Tensor,
247
+ routing_data: RoutingData,
248
+ gather_indx: GatherIndx,
249
+ scatter_indx: ScatterIndx,
250
+ inplace: bool = False,
251
+ activation: str = "silu",
252
+ use_fp8_w8a8: bool = False,
253
+ per_channel_quant: bool = False,
254
+ global_num_experts: int = -1,
255
+ expert_map: Optional[torch.Tensor] = None,
256
+ w1_scale: Optional[torch.Tensor] = None,
257
+ w2_scale: Optional[torch.Tensor] = None,
258
+ a1_scale: Optional[torch.Tensor] = None,
259
+ a2_scale: Optional[torch.Tensor] = None,
260
+ block_shape: Optional[list[int]] = None,
261
+ activation_alpha: Optional[float] = None,
262
+ swiglu_limit: Optional[int] = None,
263
+ ) -> torch.Tensor:
264
+ # print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
265
+ assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
266
+ assert per_channel_quant == False, "per_channel_quant is not supported"
267
+ assert expert_map == None, "expert_map is not supported"
268
+ assert w1_scale == None, "w1_scale is not supported"
269
+ assert w2_scale == None, "w2_scale is not supported"
270
+ assert a1_scale == None, "a1_scale is not supported"
271
+ assert a2_scale == None, "a2_scale is not supported"
272
+ assert block_shape == None, "block_shape is not supported"
273
+
274
+ # type check
275
+ assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
276
+ for w in (w1, w2):
277
+ # TODO assert bf16 or mxfp4
278
+ # assert (w.dtype == torch.bfloat16) or check-is-mxfp4, f"w must be bfloat16 or mxfp4 {w1.dtype=}"
279
+ pass
280
+
281
+ # Shape check
282
+ assert hidden_states.ndim == 2, "hidden_states must be 2D"
283
+ assert (
284
+ hidden_states.shape[-1] == w1.shape[-2]
285
+ ), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
286
+ assert (
287
+ w2.shape[-1] == w1.shape[1]
288
+ ), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
289
+
290
+ # feature check
291
+ assert inplace == False, "Inplace is not supported in new triton MoE kernel"
292
+
293
+ E, _, _ = w1.shape
294
+
295
+ if global_num_experts == -1:
296
+ global_num_experts = E
297
+
298
+ # TODO maybe completely remove this branch
299
+ if w1.dtype == torch.bfloat16:
300
+ device = "cuda"
301
+ optg = dict()
302
+ w1, w1_flex = quantize(w1, "bf16", device, **optg)
303
+ w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex))
304
+
305
+ w2, w2_flex = quantize(w2, "bf16", device, **optg)
306
+ w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))
307
+
308
+ act = FusedActivation(
309
+ FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
310
+ (activation_alpha, swiglu_limit),
311
+ 2,
312
+ )
313
+
314
+ intermediate_cache = matmul_ogs(
315
+ hidden_states,
316
+ w1,
317
+ b1,
318
+ routing_data,
319
+ gather_indx=gather_indx,
320
+ precision_config=w1_pcg,
321
+ gammas=None,
322
+ fused_activation=act,
323
+ )
324
+
325
+ return matmul_ogs(
326
+ intermediate_cache,
327
+ w2,
328
+ b2,
329
+ routing_data,
330
+ scatter_indx=scatter_indx,
331
+ precision_config=w2_pcg,
332
+ gammas=routing_data.gate_scal,
333
+ )
@@ -23,14 +23,23 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
23
23
  from sglang.srt.layers.moe.utils import DeepEPMode
24
24
  from sglang.srt.layers.quantization import deep_gemm_wrapper
25
25
  from sglang.srt.managers.schedule_batch import global_server_args_dict
26
- from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
26
+ from sglang.srt.utils import (
27
+ get_bool_env_var,
28
+ get_int_env_var,
29
+ is_hip,
30
+ is_npu,
31
+ load_json_config,
32
+ )
33
+
34
+ _is_npu = is_npu()
27
35
 
28
36
  try:
29
37
  from deep_ep import Buffer, Config
30
38
 
31
- from sglang.srt.layers.quantization.fp8_kernel import (
32
- sglang_per_token_group_quant_fp8,
33
- )
39
+ if not _is_npu:
40
+ from sglang.srt.layers.quantization.fp8_kernel import (
41
+ sglang_per_token_group_quant_fp8,
42
+ )
34
43
 
35
44
  use_deepep = True
36
45
  except ImportError:
@@ -80,8 +89,24 @@ class DeepEPLLOutput(NamedTuple):
80
89
  return DispatchOutputFormat.deepep_ll
81
90
 
82
91
 
92
+ class AscendDeepEPLLOutput(NamedTuple):
93
+ """AscendDeepEP low latency dispatch output."""
94
+
95
+ hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
96
+ topk_idx: torch.Tensor
97
+ topk_weights: torch.Tensor
98
+ masked_m: torch.Tensor
99
+ seg_indptr: torch.Tensor
100
+ expected_m: int
101
+
102
+ @property
103
+ def format(self) -> DispatchOutputFormat:
104
+ return DispatchOutputFormat.deepep_ll
105
+
106
+
83
107
  assert isinstance(DeepEPNormalOutput, DispatchOutput)
84
108
  assert isinstance(DeepEPLLOutput, DispatchOutput)
109
+ assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
85
110
 
86
111
 
87
112
  class DeepEPDispatchMode(IntEnum):
@@ -150,19 +175,20 @@ class DeepEPBuffer:
150
175
  else:
151
176
  raise NotImplementedError
152
177
 
153
- total_num_sms = torch.cuda.get_device_properties(
154
- device="cuda"
155
- ).multi_processor_count
156
- if (
157
- (deepep_mode != DeepEPMode.LOW_LATENCY)
158
- and not global_server_args_dict["enable_two_batch_overlap"]
159
- and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
160
- ):
161
- logger.warning(
162
- f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
163
- f"This may result in highly suboptimal performance. "
164
- f"Consider using --deepep-config to change the behavior."
165
- )
178
+ if not _is_npu:
179
+ total_num_sms = torch.cuda.get_device_properties(
180
+ device="cuda"
181
+ ).multi_processor_count
182
+ if (
183
+ (deepep_mode != DeepEPMode.LOW_LATENCY)
184
+ and not global_server_args_dict["enable_two_batch_overlap"]
185
+ and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
186
+ ):
187
+ logger.warning(
188
+ f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
189
+ f"This may result in highly suboptimal performance. "
190
+ f"Consider using --deepep-config to change the behavior."
191
+ )
166
192
 
167
193
  cls._buffer = Buffer(
168
194
  group,
@@ -507,13 +533,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
507
533
  masked_m
508
534
  )
509
535
 
510
- return DeepEPLLOutput(
511
- hidden_states,
512
- topk_idx,
513
- topk_weights,
514
- masked_m,
515
- expected_m,
516
- )
536
+ if _is_npu:
537
+ deepep_output = AscendDeepEPLLOutput(
538
+ hidden_states,
539
+ topk_idx,
540
+ topk_weights,
541
+ masked_m,
542
+ self.handle[1],
543
+ expected_m,
544
+ )
545
+ else:
546
+ deepep_output = DeepEPLLOutput(
547
+ hidden_states,
548
+ topk_idx,
549
+ topk_weights,
550
+ masked_m,
551
+ expected_m,
552
+ )
553
+ return deepep_output
517
554
 
518
555
  def _dispatch_core(
519
556
  self,
@@ -185,8 +185,9 @@ class TopK(CustomOp):
185
185
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
186
186
  ) -> TopKOutput:
187
187
  if self.use_triton_kernels:
188
+ # renormalize=True is equivalent to sm_first=False
188
189
  routing_data, gather_idx, scatter_idx = routing(
189
- router_logits, self.top_k, self.renormalize
190
+ router_logits, self.top_k, sm_first=not self.renormalize
190
191
  )
191
192
  return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
192
193
  else:
@@ -244,10 +245,11 @@ class TopK(CustomOp):
244
245
 
245
246
  # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
246
247
  if global_num_experts == 256:
248
+ router_logits = router_logits.to(torch.float32)
247
249
  return torch_npu.npu_moe_gating_top_k(
248
250
  router_logits,
249
251
  k=self.top_k,
250
- bias=self.correction_bias,
252
+ bias=self.correction_bias.to(torch.float32),
251
253
  k_group=self.topk_group,
252
254
  group_count=self.num_expert_group,
253
255
  group_select_mode=1,
@@ -397,8 +399,12 @@ def grouped_topk_gpu(
397
399
  .reshape(num_token, -1)
398
400
  ) # [n, e]
399
401
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
402
+ # TODO: NPU can't support directly evaluating a comparison for now
400
403
  topk_weights, topk_ids = torch.topk(
401
- tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
404
+ tmp_scores,
405
+ k=topk,
406
+ dim=-1,
407
+ sorted=(True if num_fused_shared_experts > 0 else False),
402
408
  )
403
409
  if num_fused_shared_experts:
404
410
  topk_ids[:, -1] = torch.randint(
@@ -435,7 +441,9 @@ def grouped_topk_cpu(
435
441
  routed_scaling_factor: Optional[float] = None,
436
442
  num_token_non_padded: Optional[torch.Tensor] = None,
437
443
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
444
+ apply_routed_scaling_factor_on_output: Optional[bool] = False,
438
445
  ):
446
+ assert not apply_routed_scaling_factor_on_output
439
447
  assert expert_location_dispatch_info is None
440
448
  return torch.ops.sgl_kernel.grouped_topk_cpu(
441
449
  hidden_states,
@@ -488,8 +496,12 @@ def biased_grouped_topk_impl(
488
496
  tmp_scores = scores_for_choice.masked_fill(
489
497
  ~score_mask.bool(), float("-inf")
490
498
  ) # [n, e]
499
+ # TODO: NPU can't support directly evaluating a comparison for now
491
500
  _, topk_ids = torch.topk(
492
- tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0
501
+ tmp_scores,
502
+ k=topk,
503
+ dim=-1,
504
+ sorted=(True if num_fused_shared_experts > 0 else False),
493
505
  )
494
506
  topk_weights = scores.gather(1, topk_ids)
495
507
 
@@ -1,4 +1,20 @@
1
+ import importlib.util
1
2
  from enum import Enum
3
+ from functools import lru_cache
4
+
5
+ from packaging import version as pkg_version
6
+
7
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
8
+
9
+
10
+ @lru_cache(maxsize=1)
11
+ def should_use_flashinfer_trtllm_moe():
12
+ result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
13
+ not importlib.util.find_spec("flashinfer")
14
+ or pkg_version.parse(__import__("flashinfer").__version__)
15
+ >= pkg_version.parse("0.2.9rc1")
16
+ )
17
+ return result
2
18
 
3
19
 
4
20
  class MoeA2ABackend(Enum):
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
 
4
4
  import builtins
5
5
  import inspect
6
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
6
+ from typing import TYPE_CHECKING, Dict, Optional, Type
7
7
 
8
8
  import torch
9
9
 
@@ -26,8 +26,9 @@ try:
26
26
  from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
27
27
 
28
28
  VLLM_AVAILABLE = True
29
- except ImportError:
29
+ except ImportError as e:
30
30
  VLLM_AVAILABLE = False
31
+ VLLM_IMPORT_ERROR = e
31
32
 
32
33
  # Define empty classes as placeholders when vllm is not available
33
34
  class DummyConfig:
@@ -47,6 +48,12 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
47
48
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
48
49
  CompressedTensorsConfig,
49
50
  )
51
+ from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
52
+
53
+ is_mxfp_supported = mxfp_supported()
54
+ if is_mxfp_supported:
55
+ from sglang.srt.layers.quantization.fp4 import MxFp4Config
56
+
50
57
  from sglang.srt.layers.quantization.fp8 import Fp8Config
51
58
  from sglang.srt.layers.quantization.gptq import (
52
59
  GPTQConfig,
@@ -60,6 +67,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
60
67
  ModelOptFp8Config,
61
68
  )
62
69
  from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
70
+ from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
63
71
  from sglang.srt.layers.quantization.petit import PetitNvFp4Config
64
72
  from sglang.srt.layers.quantization.qoq import QoQConfig
65
73
  from sglang.srt.layers.quantization.utils import get_linear_quant_method
@@ -85,6 +93,21 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
85
93
  "petit_nvfp4": PetitNvFp4Config,
86
94
  }
87
95
 
96
+
97
+ if is_cuda():
98
+ BASE_QUANTIZATION_METHODS.update(
99
+ {
100
+ "quark": Mxfp4Config,
101
+ "mxfp4": Mxfp4Config,
102
+ }
103
+ )
104
+ elif is_mxfp_supported and is_hip():
105
+ BASE_QUANTIZATION_METHODS.update(
106
+ {
107
+ "quark": MxFp4Config,
108
+ "mxfp4": MxFp4Config,
109
+ }
110
+ )
88
111
  # VLLM-dependent quantization methods
89
112
  VLLM_QUANTIZATION_METHODS = {
90
113
  "aqlm": AQLMConfig,
@@ -115,7 +138,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
115
138
  if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
116
139
  raise ValueError(
117
140
  f"{quantization} quantization requires some operators from vllm. "
118
- "Please install vllm by `pip install vllm==0.9.0.1`"
141
+ f"Please install vllm by `pip install vllm==0.9.0.1`\n"
142
+ f"Import error: {VLLM_IMPORT_ERROR}"
119
143
  )
120
144
 
121
145
  return QUANTIZATION_METHODS[quantization]