sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,441 @@
1
+ """
2
+ MMMU evaluation for VLMs using the run_eval simple-evals interface.
3
+
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import base64
9
+ import io
10
+ from typing import List, Optional, Tuple
11
+
12
+ from datasets import concatenate_datasets, load_dataset
13
+ from PIL import Image
14
+
15
+ from sglang.test import simple_eval_common as common
16
+ from sglang.test.simple_eval_common import (
17
+ HTML_JINJA,
18
+ Eval,
19
+ EvalResult,
20
+ SamplerBase,
21
+ SingleEvalResult,
22
+ map_with_progress,
23
+ )
24
+
25
+
26
+ class MMMUVLMEval(Eval):
27
+ DOMAIN_CAT2SUB_CAT = {
28
+ "Art and Design": ["Art", "Art_Theory", "Design", "Music"],
29
+ "Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"],
30
+ "Science": ["Biology", "Chemistry", "Geography", "Math", "Physics"],
31
+ "Health and Medicine": [
32
+ "Basic_Medical_Science",
33
+ "Clinical_Medicine",
34
+ "Diagnostics_and_Laboratory_Medicine",
35
+ "Pharmacy",
36
+ "Public_Health",
37
+ ],
38
+ "Humanities and Social Science": [
39
+ "History",
40
+ "Literature",
41
+ "Sociology",
42
+ "Psychology",
43
+ ],
44
+ "Tech and Engineering": [
45
+ "Agriculture",
46
+ "Architecture_and_Engineering",
47
+ "Computer_Science",
48
+ "Electronics",
49
+ "Energy_and_Power",
50
+ "Materials",
51
+ "Mechanical_Engineering",
52
+ ],
53
+ }
54
+
55
+ def __init__(
56
+ self, num_examples: Optional[int] = 100, num_threads: int = 32, seed: int = 42
57
+ ):
58
+ """Create MMMU VLM eval (Math subset, 100 fixed samples by default)."""
59
+ self.num_examples = num_examples
60
+ self.num_threads = num_threads
61
+ self.seed = seed
62
+ # Prepare samples deterministically across all MMMU subjects (validation split)
63
+ self.samples = self._prepare_mmmu_samples(self.num_examples)
64
+
65
+ @staticmethod
66
+ def _to_data_uri(image: Image.Image) -> str:
67
+ if image.mode == "RGBA":
68
+ image = image.convert("RGB")
69
+ buf = io.BytesIO()
70
+ image.save(buf, format="PNG")
71
+ b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
72
+ return f"data:image/png;base64,{b64}"
73
+
74
+ @staticmethod
75
+ def _build_mc_mapping(options: List[str]) -> Tuple[dict, List[str]]:
76
+ index2ans = {}
77
+ all_choices = []
78
+ ch = ord("A")
79
+ for opt in options:
80
+ letter = chr(ch)
81
+ index2ans[letter] = opt
82
+ all_choices.append(letter)
83
+ ch += 1
84
+ return index2ans, all_choices
85
+
86
+ def _prepare_mmmu_samples(self, k: int) -> List[dict]:
87
+ # Subjects and domains copied from MMMU data_utils to categorize results
88
+ subjects: List[str] = []
89
+ for subs in self.DOMAIN_CAT2SUB_CAT.values():
90
+ subjects.extend(subs)
91
+
92
+ # Load validation split of each subject
93
+ datasets = []
94
+ for subj in subjects:
95
+ try:
96
+ d = load_dataset("MMMU/MMMU", subj, split="validation")
97
+ # attach subject info via transform
98
+ d = d.add_column("__subject__", [subj] * len(d))
99
+ datasets.append(d)
100
+ except Exception:
101
+ continue
102
+ if not datasets:
103
+ raise RuntimeError("Failed to load MMMU datasets")
104
+
105
+ merged = concatenate_datasets(datasets)
106
+
107
+ # Deterministic selection: sort by id (fallback to subject+index)
108
+ def _key(idx):
109
+ ex = merged[idx]
110
+ return str(ex.get("id", f"{ex['__subject__']}:{idx}"))
111
+
112
+ order = sorted(range(len(merged)), key=_key)
113
+ picked_indices = order[:k]
114
+
115
+ samples: List[dict] = []
116
+ for idx in picked_indices:
117
+ ex = merged[idx]
118
+ subject = ex["__subject__"]
119
+ image = ex.get("image_1")
120
+ if image is None or not hasattr(image, "convert"):
121
+ continue
122
+ data_uri = self._to_data_uri(image)
123
+ question = ex.get("question", "")
124
+ answer = ex.get("answer")
125
+ raw_options = ex.get("options")
126
+ question_type = "open"
127
+ index2ans = None
128
+ all_choices = None
129
+ options = None
130
+ if raw_options:
131
+ try:
132
+ options = (
133
+ raw_options
134
+ if isinstance(raw_options, list)
135
+ else list(eval(raw_options))
136
+ )
137
+ if isinstance(options, list) and len(options) > 0:
138
+ index2ans, all_choices = self._build_mc_mapping(options)
139
+ question_type = "multiple-choice"
140
+ except Exception:
141
+ options = None
142
+
143
+ # Build final textual prompt; include choices if MC
144
+ prompt_text = f"Question: {question}\n\n"
145
+ if options:
146
+ letters = [chr(ord("A") + i) for i in range(len(options))]
147
+ for letter, opt in zip(letters, options):
148
+ prompt_text += f"{letter}) {opt}\n"
149
+ prompt_text += "\nAnswer: "
150
+
151
+ samples.append(
152
+ {
153
+ "id": ex.get("id", f"{subject}:{idx}"),
154
+ "final_input_prompt": prompt_text,
155
+ "image_data": data_uri,
156
+ "answer": answer,
157
+ "question_type": question_type,
158
+ "index2ans": index2ans,
159
+ "all_choices": all_choices,
160
+ "category": subject,
161
+ }
162
+ )
163
+
164
+ return samples
165
+
166
+ @staticmethod
167
+ def _split_prompt_for_image(prompt: str) -> tuple[str, str]:
168
+ """Split a prompt containing an inline image tag into prefix and suffix.
169
+
170
+ If no tag is present, treat the whole prompt as prefix and empty suffix.
171
+ """
172
+ if "<" in prompt and ">" in prompt:
173
+ prefix = prompt.split("<")[0]
174
+ suffix = prompt.split(">", 1)[1]
175
+ return prefix, suffix
176
+ return prompt, ""
177
+
178
+ @staticmethod
179
+ def build_chat_messages_from_prompt(prompt: str, image_data) -> List:
180
+ """Split a prompt containing an inline image tag into prefix and suffix.
181
+
182
+ If no tag is present, treat the whole prompt as prefix and empty suffix.
183
+ """
184
+ # Build a vision+text message for OpenAI-compatible API
185
+ prefix, suffix = MMMUVLMEval._split_prompt_for_image(prompt)
186
+
187
+ content: List[dict] = []
188
+ if prefix:
189
+ content.append({"type": "text", "text": prefix})
190
+ content.append({"type": "image_url", "image_url": {"url": image_data}})
191
+ if suffix:
192
+ content.append({"type": "text", "text": suffix})
193
+ prompt_messages = [{"role": "user", "content": content}]
194
+
195
+ return prompt_messages
196
+
197
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
198
+ def fn(sample: dict):
199
+ prompt = sample["final_input_prompt"]
200
+ image_data = sample["image_data"]
201
+ prompt_messages = MMMUVLMEval.build_chat_messages_from_prompt(
202
+ prompt, image_data
203
+ )
204
+
205
+ # Sample
206
+ response_text = sampler(prompt_messages)
207
+
208
+ # Parse and score
209
+ gold = sample["answer"]
210
+ if (
211
+ sample["question_type"] == "multiple-choice"
212
+ and sample["all_choices"]
213
+ and sample["index2ans"]
214
+ ):
215
+ pred = _parse_multi_choice_response(
216
+ response_text, sample["all_choices"], sample["index2ans"]
217
+ )
218
+ score = 1.0 if (gold is not None and pred == gold) else 0.0
219
+ extracted_answer = pred
220
+ else:
221
+ parsed_list = _parse_open_response(response_text)
222
+ score = (
223
+ 1.0 if (gold is not None and _eval_open(gold, parsed_list)) else 0.0
224
+ )
225
+ extracted_answer = ", ".join(map(str, parsed_list))
226
+
227
+ html_rendered = common.jinja_env.from_string(HTML_JINJA).render(
228
+ prompt_messages=prompt_messages,
229
+ next_message=dict(content=response_text, role="assistant"),
230
+ score=score,
231
+ correct_answer=gold,
232
+ extracted_answer=extracted_answer,
233
+ )
234
+
235
+ convo = prompt_messages + [dict(content=response_text, role="assistant")]
236
+ return SingleEvalResult(
237
+ html=html_rendered,
238
+ score=score,
239
+ metrics={"__category__": sample["category"]},
240
+ convo=convo,
241
+ )
242
+
243
+ results = map_with_progress(fn, self.samples, self.num_threads)
244
+
245
+ # Build category table and overall accuracy
246
+ # Gather per-sample correctness and category
247
+ per_cat_total: dict[str, int] = {}
248
+ per_cat_correct: dict[str, int] = {}
249
+ htmls = []
250
+ convos = []
251
+ scores: List[float] = []
252
+ for r in results:
253
+ # __category__ stored under metrics
254
+ cat = r.metrics.get("__category__") if r.metrics else None
255
+ if cat is None:
256
+ cat = "Unknown"
257
+ per_cat_total[cat] = per_cat_total.get(cat, 0) + 1
258
+ if r.score:
259
+ per_cat_correct[cat] = per_cat_correct.get(cat, 0) + 1
260
+ htmls.append(r.html)
261
+ convos.append(r.convo)
262
+ if r.score is not None:
263
+ scores.append(r.score)
264
+
265
+ evaluation_result = {}
266
+ for cat, tot in per_cat_total.items():
267
+ corr = per_cat_correct.get(cat, 0)
268
+ acc = (corr / tot) if tot > 0 else 0.0
269
+ evaluation_result[cat] = {"acc": round(acc, 3), "num_example": tot}
270
+
271
+ printable_results = {}
272
+ # Domains first
273
+ for domain, cats in self.DOMAIN_CAT2SUB_CAT.items():
274
+ acc_sum = 0.0
275
+ num_sum = 0
276
+ for cat in cats:
277
+ if cat in evaluation_result:
278
+ acc_sum += (
279
+ evaluation_result[cat]["acc"]
280
+ * evaluation_result[cat]["num_example"]
281
+ )
282
+ num_sum += evaluation_result[cat]["num_example"]
283
+ if num_sum > 0:
284
+ printable_results[f"Overall-{domain}"] = {
285
+ "num": num_sum,
286
+ "acc": round(acc_sum / num_sum, 3),
287
+ }
288
+ # add each sub-category row if present
289
+ for cat in cats:
290
+ if cat in evaluation_result:
291
+ printable_results[cat] = {
292
+ "num": evaluation_result[cat]["num_example"],
293
+ "acc": evaluation_result[cat]["acc"],
294
+ }
295
+
296
+ # Overall
297
+ total_num = sum(v["num_example"] for v in evaluation_result.values())
298
+ overall_acc = (
299
+ sum(v["acc"] * v["num_example"] for v in evaluation_result.values())
300
+ / total_num
301
+ if total_num > 0
302
+ else 0.0
303
+ )
304
+ printable_results["Overall"] = {"num": total_num, "acc": round(overall_acc, 3)}
305
+
306
+ # Build EvalResult
307
+ return EvalResult(
308
+ score=overall_acc, metrics=printable_results, htmls=htmls, convos=convos
309
+ )
310
+
311
+
312
+ def _parse_multi_choice_response(
313
+ response: str, all_choices: List[str], index2ans: dict
314
+ ) -> str:
315
+ # loosely adapted from benchmark mmmu eval
316
+ for char in [",", ".", "!", "?", ";", ":", "'"]:
317
+ response = response.strip(char)
318
+ response = " " + response + " "
319
+
320
+ # Prefer explicit letter with bracket e.g. (A)
321
+ candidates: List[str] = []
322
+ for choice in all_choices:
323
+ if f"({choice})" in response:
324
+ candidates.append(choice)
325
+ if not candidates:
326
+ for choice in all_choices:
327
+ if f" {choice} " in response:
328
+ candidates.append(choice)
329
+ if not candidates and len(response.split()) > 5:
330
+ # try match by option text
331
+ for idx, ans in index2ans.items():
332
+ if ans and ans.lower() in response.lower():
333
+ candidates.append(idx)
334
+ if not candidates:
335
+ # fallback to first choice
336
+ return all_choices[0]
337
+ if len(candidates) == 1:
338
+ return candidates[0]
339
+ # choose the last occurrence
340
+ starts = []
341
+ for can in candidates:
342
+ pos = response.rfind(f"({can})")
343
+ if pos == -1:
344
+ pos = response.rfind(f" {can} ")
345
+ if pos == -1 and index2ans.get(can):
346
+ pos = response.lower().rfind(index2ans[can].lower())
347
+ starts.append(pos)
348
+ return candidates[int(max(range(len(starts)), key=lambda i: starts[i]))]
349
+
350
+
351
+ def _check_is_number(s: str) -> bool:
352
+ try:
353
+ float(s.replace(",", ""))
354
+ return True
355
+ except Exception:
356
+ return False
357
+
358
+
359
+ def _normalize_str(s: str):
360
+ s = s.strip()
361
+ if _check_is_number(s):
362
+ s = s.replace(",", "")
363
+ try:
364
+ v = round(float(s), 2)
365
+ return [v]
366
+ except Exception:
367
+ return [s.lower()]
368
+ return [s.lower()] if len(s) > 1 else [" " + s, s + " "]
369
+
370
+
371
+ def _extract_numbers(s: str) -> List[str]:
372
+ import re as _re
373
+
374
+ pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b"
375
+ pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
376
+ pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
377
+ return (
378
+ _re.findall(pattern_commas, s)
379
+ + _re.findall(pattern_scientific, s)
380
+ + _re.findall(pattern_simple, s)
381
+ )
382
+
383
+
384
+ def _parse_open_response(response: str) -> List[str]:
385
+ import re as _re
386
+
387
+ def get_key_subresponses(resp: str) -> List[str]:
388
+ resp = resp.strip().strip(".").lower()
389
+ subs = _re.split(r"\.\s(?=[A-Z])|\n", resp)
390
+ indicators = [
391
+ "could be ",
392
+ "so ",
393
+ "is ",
394
+ "thus ",
395
+ "therefore ",
396
+ "final ",
397
+ "answer ",
398
+ "result ",
399
+ ]
400
+ keys = []
401
+ for i, s in enumerate(subs):
402
+ cands = [*indicators]
403
+ if i == len(subs) - 1:
404
+ cands.append("=")
405
+ shortest = None
406
+ for ind in cands:
407
+ if ind in s:
408
+ part = s.split(ind)[-1].strip()
409
+ if not shortest or len(part) < len(shortest):
410
+ shortest = part
411
+ if shortest and shortest not in [":", ",", ".", "!", "?", ";", ":", "'"]:
412
+ keys.append(shortest)
413
+ return keys or [resp]
414
+
415
+ key_resps = get_key_subresponses(response)
416
+ pred_list = key_resps.copy()
417
+ for r in key_resps:
418
+ pred_list.extend(_extract_numbers(r))
419
+ out = []
420
+ for x in pred_list:
421
+ out.extend(_normalize_str(x))
422
+ # dedup
423
+ return list(dict.fromkeys(out))
424
+
425
+
426
+ def _eval_open(gold, preds: List[str]) -> bool:
427
+ if isinstance(gold, list):
428
+ norm_answers = []
429
+ for ans in gold:
430
+ norm_answers.extend(_normalize_str(ans))
431
+ else:
432
+ norm_answers = _normalize_str(gold)
433
+ for p in preds:
434
+ if isinstance(p, str):
435
+ for na in norm_answers:
436
+ if isinstance(na, str) and na in p:
437
+ return True
438
+ else:
439
+ if p in norm_answers:
440
+ return True
441
+ return False
@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
621
621
  w_s,
622
622
  )
623
623
 
624
- from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
624
+ from deep_gemm import fp8_m_grouped_gemm_nt_masked
625
625
 
626
626
  with torch.inference_mode():
627
627
  ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
628
- m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m)
628
+ fp8_m_grouped_gemm_nt_masked(lhs, rhs, oe, masked_m, expected_m)
629
629
  out = oe[:, :M, :]
630
630
 
631
631
  self.assertTrue(
@@ -9,6 +9,7 @@ from transformers import AutoConfig
9
9
  from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
10
10
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
11
11
  from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
12
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
12
13
 
13
14
 
14
15
  # Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
@@ -21,7 +22,7 @@ def calc_diff(x, y):
21
22
 
22
23
  def get_model_config(tp_size: int):
23
24
  config = AutoConfig.from_pretrained(
24
- "deepseek-ai/deepseek-R1", trust_remote_code=True
25
+ "deepseek-ai/Deepseek-R1", trust_remote_code=True
25
26
  )
26
27
  E = config.n_routed_experts
27
28
  topk = config.num_experts_per_tok
@@ -152,14 +153,31 @@ def run_test(tp_size, batch_size, model_config, check=False):
152
153
  problem_sizes2,
153
154
  )
154
155
 
156
+ topk_output = StandardTopKOutput(
157
+ topk_weights=topk_weights,
158
+ topk_ids=topk_ids,
159
+ router_logits=torch.randn(
160
+ (batch_size, topk), device=topk_weights.device, dtype=dtype
161
+ ),
162
+ )
163
+
164
+ moe_runner_config = MoeRunnerConfig(
165
+ num_experts=E,
166
+ top_k=topk,
167
+ hidden_size=H,
168
+ intermediate_size_per_partition=I,
169
+ params_dtype=dtype,
170
+ activation="silu",
171
+ inplace=False,
172
+ )
173
+
155
174
  # Note: Triton expects non-transposed weights
156
- moe_config = MoeRunnerConfig(inplace=False)
157
175
  triton_lambda = lambda: fused_experts(
158
176
  x,
159
177
  w1,
160
178
  w2,
161
- (topk_weights, topk_ids, "dummy"),
162
- moe_config,
179
+ topk_output,
180
+ moe_runner_config,
163
181
  use_fp8_w8a8=True,
164
182
  w1_scale=w1_scale,
165
183
  w2_scale=w2_scale,
@@ -224,8 +242,8 @@ def run_test(tp_size, batch_size, model_config, check=False):
224
242
  x,
225
243
  w1, # Original shape
226
244
  w2, # Original shape
227
- (topk_weights, topk_ids, "dummy"),
228
- moe_config,
245
+ topk_output,
246
+ moe_runner_config,
229
247
  use_fp8_w8a8=True,
230
248
  w1_scale=w1_scale,
231
249
  w2_scale=w2_scale,
@@ -120,7 +120,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
120
120
  )
121
121
  topk_weights, topk_ids, _ = topk_output
122
122
  expert_map = torch.arange(E, dtype=torch.int32, device=device)
123
- expert_map[local_e:] = E
123
+ expert_map[local_e:] = -1
124
124
 
125
125
  output = cutlass_moe(
126
126
  a,
@@ -138,9 +138,7 @@ def test_cutlass_w4a8_moe(M, N, K, E, tp_size, use_ep_moe, topk, group_size, dty
138
138
  c_strides2,
139
139
  s_strides13,
140
140
  s_strides2,
141
- 0,
142
- local_e - 1,
143
- E,
141
+ local_e,
144
142
  a1_scale,
145
143
  a2_scale,
146
144
  expert_map,
@@ -178,7 +176,7 @@ def cutlass_moe(
178
176
  w1_scale: torch.Tensor,
179
177
  w2_scale: torch.Tensor,
180
178
  topk_weights: torch.Tensor,
181
- topk_ids_: torch.Tensor,
179
+ topk_ids: torch.Tensor,
182
180
  a_strides1: torch.Tensor,
183
181
  b_strides1: torch.Tensor,
184
182
  c_strides1: torch.Tensor,
@@ -187,40 +185,32 @@ def cutlass_moe(
187
185
  c_strides2: torch.Tensor,
188
186
  s_strides13: torch.Tensor,
189
187
  s_strides2: torch.Tensor,
190
- start_expert_id: int,
191
- end_expert_id: int,
192
- E: int,
188
+ num_local_experts: int,
193
189
  a1_scale: Optional[torch.Tensor] = None,
194
190
  a2_scale: Optional[torch.Tensor] = None,
195
191
  expert_map: Optional[torch.Tensor] = None,
196
192
  apply_router_weight_on_input: bool = False,
197
193
  ):
198
- local_topk_ids = topk_ids_
199
- local_topk_ids = torch.where(expert_map[topk_ids_] != E, expert_map[topk_ids_], E)
194
+ topk_ids = expert_map[topk_ids]
200
195
  device = a.device
201
196
 
202
- local_num_experts = end_expert_id - start_expert_id + 1
203
197
  expert_offsets = torch.empty(
204
- (local_num_experts + 1), dtype=torch.int32, device=device
198
+ (num_local_experts + 1), dtype=torch.int32, device=device
205
199
  )
206
200
  problem_sizes1 = torch.empty(
207
- (local_num_experts, 3), dtype=torch.int32, device=device
201
+ (num_local_experts, 3), dtype=torch.int32, device=device
208
202
  )
209
203
  problem_sizes2 = torch.empty(
210
- (local_num_experts, 3), dtype=torch.int32, device=device
204
+ (num_local_experts, 3), dtype=torch.int32, device=device
211
205
  )
212
206
  return cutlass_w4a8_moe(
213
- start_expert_id,
214
- end_expert_id,
215
- E,
216
207
  a,
217
208
  w1_q,
218
209
  w2_q,
219
210
  w1_scale,
220
211
  w2_scale,
221
212
  topk_weights,
222
- topk_ids_,
223
- local_topk_ids,
213
+ topk_ids,
224
214
  a_strides1,
225
215
  b_strides1,
226
216
  c_strides1,