sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,9 @@
1
1
  import logging
2
- from typing import List, Optional
3
2
 
4
3
  import torch
5
4
  import triton
6
5
 
7
- from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
8
- from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
9
- from sglang.utils import is_in_ci
6
+ from sglang.srt.utils import ceil_div, is_cuda
10
7
 
11
8
  logger = logging.getLogger(__name__)
12
9
 
@@ -130,28 +127,30 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
130
127
 
131
128
  @triton.jit
132
129
  def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
133
- expert = tl.program_id(0)
130
+ expert_id_minus_1 = tl.program_id(0) - 1
134
131
  low = 0
135
132
  high = num_toks - 1
136
133
  target_location = -1
137
134
  while low <= high:
138
135
  mid = (low + high) // 2
139
136
 
140
- if tl.load(reorder_topk_ids + mid) > expert:
137
+ if tl.load(reorder_topk_ids + mid) > expert_id_minus_1:
141
138
  high = mid - 1
142
139
  else:
143
140
  low = mid + 1
144
141
  target_location = mid
145
- tl.store(seg_indptr + expert + 1, target_location + 1)
142
+ tl.store(seg_indptr + expert_id_minus_1 + 1, target_location + 1)
146
143
 
147
144
 
148
- def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
145
+ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_local_experts: int):
149
146
  reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
150
147
 
151
- seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
148
+ seg_indptr = torch.zeros(
149
+ num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
150
+ )
152
151
  src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
153
152
 
