sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,9 @@
1
1
  import logging
2
- from typing import List, Optional
3
2
 
4
3
  import torch
5
4
  import triton
6
5
 
7
- from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
8
- from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
9
- from sglang.utils import is_in_ci
6
+ from sglang.srt.utils import ceil_div, is_cuda
10
7
 
11
8
  logger = logging.getLogger(__name__)
12
9
 
@@ -130,28 +127,30 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
130
127
 
131
128
  @triton.jit
132
129
  def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
133
- expert = tl.program_id(0)
130
+ expert_id_minus_1 = tl.program_id(0) - 1
134
131
  low = 0
135
132
  high = num_toks - 1
136
133
  target_location = -1
137
134
  while low <= high:
138
135
  mid = (low + high) // 2
139
136
 
140
- if tl.load(reorder_topk_ids + mid) > expert:
137
+ if tl.load(reorder_topk_ids + mid) > expert_id_minus_1:
141
138
  high = mid - 1
142
139
  else:
143
140
  low = mid + 1
144
141
  target_location = mid
145
- tl.store(seg_indptr + expert + 1, target_location + 1)
142
+ tl.store(seg_indptr + expert_id_minus_1 + 1, target_location + 1)
146
143
 
147
144
 
148
- def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
145
+ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_local_experts: int):
149
146
  reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
150
147
 
151
- seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
148
+ seg_indptr = torch.zeros(
149
+ num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
150
+ )
152
151
  src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
153
152
 
