sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -3,12 +3,15 @@ from typing import Callable
3
3
 
4
4
  import pytest
5
5
  import torch
6
+ from flashinfer import fp4_quantize
6
7
  from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
7
- from sgl_kernel import scaled_fp4_quant
8
+ from sgl_kernel import scaled_fp4_grouped_quant, scaled_fp4_quant
9
+ from torch.nn import functional as F
8
10
 
9
11
  from sglang.srt.layers.activation import SiluAndMul
10
12
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
11
13
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
14
+ from sglang.srt.layers.moe.flashinfer_cutedsl_moe import flashinfer_cutedsl_moe_masked
12
15
  from sglang.srt.layers.moe.topk import TopKConfig, select_experts
13
16
 
14
17
  if torch.cuda.get_device_capability() < (10, 0):
@@ -78,6 +81,37 @@ def break_fp4_bytes(a, dtype):
78
81
  return values.reshape(m, n * 2).to(dtype=dtype)
79
82
 
80
83
 
84
+ def compute_routing(router_logits: torch.Tensor, top_k: int):
85
+ routing_weights = torch.softmax(router_logits, dim=1, dtype=torch.float)
86
+ routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
87
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
88
+ routing_weights = routing_weights.float()
89
+ return routing_weights, selected_experts
90
+
91
+
92
+ def prepare_inputs(
93
+ hidden_states: torch.Tensor,
94
+ router_logits: torch.Tensor,
95
+ num_experts: int,
96
+ topk: int,
97
+ ):
98
+ routing_weights, topk_idx = compute_routing(router_logits, topk)
99
+
100
+ masked_m = []
101
+ for i in range(num_experts):
102
+ mask = topk_idx.view(-1) == i
103
+ masked_m.append(mask.sum())
104
+
105
+ masked_m = torch.tensor(masked_m, dtype=torch.int32)
106
+ hidden_states_3d = torch.empty(
107
+ (num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype
108
+ )
109
+ for i in range(num_experts):
110
+ hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i]
111
+
112
+ return hidden_states_3d, masked_m, topk_idx, routing_weights
113
+
114
+
81
115
  MNK_FACTORS = [
82
116
  (2, 1024, 1024),
83
117
  (2, 1024, 1536),
@@ -114,6 +148,99 @@ def torch_moe(a, w1, w2, score, topk, expert_map):
114
148
  ).sum(dim=1)
115
149
 
116
150
 