154
- compute_seg_indptr_triton_kernel[(num_experts,)](
153
+ compute_seg_indptr_triton_kernel[(num_local_experts,)](
155
154
  reorder_topk_ids, seg_indptr, topk_ids.numel()
156
155
  )
157
156
 
@@ -164,25 +163,6 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
164
163
  return reorder_topk_ids, src2dst, seg_indptr
165
164
 
166
165
 
167
- def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
168
- reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
169
-
170
- seg_indptr = torch.zeros(
171
- local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
172
- )
173
- src2dst = torch.empty(
174
- local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
175
- )
176
-
177
- BLOCK_SIZE = 512
178
- grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
179
- compute_src2dst_triton_kernel[grid](
180
- reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
181
- )
182
-
183
- return reorder_topk_ids, src2dst, seg_indptr
184
-
185
-
186
166
  @triton.jit
187
167
  def pre_reorder_triton_kernel_for_cutlass_moe(
188
168
  input_ptr,
@@ -190,52 +170,13 @@ def pre_reorder_triton_kernel_for_cutlass_moe(
190
170
  src2dst_ptr,
191
171
  topk_ids_ptr,
192
172
  a1_scales_ptr,
193
- num_experts,
173
+ num_local_experts,
194
174
  topk,
195
175
  hidden_size,
196
176
  BLOCK_SIZE: tl.constexpr,
197
177
  ):
198
178
  OutDtype = gateup_input_ptr.dtype.element_ty
199
179
 
200
- src_idx = tl.program_id(0)
201
- src2dst_ptr = src2dst_ptr + src_idx * topk
202
- topk_ids_ptr = topk_ids_ptr + src_idx * topk
203
-
204
- src_ptr = input_ptr + src_idx * hidden_size
205
- for idx in range(topk):
206
- expert_id = tl.load(topk_ids_ptr + idx)
207
- if expert_id != num_experts:
208
- if a1_scales_ptr is not None:
209
- scale = 1.0 / tl.load(a1_scales_ptr)
210
- else:
211
- scale = 1.0
212
-
213
- dst_idx = tl.load(src2dst_ptr + idx)
214
- dst_ptr = gateup_input_ptr + dst_idx * hidden_size
215
- for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
216
- offset = start_offset + tl.arange(0, BLOCK_SIZE)
217
- mask = offset < hidden_size
218
- in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
219
- out_data = (in_data * scale).to(OutDtype)
220
- tl.store(dst_ptr + offset, out_data, mask=mask)
221
-
222
-
223
- @triton.jit
224
- def pre_reorder_triton_kernel(
225
- input_ptr,
226
- gateup_input_ptr,
227
- src2dst_ptr,
228
- topk_ids_ptr,
229
- a1_scales_ptr,
230
- start_expert_id,
231
- end_expert_id,
232
- topk,
233
- hidden_size,
234
- BLOCK_SIZE: tl.constexpr,
235
- use_per_token_if_dynamic: tl.constexpr,
236
- ):
237
- OutDtype = gateup_input_ptr.dtype.element_ty
238
-
239
180
  src_idx_int32 = tl.program_id(0)
240
181
  src_idx = src_idx_int32.to(tl.int64)
241
182
  src2dst_ptr = src2dst_ptr + src_idx * topk
@@ -244,15 +185,11 @@ def pre_reorder_triton_kernel(
244
185
 
245
186
  vec = tl.arange(0, BLOCK_SIZE)
246
187
 
247
- if a1_scales_ptr is not None and use_per_token_if_dynamic:
248
- scale = 1.0 / tl.load(a1_scales_ptr + src_idx)
249
-
250
188
  for idx in range(topk):
251
189
  expert_id = tl.load(topk_ids_ptr + idx)
252
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
190
+ if expert_id != num_local_experts:
253
191
  if a1_scales_ptr is not None:
254
- if not use_per_token_if_dynamic:
255
- scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
192
+ scale = 1.0 / tl.load(a1_scales_ptr)
256
193
  else:
257
194
  scale = 1.0
258
195
 
@@ -267,52 +204,6 @@ def pre_reorder_triton_kernel(
267
204
  tl.store(dst_ptr + offset, out_data, mask=mask)
268
205
 
269
206
 
270
- @triton.jit
271
- def silu_and_mul_triton_kernel(
272
- gateup_output,
273
- down_input,
274
- hidden_size,
275
- reorder_topk_ids,
276
- scales,
277
- start_expert_id,
278
- end_expert_id,
279
- BLOCK_SIZE: tl.constexpr,
280
- ):
281
- InDtype = gateup_output.dtype.element_ty
282
- OutDtype = down_input.dtype.element_ty
283
-
284
- half_hidden_size = hidden_size // 2
285
-
286
- pid = tl.program_id(0)
287
- expert_id = tl.load(reorder_topk_ids + pid)
288
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
289
- gateup_output_ptr = gateup_output + pid * hidden_size
290
- gate_output_ptr = gateup_output_ptr
291
- up_output_ptr = gateup_output_ptr + half_hidden_size
292
- down_input_ptr = down_input + pid * half_hidden_size
293
-
294
- if scales is not None:
295
- scale = tl.load(scales + expert_id - start_expert_id)
296
- scale = (1 / scale).to(InDtype)
297
- else:
298
- scale = 1
299
-
300
- for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
301
- offset = start_offset + tl.arange(0, BLOCK_SIZE)
302
- mask = offset < half_hidden_size
303
-
304
- gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
305
- up_output = tl.load(up_output_ptr + offset, mask=mask)
306
-
307
- # silu & mul & quantize
308
- gate_output = gate_output * tl.sigmoid(gate_output)
309
- gate_output = gate_output.to(InDtype)
310
-
311
- silu_mul_output = gate_output * up_output * scale
312
- silu_mul_output = silu_mul_output.to(OutDtype)
313
- tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
314
-
315
-
316
207
  # copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
317
208
  @triton.jit
318
209
  def _silu_and_mul_post_quant_kernel(
@@ -461,84 +352,15 @@ def silu_and_mul_masked_post_quant_fwd(
461
352
 
462
353
 
463
354
  @triton.jit
464
- def tanh(x):
465
- return 2 * tl.sigmoid(2 * x) - 1
466
-
467
-
468
- @triton.jit
469
- def gelu_and_mul_triton_kernel(
470
- gateup_output,
471
- down_input,
472
- hidden_size,
473
- reorder_topk_ids,
474
- scales,
475
- start_expert_id,
476
- end_expert_id,
477
- BLOCK_SIZE: tl.constexpr,
478
- ):
479
- InDtype = gateup_output.dtype.element_ty
480
- OutDtype = down_input.dtype.element_ty
481
-
482
- half_hidden_size = hidden_size // 2
483
-
484
- pid = tl.program_id(0)
485
- expert_id = tl.load(reorder_topk_ids + pid)
486
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
487
- gateup_output_ptr = gateup_output + pid * hidden_size
488
- gate_output_ptr = gateup_output_ptr
489
- up_output_ptr = gateup_output_ptr + half_hidden_size
490
- down_input_ptr = down_input + pid * half_hidden_size
491
-
492
- if scales is not None:
493
- scale = tl.load(scales + expert_id - start_expert_id)
494
- scale = (1 / scale).to(InDtype)
495
- else:
496
- scale = 1
497
-
498
- for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
499
- offset = start_offset + tl.arange(0, BLOCK_SIZE)
500
- mask = offset < half_hidden_size
501
-
502
- gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
503
- up_output = tl.load(up_output_ptr + offset, mask=mask)
504
-
505
- # gelu & mul & quantize
506
- # https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
507
- # sqrt(2/pi)
508
- kAlpha = 0.7978845608028654
509
- gate_output = (
510
- 0.5
511
- * gate_output
512
- * (
513
- 1
514
- + tanh(
515
- kAlpha
516
- * (
517
- gate_output
518
- + 0.044715 * gate_output * gate_output * gate_output
519
- )
520
- )
521
- )
522
- )
523
- gate_output = gate_output.to(InDtype)
524
-
525
- gelu_mul_output = gate_output * up_output * scale
526
- gelu_mul_output = gelu_mul_output.to(OutDtype)
527
- tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
528
-
529
-
530
- @triton.jit
531
- def post_reorder_triton_kernel(
355
+ def post_reorder_triton_kernel_for_cutlass_moe(
532
356
  down_output_ptr,
533
357
  output_ptr,
534
358
  src2dst_ptr,
535
359
  topk_ids_ptr,
536
360
  topk_weights_ptr,
537
- start_expert_id,
538
- end_expert_id,
539
361
  topk,
362
+ num_local_experts,
540
363
  hidden_size,
541
- dst_start,
542
364
  BLOCK_SIZE: tl.constexpr,
543
365
  ):
544
366
  InDtype = down_output_ptr.dtype.element_ty
@@ -549,7 +371,6 @@ def post_reorder_triton_kernel(
549
371
  topk_ids_ptr = topk_ids_ptr + src_idx * topk
550
372
  topk_weights_ptr = topk_weights_ptr + src_idx * topk
551
373
 
552
- computed = False
553
374
  store_ptr = output_ptr + src_idx * hidden_size
554
375
 
555
376
  vec = tl.arange(0, BLOCK_SIZE)
@@ -561,37 +382,25 @@ def post_reorder_triton_kernel(
561
382
  sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
562
383
  for idx in range(topk):
563
384
  expert_id = tl.load(topk_ids_ptr + idx)
564
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
565
- computed = True
385
+ if expert_id != num_local_experts:
566
386
  dst_idx_int32 = tl.load(src2dst_ptr + idx)
567
387
  dst_idx = dst_idx_int32.to(tl.int64)
568
- dst_idx = dst_idx - dst_start
569
388
  weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
570
389
  load_ptr = down_output_ptr + dst_idx * hidden_size
571
390
  in_data = tl.load(load_ptr + offset, mask=mask)
572
391
  sum_vec += in_data * weigh_scale
573
392
  tl.store(store_ptr + offset, sum_vec, mask=mask)
574
393
 
575
- if computed == False:
576
- for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
577
- offset = start_offset + vec
578
- mask = offset < hidden_size
579
- tl.store(
580
- store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
581
- )
582
-
583
394
 
584
395
  @triton.jit
585
- def post_reorder_triton_kernel_for_cutlass_moe(
396
+ def post_reorder_triton_kernel(
586
397
  down_output_ptr,
587
398
  output_ptr,
588
399
  src2dst_ptr,
589
400
  topk_ids_ptr,
590
401
  topk_weights_ptr,
591
- num_experts,
592
402
  topk,
593
403
  hidden_size,
594
- dst_start,
595
404
  BLOCK_SIZE: tl.constexpr,
596
405
  ):
597
406
  InDtype = down_output_ptr.dtype.element_ty
@@ -613,10 +422,9 @@ def post_reorder_triton_kernel_for_cutlass_moe(
613
422
  sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
614
423
  for idx in range(topk):
615
424
  expert_id = tl.load(topk_ids_ptr + idx)
616
- if expert_id != num_experts:
425
+ if expert_id > 0:
617
426
  dst_idx_int32 = tl.load(src2dst_ptr + idx)
618
427
  dst_idx = dst_idx_int32.to(tl.int64)
619
- dst_idx = dst_idx - dst_start
620
428
  weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
621
429
  load_ptr = down_output_ptr + dst_idx * hidden_size
622
430
  in_data = tl.load(load_ptr + offset, mask=mask)
@@ -624,232 +432,6 @@ def post_reorder_triton_kernel_for_cutlass_moe(
624
432
  tl.store(store_ptr + offset, sum_vec, mask=mask)
625
433
 
626
434
 
627
- @triton.jit
628
- def compute_m_range(
629
- pid,
630
- batch_size,
631
- seg_indptr,
632
- weight_indices,
633
- m_num_tiles_indptr,
634
- BLOCK_SIZE_M: tl.constexpr,
635
- ):
636
- idx = 0
637
- for bs in range(batch_size):
638
- tiles = tl.load(m_num_tiles_indptr + bs)
639
- if pid >= tiles:
640
- idx = bs
641
-
642
- idx_start = tl.load(m_num_tiles_indptr + idx)
643
-
644
- m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
645
- m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
646
- expert_id = tl.load(weight_indices + idx)
647
- return m_range_start, m_range_end, expert_id
648
-
649
-
650
- @triton.jit
651
- def grouped_gemm_triton_kernel(
652
- a,
653
- b,
654
- c,
655
- batch_size,
656
- N,
657
- K,
658
- seg_indptr,
659
- weight_indices,
660
- m_num_tiles_indptr,
661
- scale_a,
662
- scale_b,
663
- use_fp8_w8a8: tl.constexpr,
664
- group_n: tl.constexpr,
665
- group_k: tl.constexpr,
666
- a_stride_0: tl.constexpr,
667
- b_stride_0: tl.constexpr,
668
- b_stride_1: tl.constexpr,
669
- as_stride_0: tl.constexpr,
670
- as_stride_1: tl.constexpr,
671
- bs_stride_0: tl.constexpr,
672
- bs_stride_2: tl.constexpr,
673
- bs_stride_1: tl.constexpr,
674
- use_per_token_if_dynamic: tl.constexpr,
675
- BLOCK_SIZE_M: tl.constexpr,
676
- BLOCK_SIZE_N: tl.constexpr,
677
- BLOCK_SIZE_K: tl.constexpr,
678
- ):
679
- c_dtype = c.dtype.element_ty
680
-
681
- pid_m = tl.program_id(0)
682
- pid_n = tl.program_id(1)
683
- total_m_block = tl.load(m_num_tiles_indptr + batch_size)
684
- if pid_m >= total_m_block:
685
- return
686
-
687
- m_range_start, m_range_end, expert_id = compute_m_range(
688
- pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
689
- )
690
- if m_range_end - m_range_start == 0:
691
- return
692
-
693
- n_range_start = pid_n * BLOCK_SIZE_N
694
- n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
695
-
696
- offs_am = tl.arange(0, BLOCK_SIZE_M)
697
- offs_bn = tl.arange(0, BLOCK_SIZE_N)
698
-
699
- offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
700
- offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
701
- offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
702
- offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
703
- offs_k = tl.arange(0, BLOCK_SIZE_K)
704
-
705
- a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
706
- b_ptr = b + (
707
- (expert_id * b_stride_0)
708
- + (n_range_start + offs_bn[:, None]) * b_stride_1
709
- + offs_k[None, :]
710
- )
711
-
712
- if group_k > 0 and group_n > 0:
713
- a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
714
- offs_bsn = (n_range_start + offs_bn) // group_n
715
- b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
716
-
717
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
718
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
719
- a_tile = tl.load(
720
- a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
721
- )
722
- b_tile = tl.load(
723
- b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
724
- )
725
-
726
- if group_k > 0 and group_n > 0:
727
- k_start = k * BLOCK_SIZE_K
728
- offs_ks = k_start // group_k
729
- a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1)
730
- b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2)
731
- accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
732
- else:
733
- accumulator = tl.dot(a_tile, b_tile.T, accumulator)
734
- a_ptr += BLOCK_SIZE_K
735
- b_ptr += BLOCK_SIZE_K
736
-
737
- if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
738
- if use_per_token_if_dynamic:
739
- scale_a_value = tl.load(scale_a + (m_range_start + offs_am[:, None]))
740
- else:
741
- scale_a_value = tl.load(scale_a + expert_id)
742
- scale_b_value = tl.load(scale_b + expert_id)
743
- accumulator *= scale_a_value * scale_b_value
744
-
745
- c_tile = accumulator.to(c_dtype)
746
-
747
- offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
748
- offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
749
- c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
750
- c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
751
- tl.store(c_ptr, c_tile, mask=c_mask)
752
-
753
-
754
- @triton.jit
755
- def compute_m_num_tiles_indptr(
756
- m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
757
- ):
758
- for bs in range(batch_size):
759
- m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
760
- cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
761
- pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
762
- tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
763
-
764
-
765
- def grouped_gemm_triton(
766
- a: torch.Tensor,
767
- b: torch.Tensor,
768
- c: torch.Tensor,
769
- batch_size: int,
770
- weight_column_major: bool,
771
- seg_indptr: Optional[torch.Tensor] = None,
772
- weight_indices: Optional[torch.Tensor] = None,
773
- use_fp8_w8a8: bool = False,
774
- scale_a: torch.Tensor = None,
775
- scale_b: torch.Tensor = None,
776
- block_shape: Optional[List[int]] = None,
777
- c_dtype=None,
778
- use_per_token_if_dynamic: bool = True,
779
- ):
780
- assert weight_column_major == True # TODO: more
781
- if use_fp8_w8a8 and block_shape is None:
782
- assert scale_a is not None and scale_b is not None
783
-
784
- if block_shape is not None:
785
- a_original = a
786
-
787
- assert len(block_shape) == 2
788
- block_n, block_k = block_shape[0], block_shape[1]
789
- a, scale_a = per_token_group_quant_fp8(a, block_k)
790
-
791
- assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
792
- assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
793
- assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
794
-
795
- dispose_tensor(a_original)
796
-
797
- # TODO: adjust config or tune kernel
798
- # Reduce block size to prevent L40 shared memory overflow.
799
- config = {
800
- "BLOCK_SIZE_M": 64,
801
- "BLOCK_SIZE_N": 32,
802
- "BLOCK_SIZE_K": 128,
803
- }
804
-
805
- m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
806
- compute_m_num_tiles_indptr[(1,)](
807
- m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
808
- )
809
-
810
- if c is None:
811
- assert c_dtype is not None
812
- c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype)
813
-
814
- grid = lambda META: (
815
- triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
816
- triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
817
- )
818
-
819
- if use_fp8_w8a8 and block_shape is None and use_per_token_if_dynamic:
820
- assert (
821
- scale_a.shape[0] == a.shape[0]
822
- ), f"scale_a.shape: {scale_a.shape}, a.shape: {a.shape}"
823
-
824
- grouped_gemm_triton_kernel[grid](
825
- a,
826
- b,
827
- c,
828
- batch_size,
829
- b.size(1),
830
- b.size(2),
831
- seg_indptr,
832
- weight_indices,
833
- m_num_tiles_indptr,
834
- scale_a,
835
- scale_b,
836
- use_fp8_w8a8,
837
- 0 if block_shape is None else block_shape[0],
838
- 0 if block_shape is None else block_shape[1],
839
- a.stride(0),
840
- b.stride(0),
841
- b.stride(1),
842
- scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0,
843
- scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0,
844
- scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
845
- scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
846
- scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
847
- use_per_token_if_dynamic,
848
- **config,
849
- )
850
- return c
851
-
852
-
853
435
  @triton.jit
854
436
  def _fwd_kernel_ep_scatter_1(
855
437
  num_recv_tokens_per_expert,
@@ -984,7 +566,9 @@ def ep_scatter(
984
566
  scale_hidden_size = ceil_div(scale_hidden_size, 4)
985
567
 
986
568
  assert m_indices.shape[0] % BLOCK_E == 0
987
- assert recv_x_scale.dtype == output_tensor_scale.dtype
569
+ assert (
570
+ recv_x_scale.dtype == output_tensor_scale.dtype
571
+ ), f"recv_x_scale.dtype: {recv_x_scale.dtype}, output_tensor_scale.dtype: {output_tensor_scale.dtype}"
988
572
  assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size
989
573
 
990
574
  _fwd_kernel_ep_scatter_1[(grid,)](
@@ -1234,7 +818,7 @@ def deepgemm_compute_src2dst_triton_kernel(
1234
818
  mask = dst_id < num_toks
1235
819
  src_id = tl.load(reorder_ids + dst_id, mask=mask)
1236
820
  expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
1237
- expert_dst_start = tl.load(seg_indptr + expert_id)
821
+ expert_dst_start = tl.load(seg_indptr + expert_id, mask=(expert_id >= 0))
1238
822
  expert_dst_offset = dst_id - expert_dst_start
1239
823
  dst_id = expert_id * m_max + expert_dst_offset
1240
824
  tl.store(src2dst + src_id, dst_id, mask=mask)
@@ -1248,10 +832,7 @@ def fill_gateup_input_triton_kernel(
1248
832
  gateup_input_scale_ptr,
1249
833
  src2dst_ptr,
1250
834
  topk_ids_ptr,
1251
- start_expert_id,
1252
- end_expert_id,
1253
835
  topk,
1254
- m_max,
1255
836
  hidden_size,
1256
837
  scale_size,
1257
838
  BLOCK_SIZE: tl.constexpr,
@@ -1267,10 +848,9 @@ def fill_gateup_input_triton_kernel(
1267
848
  vec = tl.arange(0, BLOCK_SIZE)
1268
849
  for idx in range(topk):
1269
850
  expert_id = tl.load(topk_ids_ptr + idx)
1270
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
851
+ if expert_id >= 0:
1271
852
  dst_idx_int32 = tl.load(src2dst_ptr + idx)
1272
853
  dst_idx = dst_idx_int32.to(tl.int64)
1273
- dst_idx = dst_idx - start_expert_id * m_max
1274
854
  dst_ptr = gateup_input_ptr + dst_idx * hidden_size
1275
855
  for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
1276
856
  offset = start_offset + vec
@@ -1287,31 +867,31 @@ def fill_gateup_input_triton_kernel(
1287
867
 
1288
868
  def moe_ep_deepgemm_preprocess(
1289
869
  topk_ids: torch.Tensor,
1290
- num_experts: int,
870
+ num_local_experts: int,
1291
871
  hidden_states: torch.Tensor,
1292
872
  top_k: int,
1293
- start_expert_id,
1294
- end_expert_id,
1295
873
  block_shape,
1296
874
  output_dtype: torch.dtype = torch.float8_e4m3fn,
1297
875
  ):
1298
876
  reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
1299
- seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
877
+ seg_indptr = torch.zeros(
878
+ num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
879
+ )
1300
880
  src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
1301
- masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
881
+ masked_m = torch.empty(num_local_experts, device=topk_ids.device, dtype=torch.int32)
1302
882
 
1303
- compute_seg_indptr_triton_kernel[(num_experts,)](
883
+ compute_seg_indptr_triton_kernel[(num_local_experts + 1,)](
1304
884
  reorder_topk_ids, seg_indptr, topk_ids.numel()
1305
885
  )
1306
886
 
1307
887
  grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
1308
- compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
888
+ compute_masked_m_triton_kernel[(num_local_experts,)](seg_indptr, masked_m)
1309
889
 
1310
890
  # For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
1311
- m_max = (hidden_states.size(0) + 255) // 256 * 256
1312
- expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
891
+ m_max = (hidden_states.size(0) // 256 + 1) * 256
892
+ expected_m = (topk_ids.numel() - 1) // num_local_experts + 1
1313
893
  gateup_input = torch.empty(
1314
- (int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
894
+ (num_local_experts, m_max, hidden_states.size(1)),
1315
895
  device=hidden_states.device,
1316
896
  dtype=output_dtype,
1317
897
  )
@@ -1330,6 +910,8 @@ def moe_ep_deepgemm_preprocess(
1330
910
  block_shape = [128, 128]
1331
911
  assert len(block_shape) == 2
1332
912
  block_n, block_k = block_shape[0], block_shape[1]
913
+
914
+ # TODO: fuse this with the preprocess
1333
915
  hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
1334
916
 
1335
917
  gateup_input_scale = torch.empty(
@@ -1345,18 +927,14 @@ def moe_ep_deepgemm_preprocess(
1345
927
  gateup_input_scale,
1346
928
  src2dst,
1347
929
  topk_ids,
1348
- start_expert_id,
1349
- end_expert_id,
1350
930
  top_k,
1351
- m_max,
1352
931
  hidden_states.size(1),
1353
932
  scale.size(1),
1354
933
  BLOCK_SIZE=1024,
1355
934
  )
1356
935
 
1357
936
  return (
1358
- m_max,
1359
- masked_m[start_expert_id : (end_expert_id + 1)],
937
+ masked_m,
1360
938
  expected_m,
1361
939
  src2dst,
1362
940
  gateup_input,
@@ -1436,3 +1014,197 @@ def zero_experts_compute_triton(
1436
1014
  )
1437
1015
 
1438
1016
  return output
1017
+
1018
+
1019
+ @triton.jit
1020
+ def compute_problem_sizes_w4a8_kernel(
1021
+ masked_m_ptr,
1022
+ problem_sizes1_ptr,
1023
+ problem_sizes2_ptr,
1024
+ n,
1025
+ k,
1026
+ num_experts,
1027
+ BLOCK_SIZE: tl.constexpr,
1028
+ ):
1029
+ pid = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1030
+ mask = pid < num_experts
1031
+ final_occurrences = tl.load(masked_m_ptr + pid, mask=mask, other=0)
1032
+
1033
+ ps1_idx_0 = pid * 3
1034
+ ps1_idx_1 = ps1_idx_0 + 1
1035
+ ps1_idx_2 = ps1_idx_0 + 2
1036
+
1037
+ ps2_idx_0 = pid * 3
1038
+ ps2_idx_1 = ps2_idx_0 + 1
1039
+ ps2_idx_2 = ps2_idx_0 + 2
1040
+
1041
+ ps1_mask_0 = ps1_idx_0 < num_experts * 3
1042
+ ps1_mask_1 = ps1_idx_1 < num_experts * 3
1043
+ ps1_mask_2 = ps1_idx_2 < num_experts * 3
1044
+ ps2_mask_0 = ps2_idx_0 < num_experts * 3
1045
+ ps2_mask_1 = ps2_idx_1 < num_experts * 3
1046
+ ps2_mask_2 = ps2_idx_2 < num_experts * 3
1047
+
1048
+ tl.store(problem_sizes1_ptr + ps1_idx_0, 2 * n, mask=ps1_mask_0)
1049
+ tl.store(problem_sizes1_ptr + ps1_idx_1, final_occurrences, mask=ps1_mask_1)
1050
+ tl.store(problem_sizes1_ptr + ps1_idx_2, k, mask=ps1_mask_2)
1051
+
1052
+ tl.store(problem_sizes2_ptr + ps2_idx_0, k, mask=ps2_mask_0)
1053
+ tl.store(problem_sizes2_ptr + ps2_idx_1, final_occurrences, mask=ps2_mask_1)
1054
+ tl.store(problem_sizes2_ptr + ps2_idx_2, n, mask=ps2_mask_2)
1055
+
1056
+
1057
+ def compute_problem_sizes_w4a8(
1058
+ masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
1059
+ ):
1060
+ BLOCK_SIZE = 256
1061
+ grid = lambda meta: (triton.cdiv(num_experts, meta["BLOCK_SIZE"]),)
1062
+ compute_problem_sizes_w4a8_kernel[grid](
1063
+ masked_m,
1064
+ problem_sizes1,
1065
+ problem_sizes2,
1066
+ n,
1067
+ k,
1068
+ num_experts,
1069
+ BLOCK_SIZE=BLOCK_SIZE,
1070
+ )
1071
+ return problem_sizes1, problem_sizes2
1072
+
1073
+
1074
+ def deepep_ll_get_cutlass_w4a8_moe_mm_data(
1075
+ masked_m,
1076
+ problem_sizes1,
1077
+ problem_sizes2,
1078
+ num_experts,
1079
+ n,
1080
+ k,
1081
+ ):
1082
+ problem_sizes1, problem_sizes2 = compute_problem_sizes_w4a8(
1083
+ masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
1084
+ )
1085
+ return (
1086
+ problem_sizes1.to(torch.int32),
1087
+ problem_sizes2.to(torch.int32),
1088
+ )
1089
+
1090
+
1091
+ @triton.jit
1092
+ def _silu_and_mul_post_per_tensor_quant_kernel(
1093
+ input_ptr,
1094
+ stride_input_expert,
1095
+ stride_input_token,
1096
+ stride_input_dim,
1097
+ output_ptr,
1098
+ stride_output_expert,
1099
+ stride_output_token,
1100
+ stride_output_dim,
1101
+ scale_ptr,
1102
+ masked_m_ptr,
1103
+ inner_dim,
1104
+ fp8_max,
1105
+ fp8_min,
1106
+ BLOCK_N: tl.constexpr,
1107
+ NUM_STAGE: tl.constexpr,
1108
+ ):
1109
+ """
1110
+ Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization.
1111
+
1112
+ Shape:
1113
+ input: [E, T_padded, 2*D] -> gate: [:,:,D], up: [:,:,D]
1114
+ output: [E, T_padded, D], dtype=float8_e4m3fn
1115
+ """
1116
+ expert_id = tl.program_id(2)
1117
+ block_id_token = tl.program_id(1)
1118
+ block_id_dim = tl.program_id(0)
1119
+
1120
+ num_token_blocks = tl.num_programs(1)
1121
+
1122
+ token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
1123
+
1124
+ scale = 1.0 / tl.load(scale_ptr).to(tl.float32)
1125
+
1126
+ stride_input_expert = tl.cast(stride_input_expert, tl.int32)
1127
+ stride_output_expert = tl.cast(stride_output_expert, tl.int32)
1128
+ stride_input_token = tl.cast(stride_input_token, tl.int32)
1129
+ stride_output_token = tl.cast(stride_output_token, tl.int32)
1130
+
1131
+ offset_d = block_id_dim * BLOCK_N + tl.arange(0, BLOCK_N)
1132
+ mask_d = offset_d < inner_dim
1133
+
1134
+ # base pointers for current expert and dim block
1135
+ input_base_offs = input_ptr + expert_id * stride_input_expert + offset_d
1136
+ output_base_offs = output_ptr + expert_id * stride_output_expert + offset_d
1137
+
1138
+ for token_idx in tl.range(
1139
+ block_id_token, token_num_cur_expert, num_token_blocks, num_stages=NUM_STAGE
1140
+ ):
1141
+ gate_ptr = input_base_offs + token_idx * stride_input_token
1142
+ up_ptr = gate_ptr + inner_dim
1143
+ gate = tl.load(gate_ptr, mask=mask_d, other=0.0).to(tl.float32)
1144
+ up = tl.load(up_ptr, mask=mask_d, other=0.0).to(tl.float32)
1145
+
1146
+ # SiLU: x * sigmoid(x)
1147
+ gate = gate / (1 + tl.exp(-gate))
1148
+ gate = gate.to(input_ptr.dtype.element_ty)
1149
+ gate_up = up * gate
1150
+
1151
+ scaled = gate_up * scale
1152
+ output_q = tl.clamp(scaled, fp8_min, fp8_max).to(output_ptr.dtype.element_ty)
1153
+ out_ptr = output_base_offs + token_idx * stride_output_token
1154
+ tl.store(out_ptr, output_q, mask=mask_d)
1155
+
1156
+
1157
+ def silu_and_mul_masked_post_per_tensor_quant_fwd(
1158
+ input: torch.Tensor,
1159
+ output: torch.Tensor,
1160
+ masked_m: torch.Tensor,
1161
+ scale: torch.Tensor,
1162
+ ) -> torch.Tensor:
1163
+ """
1164
+ Fused SiLU + Mul + Per-Tensor Quantization to FP8.
1165
+
1166
+ Args:
1167
+ input: [expert_num, token_num_padded, 2 * inner_dim]
1168
+ output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn
1169
+ masked_m: [expert_num], actual token count for each expert
1170
+ scale: [1] or [expert_num], quantization scale (per-tensor or per-expert)
1171
+
1172
+ Returns:
1173
+ output tensor
1174
+ """
1175
+ assert input.is_contiguous()
1176
+ assert output.is_contiguous()
1177
+ assert output.dtype == torch.float8_e4m3fn
1178
+ assert input.ndim == 3
1179
+ assert input.shape[0] == masked_m.shape[0]
1180
+ assert input.shape[-1] % 2 == 0
1181
+ assert scale.numel() == 1 or scale.shape[0] == input.shape[0]
1182
+
1183
+ expert_num = input.shape[0]
1184
+ # 3584
1185
+ inner_dim = input.shape[-1] // 2
1186
+
1187
+ BLOCK_N = 256
1188
+ BLOCK_M = 64 if expert_num < 4 else 32
1189
+ NUM_STAGES = 3
1190
+ hidden_dim_split_block_num = triton.cdiv(inner_dim, BLOCK_N)
1191
+
1192
+ grid = (hidden_dim_split_block_num, BLOCK_M, expert_num)
1193
+ finfo = torch.finfo(torch.float8_e4m3fn)
1194
+ fp8_max = finfo.max
1195
+ fp8_min = -fp8_max
1196
+
1197
+ _silu_and_mul_post_per_tensor_quant_kernel[grid](
1198
+ input,
1199
+ *input.stride(),
1200
+ output,
1201
+ *output.stride(),
1202
+ scale,
1203
+ masked_m,
1204
+ inner_dim,
1205
+ fp8_max,
1206
+ fp8_min,
1207
+ BLOCK_N=BLOCK_N,
1208
+ NUM_STAGE=NUM_STAGES,
1209
+ )
1210
+ return output