154
- compute_seg_indptr_triton_kernel[(num_experts,)](
153
+ compute_seg_indptr_triton_kernel[(num_local_experts,)](
155
154
  reorder_topk_ids, seg_indptr, topk_ids.numel()
156
155
  )
157
156
 
@@ -164,25 +163,6 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
164
163
  return reorder_topk_ids, src2dst, seg_indptr
165
164
 
166
165
 
167
- def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
168
- reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
169
-
170
- seg_indptr = torch.zeros(
171
- local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
172
- )
173
- src2dst = torch.empty(
174
- local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
175
- )
176
-
177
- BLOCK_SIZE = 512
178
- grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
179
- compute_src2dst_triton_kernel[grid](
180
- reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
181
- )
182
-
183
- return reorder_topk_ids, src2dst, seg_indptr
184
-
185
-
186
166
  @triton.jit
187
167
  def pre_reorder_triton_kernel_for_cutlass_moe(
188
168
  input_ptr,
@@ -190,52 +170,13 @@ def pre_reorder_triton_kernel_for_cutlass_moe(
190
170
  src2dst_ptr,
191
171
  topk_ids_ptr,
192
172
  a1_scales_ptr,
193
- num_experts,
173
+ num_local_experts,
194
174
  topk,
195
175
  hidden_size,
196
176
  BLOCK_SIZE: tl.constexpr,
197
177
  ):
198
178
  OutDtype = gateup_input_ptr.dtype.element_ty
199
179
 
200
- src_idx = tl.program_id(0)
201
- src2dst_ptr = src2dst_ptr + src_idx * topk
202
- topk_ids_ptr = topk_ids_ptr + src_idx * topk
203
-
204
- src_ptr = input_ptr + src_idx * hidden_size
205
- for idx in range(topk):
206
- expert_id = tl.load(topk_ids_ptr + idx)
207
- if expert_id != num_experts:
208
- if a1_scales_ptr is not None:
209
- scale = 1.0 / tl.load(a1_scales_ptr)
210
- else:
211
- scale = 1.0
212
-
213
- dst_idx = tl.load(src2dst_ptr + idx)
214
- dst_ptr = gateup_input_ptr + dst_idx * hidden_size
215
- for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
216
- offset = start_offset + tl.arange(0, BLOCK_SIZE)
217
- mask = offset < hidden_size
218
- in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
219
- out_data = (in_data * scale).to(OutDtype)
220
- tl.store(dst_ptr + offset, out_data, mask=mask)
221
-
222
-
223
- @triton.jit
224
- def pre_reorder_triton_kernel(
225
- input_ptr,
226
- gateup_input_ptr,
227
- src2dst_ptr,
228
- topk_ids_ptr,
229
- a1_scales_ptr,
230
- start_expert_id,
231
- end_expert_id,
232
- topk,
233
- hidden_size,
234
- BLOCK_SIZE: tl.constexpr,
235
- use_per_token_if_dynamic: tl.constexpr,
236
- ):
237
- OutDtype = gateup_input_ptr.dtype.element_ty
238
-
239
180
  src_idx_int32 = tl.program_id(0)
240
181
  src_idx = src_idx_int32.to(tl.int64)
241
182
  src2dst_ptr = src2dst_ptr + src_idx * topk
@@ -244,15 +185,11 @@ def pre_reorder_triton_kernel(
244
185
 
245
186
  vec = tl.arange(0, BLOCK_SIZE)
246
187
 
247
- if a1_scales_ptr is not None and use_per_token_if_dynamic:
248
- scale = 1.0 / tl.load(a1_scales_ptr + src_idx)
249
-
250
188
  for idx in range(topk):
251
189
  expert_id = tl.load(topk_ids_ptr + idx)
252
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
190
+ if expert_id != num_local_experts:
253
191
  if a1_scales_ptr is not None:
254
- if not use_per_token_if_dynamic:
255
- scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
192
+ scale = 1.0 / tl.load(a1_scales_ptr)
256
193
  else:
257
194
  scale = 1.0
258
195
 
@@ -267,52 +204,6 @@ def pre_reorder_triton_kernel(
267
204
  tl.store(dst_ptr + offset, out_data, mask=mask)
268
205
 
269
206
 
270
- @triton.jit
271
- def silu_and_mul_triton_kernel(
272
- gateup_output,
273
- down_input,
274
- hidden_size,
275
- reorder_topk_ids,
276
- scales,
277
- start_expert_id,
278
- end_expert_id,
279
- BLOCK_SIZE: tl.constexpr,
280
- ):
281
- InDtype = gateup_output.dtype.element_ty
282
- OutDtype = down_input.dtype.element_ty
283
-
284
- half_hidden_size = hidden_size // 2
285
-
286
- pid = tl.program_id(0)
287
- expert_id = tl.load(reorder_topk_ids + pid)
288
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
289
- gateup_output_ptr = gateup_output + pid * hidden_size
290
- gate_output_ptr = gateup_output_ptr
291
- up_output_ptr = gateup_output_ptr + half_hidden_size
292
- down_input_ptr = down_input + pid * half_hidden_size
293
-
294
- if scales is not None:
295
- scale = tl.load(scales + expert_id - start_expert_id)
296
- scale = (1 / scale).to(InDtype)
297
- else:
298
- scale = 1
299
-
300
- for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
301
- offset = start_offset + tl.arange(0, BLOCK_SIZE)
302
- mask = offset < half_hidden_size
303
-
304
- gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
305
- up_output = tl.load(up_output_ptr + offset, mask=mask)
306
-
307
- # silu & mul & quantize
308
- gate_output = gate_output * tl.sigmoid(gate_output)
309
- gate_output = gate_output.to(InDtype)
310
-
311
- silu_mul_output = gate_output * up_output * scale
312
- silu_mul_output = silu_mul_output.to(OutDtype)
313
- tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
314
-
315
-
316
207
  # copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
317
208
  @triton.jit
318
209
  def _silu_and_mul_post_quant_kernel(
@@ -461,84 +352,15 @@ def silu_and_mul_masked_post_quant_fwd(
461
352
 
462
353
 
463
354
  @triton.jit
464
- def tanh(x):
465
- return 2 * tl.sigmoid(2 * x) - 1
466
-
467
-
468
- @triton.jit
469
- def gelu_and_mul_triton_kernel(
470
- gateup_output,
471
- down_input,
472
- hidden_size,
473
- reorder_topk_ids,
474
- scales,
475
- start_expert_id,
476
- end_expert_id,
477
- BLOCK_SIZE: tl.constexpr,
478
- ):
479
- InDtype = gateup_output.dtype.element_ty
480
- OutDtype = down_input.dtype.element_ty
481
-
482
- half_hidden_size = hidden_size // 2
483
-
484
- pid = tl.program_id(0)
485
- expert_id = tl.load(reorder_topk_ids + pid)
486
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
487
- gateup_output_ptr = gateup_output + pid * hidden_size
488
- gate_output_ptr = gateup_output_ptr
489
- up_output_ptr = gateup_output_ptr + half_hidden_size
490
- down_input_ptr = down_input + pid * half_hidden_size
491
-
492
- if scales is not None:
493
- scale = tl.load(scales + expert_id - start_expert_id)
494
- scale = (1 / scale).to(InDtype)
495
- else:
496
- scale = 1
497
-
498
- for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
499
- offset = start_offset + tl.arange(0, BLOCK_SIZE)
500
- mask = offset < half_hidden_size
501
-
502
- gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
503
- up_output = tl.load(up_output_ptr + offset, mask=mask)
504
-
505
- # gelu & mul & quantize
506
- # https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
507
- # sqrt(2/pi)
508
- kAlpha = 0.7978845608028654
509
- gate_output = (
510
- 0.5
511
- * gate_output
512
- * (
513
- 1
514
- + tanh(
515
- kAlpha
516
- * (
517
- gate_output
518
- + 0.044715 * gate_output * gate_output * gate_output
519
- )
520
- )
521
- )
522
- )
523
- gate_output = gate_output.to(InDtype)
524
-
525
- gelu_mul_output = gate_output * up_output * scale
526
- gelu_mul_output = gelu_mul_output.to(OutDtype)
527
- tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
528
-
529
-
530
- @triton.jit
531
- def post_reorder_triton_kernel(
355
+ def post_reorder_triton_kernel_for_cutlass_moe(
532
356
  down_output_ptr,
533
357
  output_ptr,
534
358
  src2dst_ptr,
535
359
  topk_ids_ptr,
536
360
  topk_weights_ptr,
537
- start_expert_id,
538
- end_expert_id,
539
361
  topk,
362
+ num_local_experts,
540
363
  hidden_size,
541
- dst_start,
542
364
  BLOCK_SIZE: tl.constexpr,
543
365
  ):
544
366
  InDtype = down_output_ptr.dtype.element_ty
@@ -549,7 +371,6 @@ def post_reorder_triton_kernel(
549
371
  topk_ids_ptr = topk_ids_ptr + src_idx * topk
550
372
  topk_weights_ptr = topk_weights_ptr + src_idx * topk
551
373
 
552
- computed = False
553
374
  store_ptr = output_ptr + src_idx * hidden_size
554
375
 
555
376
  vec = tl.arange(0, BLOCK_SIZE)
@@ -561,37 +382,25 @@ def post_reorder_triton_kernel(
561
382
  sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
562
383
  for idx in range(topk):
563
384
  expert_id = tl.load(topk_ids_ptr + idx)
564
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
565
- computed = True
385
+ if expert_id != num_local_experts:
566
386
  dst_idx_int32 = tl.load(src2dst_ptr + idx)
567
387
  dst_idx = dst_idx_int32.to(tl.int64)
568
- dst_idx = dst_idx - dst_start
569
388
  weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
570
389
  load_ptr = down_output_ptr + dst_idx * hidden_size
571
390
  in_data = tl.load(load_ptr + offset, mask=mask)
572
391
  sum_vec += in_data * weigh_scale
573
392
  tl.store(store_ptr + offset, sum_vec, mask=mask)
574
393
 
575
- if computed == False:
576
- for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
577
- offset = start_offset + vec
578
- mask = offset < hidden_size
579
- tl.store(
580
- store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
581
- )
582
-
583
394
 
584
395
  @triton.jit
585
- def post_reorder_triton_kernel_for_cutlass_moe(
396
+ def post_reorder_triton_kernel(
586
397
  down_output_ptr,
587
398
  output_ptr,
588
399
  src2dst_ptr,
589
400
  topk_ids_ptr,
590
401
  topk_weights_ptr,
591
- num_experts,
592
402
  topk,
593
403
  hidden_size,
594
- dst_start,
595
404
  BLOCK_SIZE: tl.constexpr,
596
405
  ):
597
406
  InDtype = down_output_ptr.dtype.element_ty
@@ -613,10 +422,9 @@ def post_reorder_triton_kernel_for_cutlass_moe(
613
422
  sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
614
423
  for idx in range(topk):
615
424
  expert_id = tl.load(topk_ids_ptr + idx)
616
- if expert_id != num_experts:
425
+ if expert_id > 0:
617
426
  dst_idx_int32 = tl.load(src2dst_ptr + idx)
618
427
  dst_idx = dst_idx_int32.to(tl.int64)
619
- dst_idx = dst_idx - dst_start
620
428
  weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
621
429
  load_ptr = down_output_ptr + dst_idx * hidden_size
622
430
  in_data = tl.load(load_ptr + offset, mask=mask)
@@ -624,232 +432,6 @@ def post_reorder_triton_kernel_for_cutlass_moe(
624
432
  tl.store(store_ptr + offset, sum_vec, mask=mask)
625
433
 
626
434
 
627
- @triton.jit
628
- def compute_m_range(
629
- pid,
630
- batch_size,
631
- seg_indptr,
632
- weight_indices,
633
- m_num_tiles_indptr,
634
- BLOCK_SIZE_M: tl.constexpr,
635
- ):
636
- idx = 0
637
- for bs in range(batch_size):
638
- tiles = tl.load(m_num_tiles_indptr + bs)
639
- if pid >= tiles:
640
- idx = bs
641
-
642
- idx_start = tl.load(m_num_tiles_indptr + idx)
643
-
644
- m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
645
- m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
646
- expert_id = tl.load(weight_indices + idx)
647
- return m_range_start, m_range_end, expert_id
648
-
649
-
650
- @triton.jit
651
- def grouped_gemm_triton_kernel(
652
- a,
653
- b,
654
- c,
655
- batch_size,
656
- N,
657
- K,
658
- seg_indptr,
659
- weight_indices,
660
- m_num_tiles_indptr,
661
- scale_a,
662
- scale_b,
663
- use_fp8_w8a8: tl.constexpr,
664
- group_n: tl.constexpr,
665
- group_k: tl.constexpr,
666
- a_stride_0: tl.constexpr,
667
- b_stride_0: tl.constexpr,
668
- b_stride_1: tl.constexpr,
669
- as_stride_0: tl.constexpr,
670
- as_stride_1: tl.constexpr,
671
- bs_stride_0: tl.constexpr,
672
- bs_stride_2: tl.constexpr,
673
- bs_stride_1: tl.constexpr,
674
- use_per_token_if_dynamic: tl.constexpr,
675
- BLOCK_SIZE_M: tl.constexpr,
676
- BLOCK_SIZE_N: tl.constexpr,
677
- BLOCK_SIZE_K: tl.constexpr,
678
- ):
679
- c_dtype = c.dtype.element_ty
680
-
681
- pid_m = tl.program_id(0)
682
- pid_n = tl.program_id(1)
683
- total_m_block = tl.load(m_num_tiles_indptr + batch_size)
684
- if pid_m >= total_m_block:
685
- return
686
-
687
- m_range_start, m_range_end, expert_id = compute_m_range(
688
- pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
689
- )
690
- if m_range_end - m_range_start == 0:
691
- return
692
-
693
- n_range_start = pid_n * BLOCK_SIZE_N
694
- n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
695
-
696
- offs_am = tl.arange(0, BLOCK_SIZE_M)
697
- offs_bn = tl.arange(0, BLOCK_SIZE_N)
698
-
699
- offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
700
- offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
701
- offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
702
- offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
703
- offs_k = tl.arange(0, BLOCK_SIZE_K)
704
-
705
- a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
706
- b_ptr = b + (
707
- (expert_id * b_stride_0)
708
- + (n_range_start + offs_bn[:, None]) * b_stride_1
709
- + offs_k[None, :]
710
- )
711
-
712
- if group_k > 0 and group_n > 0:
713
- a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
714
- offs_bsn = (n_range_start + offs_bn) // group_n
715
- b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
716
-
717
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
718
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
719
- a_tile = tl.load(
720
- a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
721
- )
722
- b_tile = tl.load(
723
- b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
724
- )
725
-
726
- if group_k > 0 and group_n > 0:
727
- k_start = k * BLOCK_SIZE_K
728
- offs_ks = k_start // group_k
729
- a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1)
730
- b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2)
731
- accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :]
732
- else:
733
- accumulator = tl.dot(a_tile, b_tile.T, accumulator)
734
- a_ptr += BLOCK_SIZE_K
735
- b_ptr += BLOCK_SIZE_K
736
-
737
- if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
738
- if use_per_token_if_dynamic:
739
- scale_a_value = tl.load(scale_a + (m_range_start + offs_am[:, None]))
740
- else:
741
- scale_a_value = tl.load(scale_a + expert_id)
742
- scale_b_value = tl.load(scale_b + expert_id)
743
- accumulator *= scale_a_value * scale_b_value
744
-
745
- c_tile = accumulator.to(c_dtype)
746
-
747
- offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
748
- offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
749
- c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
750
- c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
751
- tl.store(c_ptr, c_tile, mask=c_mask)
752
-
753
-
754
- @triton.jit
755
- def compute_m_num_tiles_indptr(
756
- m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
757
- ):
758
- for bs in range(batch_size):
759
- m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
760
- cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
761
- pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
762
- tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
763
-
764
-
765
- def grouped_gemm_triton(
766
- a: torch.Tensor,
767
- b: torch.Tensor,
768
- c: torch.Tensor,
769
- batch_size: int,
770
- weight_column_major: bool,
771
- seg_indptr: Optional[torch.Tensor] = None,
772
- weight_indices: Optional[torch.Tensor] = None,
773
- use_fp8_w8a8: bool = False,
774
- scale_a: torch.Tensor = None,
775
- scale_b: torch.Tensor = None,
776
- block_shape: Optional[List[int]] = None,
777
- c_dtype=None,
778
- use_per_token_if_dynamic: bool = True,
779
- ):
780
- assert weight_column_major == True # TODO: more
781
- if use_fp8_w8a8 and block_shape is None:
782
- assert scale_a is not None and scale_b is not None
783
-
784
- if block_shape is not None:
785
- a_original = a
786
-
787
- assert len(block_shape) == 2
788
- block_n, block_k = block_shape[0], block_shape[1]
789
- a, scale_a = per_token_group_quant_fp8(a, block_k)
790
-
791
- assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1]
792
- assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
793
- assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
794
-
795
- dispose_tensor(a_original)
796
-
797
- # TODO: adjust config or tune kernel
798
- # Reduce block size to prevent L40 shared memory overflow.
799
- config = {
800
- "BLOCK_SIZE_M": 64,
801
- "BLOCK_SIZE_N": 32,
802
- "BLOCK_SIZE_K": 128,
803
- }
804
-
805
- m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
806
- compute_m_num_tiles_indptr[(1,)](
807
- m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
808
- )
809
-
810
- if c is None:
811
- assert c_dtype is not None
812
- c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype)
813
-
814
- grid = lambda META: (
815
- triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
816
- triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
817
- )
818
-
819
- if use_fp8_w8a8 and block_shape is None and use_per_token_if_dynamic:
820
- assert (
821
- scale_a.shape[0] == a.shape[0]
822
- ), f"scale_a.shape: {scale_a.shape}, a.shape: {a.shape}"
823
-
824
- grouped_gemm_triton_kernel[grid](
825
- a,
826
- b,
827
- c,
828
- batch_size,
829
- b.size(1),
830
- b.size(2),
831
- seg_indptr,
832
- weight_indices,
833
- m_num_tiles_indptr,
834
- scale_a,
835
- scale_b,
836
- use_fp8_w8a8,
837
- 0 if block_shape is None else block_shape[0],
838
- 0 if block_shape is None else block_shape[1],
839
- a.stride(0),
840
- b.stride(0),
841
- b.stride(1),
842
- scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0,
843
- scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0,
844
- scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0,
845
- scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0,
846
- scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0,
847
- use_per_token_if_dynamic,
848
- **config,
849
- )
850
- return c
851
-
852
-
853
435
  @triton.jit
854
436
  def _fwd_kernel_ep_scatter_1(
855
437
  num_recv_tokens_per_expert,
@@ -984,7 +566,9 @@ def ep_scatter(
984
566
  scale_hidden_size = ceil_div(scale_hidden_size, 4)
985
567
 
986
568
  assert m_indices.shape[0] % BLOCK_E == 0
987
- assert recv_x_scale.dtype == output_tensor_scale.dtype
569
+ assert (
570
+ recv_x_scale.dtype == output_tensor_scale.dtype
571
+ ), f"recv_x_scale.dtype: {recv_x_scale.dtype}, output_tensor_scale.dtype: {output_tensor_scale.dtype}"
988
572
  assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size
989
573
 
990
574
  _fwd_kernel_ep_scatter_1[(grid,)](
@@ -1234,7 +818,7 @@ def deepgemm_compute_src2dst_triton_kernel(
1234
818
  mask = dst_id < num_toks
1235
819
  src_id = tl.load(reorder_ids + dst_id, mask=mask)
1236
820
  expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
1237
- expert_dst_start = tl.load(seg_indptr + expert_id)
821
+ expert_dst_start = tl.load(seg_indptr + expert_id, mask=(expert_id >= 0))
1238
822
  expert_dst_offset = dst_id - expert_dst_start
1239
823
  dst_id = expert_id * m_max + expert_dst_offset
1240
824
  tl.store(src2dst + src_id, dst_id, mask=mask)
@@ -1248,10 +832,7 @@ def fill_gateup_input_triton_kernel(
1248
832
  gateup_input_scale_ptr,
1249
833
  src2dst_ptr,
1250
834
  topk_ids_ptr,
1251
- start_expert_id,
1252
- end_expert_id,
1253
835
  topk,
1254
- m_max,
1255
836
  hidden_size,
1256
837
  scale_size,
1257
838
  BLOCK_SIZE: tl.constexpr,
@@ -1267,10 +848,9 @@ def fill_gateup_input_triton_kernel(
1267
848
  vec = tl.arange(0, BLOCK_SIZE)
1268
849
  for idx in range(topk):
1269
850
  expert_id = tl.load(topk_ids_ptr + idx)
1270
- if expert_id >= start_expert_id and expert_id <= end_expert_id:
851
+ if expert_id >= 0:
1271
852
  dst_idx_int32 = tl.load(src2dst_ptr + idx)
1272
853
  dst_idx = dst_idx_int32.to(tl.int64)
1273
- dst_idx = dst_idx - start_expert_id * m_max
1274
854
  dst_ptr = gateup_input_ptr + dst_idx * hidden_size
1275
855
  for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
1276
856
  offset = start_offset + vec
@@ -1287,31 +867,31 @@ def fill_gateup_input_triton_kernel(
1287
867
 
1288
868
  def moe_ep_deepgemm_preprocess(
1289
869
  topk_ids: torch.Tensor,
1290
- num_experts: int,
870
+ num_local_experts: int,
1291
871
  hidden_states: torch.Tensor,
1292
872
  top_k: int,
1293
- start_expert_id,
1294
- end_expert_id,
1295
873
  block_shape,
1296
874
  output_dtype: torch.dtype = torch.float8_e4m3fn,
1297
875
  ):
1298
876
  reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
1299
- seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
877
+ seg_indptr = torch.zeros(
878
+ num_local_experts + 1, device=topk_ids.device, dtype=torch.int64
879
+ )
1300
880
  src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
1301
- masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
881
+ masked_m = torch.empty(num_local_experts, device=topk_ids.device, dtype=torch.int32)
1302
882
 
1303
- compute_seg_indptr_triton_kernel[(num_experts,)](
883
+ compute_seg_indptr_triton_kernel[(num_local_experts + 1,)](
1304
884
  reorder_topk_ids, seg_indptr, topk_ids.numel()
1305
885
  )
1306
886
 
1307
887
  grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
1308
- compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
888
+ compute_masked_m_triton_kernel[(num_local_experts,)](seg_indptr, masked_m)
1309
889
 
1310
890
  # For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
1311
- m_max = (hidden_states.size(0) + 255) // 256 * 256
1312
- expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
891
+ m_max = (hidden_states.size(0) // 256 + 1) * 256
892
+ expected_m = (topk_ids.numel() - 1) // num_local_experts + 1
1313
893
  gateup_input = torch.empty(
1314
- (int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
894
+ (num_local_experts, m_max, hidden_states.size(1)),
1315
895
  device=hidden_states.device,
1316
896
  dtype=output_dtype,
1317
897
  )
@@ -1330,6 +910,8 @@ def moe_ep_deepgemm_preprocess(
1330
910
  block_shape = [128, 128]
1331
911
  assert len(block_shape) == 2
1332
912
  block_n, block_k = block_shape[0], block_shape[1]
913
+
914
+ # TODO: fuse this with the preprocess
1333
915
  hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
1334
916
 
1335
917
  gateup_input_scale = torch.empty(
@@ -1345,18 +927,14 @@ def moe_ep_deepgemm_preprocess(
1345
927
  gateup_input_scale,
1346
928
  src2dst,
1347
929
  topk_ids,
1348
- start_expert_id,
1349
- end_expert_id,
1350
930
  top_k,
1351
- m_max,
1352
931
  hidden_states.size(1),
1353
932
  scale.size(1),
1354
933
  BLOCK_SIZE=1024,
1355
934
  )
1356
935
 
1357
936
  return (
1358
- m_max,
1359
- masked_m[start_expert_id : (end_expert_id + 1)],
937
+ masked_m,
1360
938
  expected_m,
1361
939
  src2dst,
1362
940
  gateup_input,