151
+ def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids):
152
+ B, D = a.shape
153
+ a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
154
+ out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
155
+
156
+ topk_weight = topk_weight.view(-1)
157
+ topk_ids = topk_ids.view(-1)
158
+
159
+ for i in range(w1.shape[0]):
160
+ mask = topk_ids == i
161
+ if mask.sum():
162
+ m = w1[i].shape[0]
163
+ assert m % 2 == 0
164
+ # Note: w1 and w3 are swapped!
165
+ w3_expert, w1_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :]
166
+ inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
167
+ inter_gs = torch.tensor(1.0).cuda()
168
+ inter_q, inter_blockscale = fp4_quantize(inter, inter_gs)
169
+ inter = dequantize_nvfp4_to_dtype(
170
+ inter_q,
171
+ inter_blockscale,
172
+ inter_gs,
173
+ dtype=inter.dtype,
174
+ device=inter.device,
175
+ block_size=16,
176
+ ).cuda()
177
+ out[mask] = inter @ w2[i].transpose(0, 1)
178
+ return (
179
+ out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
180
+ ).sum(dim=1)
181
+
182
+
183
+ def flashinfer_cutedsl_grouped_gemm_nt_masked(
184
+ hidden_states: torch.Tensor, # 3d
185
+ input_global_scale: torch.Tensor, # (l,)
186
+ weights: torch.Tensor,
187
+ w_global_scale: torch.Tensor, # (l,)
188
+ masked_m: torch.Tensor,
189
+ ):
190
+ from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked
191
+
192
+ # hidden_states: [l, m, k]
193
+ # weights: [l, n, k]
194
+ aq, aq_sf = scaled_fp4_grouped_quant(
195
+ hidden_states,
196
+ input_global_scale,
197
+ masked_m.to(hidden_states.device),
198
+ )
199
+ num_experts, n, k = weights.shape
200
+ bq, bq_sf = scaled_fp4_grouped_quant(
201
+ weights,
202
+ w_global_scale,
203
+ torch.ones(num_experts, device=weights.device, dtype=torch.int32) * n,
204
+ )
205
+
206
+ out = torch.zeros(
207
+ (num_experts, max(masked_m), n), dtype=weights.dtype, device=aq.device
208
+ )
209
+ out = out.permute(1, 2, 0) # requirement of kernel
210
+ sf_vec_size = 16
211
+ ab_dtype = "float4_e2m1fn"
212
+ sf_dtype = "float8_e4m3fn"
213
+ c_dtype = "bfloat16"
214
+ alpha = 1.0 / (input_global_scale * w_global_scale).to(out.dtype).view(
215
+ 1, 1, num_experts
216
+ )
217
+
218
+ def get_cute_dtype(input: torch.Tensor) -> str:
219
+ if input.dtype == torch.bfloat16:
220
+ return "bfloat16"
221
+ elif input.dtype == torch.float16:
222
+ return "float16"
223
+ elif input.dtype == torch.float32:
224
+ return "float32"
225
+ else:
226
+ raise ValueError(f"Unsupported cute dtype {input.dtype}")
227
+
228
+ grouped_gemm_nt_masked(
229
+ (aq, aq_sf),
230
+ (bq, bq_sf),
231
+ out,
232
+ masked_m.to(aq.device),
233
+ ab_dtype=ab_dtype,
234
+ sf_dtype=sf_dtype,
235
+ c_dtype=c_dtype,
236
+ sf_vec_size=sf_vec_size,
237
+ alpha=alpha,
238
+ alpha_dtype=get_cute_dtype(alpha),
239
+ )
240
+
241
+ return out
242
+
243
+
117
244
  def check_moe(
118
245
  m: int,
119
246
  n: int,
@@ -324,6 +451,248 @@ def test_flashinfer_fp4_moe_no_graph(
324
451
  check_moe(m, n, k, e, topk, dtype, flashinfer_moe_impl, flip_w13=True)
325
452
 
326
453
 
454
+ @pytest.mark.parametrize("bs, hidden_dim, inter_dim", [(2, 128, 256), (16, 128, 512)])
455
+ @pytest.mark.parametrize("topk", [1, 2, 4])
456
+ @torch.inference_mode()
457
+ def test_flashinfer_cutedsl_moe_masked(
458
+ bs: int, hidden_dim: int, inter_dim: int, topk: int
459
+ ):
460
+ torch.manual_seed(42)
461
+ device = "cuda"
462
+ dtype = torch.bfloat16
463
+ num_experts = 8
464
+ hidden_states = (
465
+ torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device) / 5.0
466
+ )
467
+ w1 = (
468
+ torch.randn(
469
+ num_experts, 2 * inter_dim, hidden_dim, dtype=torch.bfloat16, device=device
470
+ )
471
+ / 10.0
472
+ )
473
+ w2 = (
474
+ torch.randn(
475
+ num_experts, hidden_dim, inter_dim, dtype=torch.bfloat16, device=device
476
+ )
477
+ / 10.0
478
+ )
479
+ router_logits = torch.randn(bs, num_experts, dtype=torch.float32)
480
+
481
+ hidden_states_expanded = (
482
+ hidden_states.view(bs, -1, hidden_dim)
483
+ .repeat(1, topk, 1)
484
+ .reshape(-1, hidden_dim)
485
+ )
486
+ hidden_states_3d, masked_m, topk_idx, routing_weights = prepare_inputs(
487
+ hidden_states_expanded, router_logits, num_experts, topk
488
+ )
489
+
490
+ w1_amax = w1.abs().amax(dim=(1, 2)).to(torch.float32).to(w1.device)
491
+ w2_amax = w2.abs().amax(dim=(1, 2)).to(torch.float32).to(w2.device)
492
+ input_global_scale = torch.ones(
493
+ (num_experts,), dtype=torch.float32, device=hidden_states.device
494
+ )
495
+
496
+ w1_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
497
+ w2_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
498
+ a2_global_scale = torch.ones(
499
+ (num_experts,), dtype=torch.float32, device=hidden_states.device
500
+ ) # assume intermediate scale is 1.0
501
+
502
+ w1_fp4, w1_blockscale = scaled_fp4_grouped_quant(
503
+ w1,
504
+ w1_global_scale,
505
+ torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim,
506
+ )
507
+ w2_fp4, w2_blockscale = scaled_fp4_grouped_quant(
508
+ w2,
509
+ w2_global_scale,
510
+ torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim,
511
+ )
512
+
513
+ w1_alpha = 1.0 / (input_global_scale * w1_global_scale)
514
+ w2_alpha = 1.0 / (a2_global_scale * w2_global_scale)
515
+
516
+ out = flashinfer_cutedsl_moe_masked(
517
+ hidden_states_3d.to(hidden_states.device),
518
+ input_global_scale,
519
+ w1_fp4.permute(2, 0, 1),
520
+ w1_blockscale,
521
+ w1_alpha,
522
+ w2_fp4.permute(2, 0, 1),
523
+ a2_global_scale,
524
+ w2_blockscale,
525
+ w2_alpha,
526
+ masked_m.to(hidden_states.device),
527
+ )
528
+
529
+ # reference
530
+ a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale)
531
+ a_in_dtype = dequantize_nvfp4_to_dtype(
532
+ a_fp4,
533
+ a_scale_interleaved,
534
+ input_global_scale,
535
+ dtype=hidden_states.dtype,
536
+ device=hidden_states.device,
537
+ block_size=16,
538
+ )
539
+ w1_d = torch.empty(
540
+ (num_experts, 2 * inter_dim, hidden_dim), device=w1.device, dtype=w1.dtype
541
+ )
542
+ w2_d = torch.empty(
543
+ (num_experts, hidden_dim, inter_dim), device=w2.device, dtype=w2.dtype
544
+ )
545
+
546
+ for idx in range(0, num_experts):
547
+ w1_fp4_sliced, w1_blockscale_sliced = fp4_quantize(
548
+ w1[idx], w1_global_scale[idx]
549
+ )
550
+ w2_fp4_sliced, w2_blockscale_sliced = fp4_quantize(
551
+ w2[idx], w2_global_scale[idx]
552
+ )
553
+ w1_d[idx] = dequantize_nvfp4_to_dtype(
554
+ w1_fp4_sliced,
555
+ w1_blockscale_sliced,
556
+ w1_global_scale[idx],
557
+ dtype=w1.dtype,
558
+ device=w1.device,
559
+ block_size=16,
560
+ )
561
+ w2_d[idx] = dequantize_nvfp4_to_dtype(
562
+ w2_fp4_sliced,
563
+ w2_blockscale_sliced,
564
+ w2_global_scale[idx],
565
+ dtype=w2.dtype,
566
+ device=w2.device,
567
+ block_size=16,
568
+ )
569
+
570
+ ref_output = torch_moe_nvfp4(
571
+ a_in_dtype,
572
+ w1_d,
573
+ w2_d,
574
+ topk,
575
+ routing_weights.to(a_in_dtype.device),
576
+ topk_idx.to(a_in_dtype.device),
577
+ )
578
+ out_weighted = torch.zeros_like(ref_output, device=out.device, dtype=out.dtype)
579
+
580
+ positions = torch.nonzero(masked_m[topk_idx], as_tuple=False)
581
+ rows, cols = positions[:, 0], positions[:, 1]
582
+ experts = topk_idx[rows, cols]
583
+ for i in range(num_experts):
584
+ mask = experts == i
585
+ if mask.any():
586
+ idx = torch.nonzero(mask, as_tuple=False).squeeze(-1)
587
+ r, c = rows[idx], cols[idx]
588
+ out_weighted[r] += out[i, : len(r), :] * routing_weights[r, c].to(
589
+ out.device
590
+ ).unsqueeze(-1)
591
+ torch.testing.assert_close(
592
+ out_weighted.cpu(), ref_output.cpu(), atol=5e-2, rtol=5e-2
593
+ )
594
+
595
+
596
+ @pytest.mark.parametrize(
597
+ "bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)]
598
+ )
599
+ @torch.inference_mode()
600
+ def test_grouped_gemm_nt_masked(
601
+ bs: int, hidden_dim: int, inter_dim: int, topk: int
602
+ ) -> None:
603
+ torch.manual_seed(42)
604
+ B = bs
605
+ D = hidden_dim
606
+ N = inter_dim
607
+ num_experts = 8
608
+ hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
609
+ weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
610
+ router_logits = torch.randn(B, num_experts, dtype=torch.float32)
611
+
612
+ hidden_states_expanded = (
613
+ hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
614
+ )
615
+ hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
616
+ hidden_states_expanded, router_logits, num_experts, topk
617
+ )
618
+
619
+ # reference
620
+ out = torch.zeros(
621
+ (B * topk, weights.shape[1]), dtype=weights.dtype, device=weights.device
622
+ )
623
+ for i in range(num_experts):
624
+ mask = topk_idx.view(-1) == i
625
+ if mask.sum():
626
+ lhs = hidden_states_expanded[mask]
627
+ rhs = weights[i]
628
+ a_amax = lhs.abs().max().to(torch.float32).to(hidden_states.device)
629
+ b_amax = rhs.abs().amax().to(torch.float32).to(weights.device)
630
+ a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
631
+ b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
632
+
633
+ lhsq, lhsq_sf = fp4_quantize(
634
+ lhs,
635
+ a_gs,
636
+ )
637
+ rhsq, rhsq_sf = fp4_quantize(
638
+ rhs,
639
+ b_gs,
640
+ )
641
+
642
+ lhs_in_dtype = dequantize_nvfp4_to_dtype(
643
+ lhsq,
644
+ lhsq_sf,
645
+ a_gs,
646
+ dtype=hidden_states.dtype,
647
+ device=hidden_states.device,
648
+ block_size=16,
649
+ )
650
+
651
+ rhs_in_dtype = dequantize_nvfp4_to_dtype(
652
+ rhsq,
653
+ rhsq_sf,
654
+ b_gs,
655
+ dtype=hidden_states.dtype,
656
+ device=hidden_states.device,
657
+ block_size=16,
658
+ )
659
+ out[mask] = lhs_in_dtype @ rhs_in_dtype.t()
660
+
661
+ a_amax = (
662
+ hidden_states_3d.abs()
663
+ .amax(dim=(1, 2))
664
+ .to(torch.float32)
665
+ .to(hidden_states.device)
666
+ )
667
+ b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
668
+ a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
669
+ b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
670
+ out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
671
+ hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
672
+ )
673
+
674
+ # re-pack out into [num_experts, max_m, n]
675
+ out_ref = torch.zeros(
676
+ (num_experts, max(masked_m), weights.shape[1]), dtype=out.dtype
677
+ )
678
+ expert_slot = [0] * num_experts
679
+ for i, expert_id in enumerate(topk_idx.view(-1).tolist()):
680
+ out_ref[expert_id, expert_slot[expert_id], :] = out[i]
681
+ expert_slot[expert_id] += 1
682
+
683
+ # Note: just to compare the masked position due to cutedsl may write nan
684
+ # into unmasked position.
685
+ for i in range(num_experts):
686
+ torch.testing.assert_close(
687
+ out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]],
688
+ out_ref.to(out_flashinfer.device)[i, : masked_m[i]],
689
+ atol=1e-1,
690
+ rtol=5e-2,
691
+ )
692
+
693
+
327
694
  if __name__ == "__main__":
328
695
  test_cutlass_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
329
696
  test_flashinfer_fp4_moe_no_graph(224, 1024, 1024, 256, 8, torch.half)
697
+ test_flashinfer_cutedsl_moe_masked(16, 128, 512, 4)
698
+ test_grouped_gemm_nt_masked(16, 128, 512, 4)
sglang/test/test_utils.py CHANGED
@@ -42,7 +42,8 @@ DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
42
42
  DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
43
43
  DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE = "meta-llama/Llama-3.2-1B"
44
44
  DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
45
- DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
45
+ DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_BASE = "Qwen/Qwen1.5-MoE-A2.7B"
46
+ DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT = "Qwen/Qwen1.5-MoE-A2.7B-Chat"
46
47
 
47
48
  # MLA test models
48
49
  DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
@@ -52,6 +53,9 @@ DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instru
52
53
  DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"
53
54
  DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = "lmsys/sglang-ci-dsv3-test-NextN"
54
55
 
56
+ # NVFP4 models
57
+ DEFAULT_DEEPSEEK_NVFP4_MODEL_FOR_TEST = "nvidia/DeepSeek-R1-0528-FP4"
58
+
55
59
  # FP8 models
56
60
  DEFAULT_MODEL_NAME_FOR_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
57
61
  DEFAULT_MODEL_NAME_FOR_ACCURACY_TEST_FP8 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8"
@@ -72,6 +76,10 @@ DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE = "nytopop/Qwen3-30B-A3B.w8a8"
72
76
  DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
73
77
  DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
74
78
  DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
79
+ DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
80
+ "meta-llama/Llama-3.1-8B-Instruct"
81
+ )
82
+ DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
75
83
 
76
84
  # Other use cases
77
85
  DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
@@ -466,6 +474,25 @@ def try_cached_model(model_repo: str):
466
474
  return model_dir if model_dir else model_repo
467
475
 
468
476
 
477
+ def popen_with_error_check(command: list[str], allow_exit: bool = False):
478
+ process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
479
+
480
+ def _run_and_check():
481
+ stdout, stderr = process.communicate()
482
+
483
+ while process.poll() is None:
484
+ time.sleep(5)
485
+
486
+ if not allow_exit or process.returncode != 0:
487
+ raise Exception(
488
+ f"{command} exited with code {process.returncode}\n{stdout=}\n{stderr=}"
489
+ )
490
+
491
+ t = threading.Thread(target=_run_and_check)
492
+ t.start()
493
+ return process
494
+
495
+
469
496
  def popen_launch_server(
470
497
  model: str,
471
498
  base_url: str,
sglang/utils.py CHANGED
@@ -457,6 +457,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
457
457
  NOTE: Typically, the server runs in a separate terminal.
458
458
  In this notebook, we run the server and notebook code together, so their outputs are combined.
459
459
  To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
460
+ To reduce the log length, we set the log level to warning for the server, the default log level is info.
460
461
  We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
461
462
  """
462
463
  )
@@ -471,11 +472,22 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
471
472
  class TypeBasedDispatcher:
472
473
  def __init__(self, mapping: List[Tuple[Type, Callable]]):
473
474
  self._mapping = mapping
475
+ self._fallback_fn = None
476
+
477
+ def add_fallback_fn(self, fallback_fn: Callable):
478
+ self._fallback_fn = fallback_fn
479
+
480
+ def __iadd__(self, other: "TypeBasedDispatcher"):
481
+ self._mapping.extend(other._mapping)
482
+ return self
474
483
 
475
484
  def __call__(self, obj: Any):
476
485
  for ty, fn in self._mapping:
477
486
  if isinstance(obj, ty):
478
487
  return fn(obj)
488
+
489
+ if self._fallback_fn is not None:
490
+ return self._fallback_fn(obj)
479
491
  raise ValueError(f"Invalid object: {obj}")
480
492
 
481
493
 
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.5.2rc1"
1
+ __version__ = "0.5.3rc0"