sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,194 @@
1
+ """Triton kernels MoE runner backend skeleton."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, Optional
7
+
8
+ import torch
9
+
10
+ from sglang.srt.layers.moe.moe_runner.base import (
11
+ MoeQuantInfo,
12
+ MoeRunnerConfig,
13
+ MoeRunnerCore,
14
+ RunnerInput,
15
+ RunnerOutput,
16
+ register_post_permute,
17
+ register_pre_permute,
18
+ )
19
+ from sglang.srt.layers.moe.utils import MoeRunnerBackend
20
+
21
+ if TYPE_CHECKING:
22
+ from triton_kernels.matmul_ogs import PrecisionConfig
23
+ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
24
+
25
+ from sglang.srt.layers.moe.token_dispatcher.standard import (
26
+ StandardCombineInput,
27
+ StandardDispatchOutput,
28
+ )
29
+
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Runner IO dataclasses
33
+ # ---------------------------------------------------------------------------
34
+
35
+
36
+ @dataclass
37
+ class TritonKernelsRunnerInput(RunnerInput):
38
+ """Input bundle passed to the triton-kernels runner core."""
39
+
40
+ hidden_states: torch.Tensor
41
+ routing_data: "RoutingData"
42
+ gather_indx: "GatherIndx"
43
+ scatter_indx: "ScatterIndx"
44
+
45
+ @property
46
+ def runner_backend(self) -> MoeRunnerBackend:
47
+ return MoeRunnerBackend.TRITON_KERNELS
48
+
49
+
50
+ @dataclass
51
+ class TritonKernelsRunnerOutput(RunnerOutput):
52
+ """Output bundle returned from the triton-kernels runner core."""
53
+
54
+ hidden_states: torch.Tensor
55
+
56
+ @property
57
+ def runner_backend(self) -> MoeRunnerBackend:
58
+ return MoeRunnerBackend.TRITON_KERNELS
59
+
60
+
61
+ @dataclass
62
+ class TritonKernelsQuantInfo(MoeQuantInfo):
63
+ """Quantization payload consumed by the triton-kernels backend."""
64
+
65
+ w13_weight: torch.Tensor
66
+ w2_weight: torch.Tensor
67
+ w13_bias: Optional[torch.Tensor] = None
68
+ w2_bias: Optional[torch.Tensor] = None
69
+ w13_precision_config: Optional[PrecisionConfig] = None
70
+ w2_precision_config: Optional[PrecisionConfig] = None
71
+ global_num_experts: int = -1
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Runner core
76
+ # ---------------------------------------------------------------------------
77
+
78
+
79
+ class TritonKernelsRunnerCore(MoeRunnerCore):
80
+ """Execute MoE experts via the external triton_kernels package."""
81
+
82
+ def run(
83
+ self,
84
+ runner_input: TritonKernelsRunnerInput,
85
+ quant_info: TritonKernelsQuantInfo,
86
+ running_state: dict,
87
+ ) -> TritonKernelsRunnerOutput:
88
+ from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
89
+ triton_kernel_fused_experts,
90
+ triton_kernel_fused_experts_with_bias,
91
+ )
92
+
93
+ hidden_states = runner_input.hidden_states
94
+
95
+ common_kwargs = dict(
96
+ routing_data=runner_input.routing_data,
97
+ gather_indx=runner_input.gather_indx,
98
+ scatter_indx=None if self.config.no_combine else runner_input.scatter_indx,
99
+ inplace=False,
100
+ activation=self.config.activation,
101
+ apply_router_weight_on_input=self.config.apply_router_weight_on_input,
102
+ global_num_experts=quant_info.global_num_experts,
103
+ )
104
+
105
+ has_bias = quant_info.w13_bias is not None or quant_info.w2_bias is not None
106
+
107
+ if has_bias:
108
+ assert (
109
+ quant_info.w13_bias is not None and quant_info.w2_bias is not None
110
+ ), "Bias execution requires both w13_bias and w2_bias"
111
+ output = triton_kernel_fused_experts_with_bias(
112
+ hidden_states=hidden_states,
113
+ w1=quant_info.w13_weight,
114
+ w1_pcg=quant_info.w13_precision_config,
115
+ b1=quant_info.w13_bias,
116
+ w2=quant_info.w2_weight,
117
+ w2_pcg=quant_info.w2_precision_config,
118
+ b2=quant_info.w2_bias,
119
+ gemm1_alpha=self.config.gemm1_alpha,
120
+ gemm1_clamp_limit=self.config.gemm1_clamp_limit,
121
+ **common_kwargs,
122
+ )
123
+ else:
124
+ output = triton_kernel_fused_experts(
125
+ hidden_states=hidden_states,
126
+ w1=quant_info.w13_weight,
127
+ w2=quant_info.w2_weight,
128
+ **common_kwargs,
129
+ )
130
+
131
+ if self.config.no_combine:
132
+ tokens = runner_input.hidden_states.shape[0]
133
+ hidden = runner_input.hidden_states.shape[-1]
134
+ total_rows = output.shape[0]
135
+ top_k = total_rows // tokens
136
+ output = output.view(tokens, top_k, hidden)
137
+
138
+ return TritonKernelsRunnerOutput(hidden_states=output)
139
+
140
+ @property
141
+ def runner_backend(self) -> MoeRunnerBackend:
142
+ return MoeRunnerBackend.TRITON_KERNELS
143
+
144
+
145
+ # ---------------------------------------------------------------------------
146
+ # Permute / fused hooks
147
+ # ---------------------------------------------------------------------------
148
+
149
+
150
+ @register_pre_permute("standard", "triton_kernel")
151
+ def pre_permute_standard_to_triton_kernels(
152
+ dispatch_output: "StandardDispatchOutput",
153
+ quant_info: TritonKernelsQuantInfo,
154
+ runner_config: MoeRunnerConfig,
155
+ running_state: dict,
156
+ ) -> TritonKernelsRunnerInput:
157
+ from sglang.srt.layers.moe.topk import TopKOutputChecker
158
+
159
+ hidden_states = dispatch_output.hidden_states
160
+ topk_output = dispatch_output.topk_output
161
+
162
+ assert TopKOutputChecker.format_is_triton_kernels(
163
+ topk_output
164
+ ), "Triton-kernel runner expects TritonKernelTopKOutput"
165
+
166
+ routing_data, gather_indx, scatter_indx = topk_output
167
+
168
+ return TritonKernelsRunnerInput(
169
+ hidden_states=hidden_states,
170
+ routing_data=routing_data,
171
+ gather_indx=gather_indx,
172
+ scatter_indx=scatter_indx,
173
+ )
174
+
175
+
176
+ @register_post_permute("triton_kernel", "standard")
177
+ def post_permute_triton_kernels_to_standard(
178
+ runner_output: TritonKernelsRunnerOutput,
179
+ quant_info: TritonKernelsQuantInfo,
180
+ runner_config: MoeRunnerConfig,
181
+ running_state: dict,
182
+ ) -> StandardCombineInput:
183
+ from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
184
+
185
+ hidden_states = runner_output.hidden_states
186
+
187
+ if (
188
+ runner_config.routed_scaling_factor is not None
189
+ and runner_config.routed_scaling_factor != 1.0
190
+ and not runner_config.no_combine
191
+ ):
192
+ hidden_states.mul_(runner_config.routed_scaling_factor)
193
+
194
+ return StandardCombineInput(hidden_states=hidden_states)
@@ -2,7 +2,6 @@
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
4
4
  from enum import IntEnum
5
- from functools import cache
6
5
  from typing import Optional
7
6
 
8
7
  import torch
@@ -11,7 +11,7 @@ _is_hip = is_hip()
11
11
 
12
12
 
13
13
  @triton.jit
14
- def fused_moe_router_kernel(
14
+ def fused_moe_router_cudacore_kernel(
15
15
  input_ptr, # input (bs, hidden_dim)
16
16
  moe_router_weight_ptr, # input (num_experts, hidden_dim)
17
17
  topk_weights_ptr, # output (bs, topk)
@@ -114,7 +114,7 @@ def fused_moe_router_kernel(
114
114
  # assert not moe_renormalize, "moe weight renormalization not implemented"
115
115
 
116
116
 
117
- def fused_moe_router_impl(
117
+ def fused_moe_router_cudacore(
118
118
  x: torch.Tensor,
119
119
  router_weight: torch.Tensor,
120
120
  topk: int,
@@ -138,7 +138,7 @@ def fused_moe_router_impl(
138
138
  ),
139
139
  }
140
140
 
141
- fused_moe_router_kernel[(bs,)](
141
+ fused_moe_router_cudacore_kernel[(bs,)](
142
142
  x,
143
143
  router_weight,
144
144
  topk_weights,
@@ -157,7 +157,7 @@ def fused_moe_router_impl(
157
157
 
158
158
 
159
159
  @triton.jit
160
- def fused_moe_router_large_bs_kernel(
160
+ def fused_moe_router_tensorcore_kernel(
161
161
  a_ptr, # input (bs, hidden_dim)
162
162
  b_ptr, # input (num_experts, hidden_dim)
163
163
  topk_weights_ptr, # output (bs, topk)
@@ -167,12 +167,15 @@ def fused_moe_router_large_bs_kernel(
167
167
  topk: tl.constexpr, # only support topk <= 2
168
168
  moe_softcapping: tl.constexpr,
169
169
  moe_renormalize: tl.constexpr, # not supported
170
+ correction_bias_ptr,
171
+ is_correction_bias: tl.constexpr,
170
172
  K: tl.constexpr,
171
173
  BLOCK_SIZE_M: tl.constexpr,
172
174
  BLOCK_SIZE_N: tl.constexpr,
173
175
  BLOCK_SIZE_K: tl.constexpr,
174
176
  stride_am: tl.constexpr,
175
177
  stride_bn: tl.constexpr,
178
+ dp_attn_workaround_flag: tl.constexpr,
176
179
  ):
177
180
 
178
181
  # 1. get block id
@@ -217,6 +220,20 @@ def fused_moe_router_large_bs_kernel(
217
220
  exped = tl.exp(2 * logits_scaled)
218
221
  logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
219
222
 
223
+ # Add bias after softcapping
224
+ if is_correction_bias:
225
+ bias = tl.load(
226
+ correction_bias_ptr + tl.arange(0, BLOCK_SIZE_N)[None, :],
227
+ mask=expert_mask.T,
228
+ other=0.0,
229
+ )
230
+ logits_softcapped = logits_softcapped + bias
231
+
232
+ if dp_attn_workaround_flag:
233
+ logits_softcapped = tl.where(
234
+ logits_softcapped != logits_softcapped, -1e9, logits_softcapped
235
+ )
236
+
220
237
  # 5. top1
221
238
  arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
222
239
  cond_top1 = arange_block_size_n < num_experts
@@ -266,7 +283,7 @@ def fused_moe_router_large_bs_kernel(
266
283
  )
267
284
 
268
285
 
269
- def fused_moe_router_large_bs_impl(
286
+ def fused_moe_router_tensorcore(
270
287
  x: torch.Tensor,
271
288
  router_weight: torch.Tensor,
272
289
  topk: int,
@@ -274,6 +291,7 @@ def fused_moe_router_large_bs_impl(
274
291
  BLOCK_SIZE_M: int,
275
292
  BLOCK_SIZE_N: int,
276
293
  BLOCK_SIZE_K: int,
294
+ correction_bias: Optional[torch.Tensor] = None,
277
295
  ):
278
296
  assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
279
297
  bs, hidden_dim = x.shape
@@ -285,10 +303,17 @@ def fused_moe_router_large_bs_impl(
285
303
 
286
304
  topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
287
305
  topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
306
+ is_correction_bias = correction_bias is not None
288
307
 
289
308
  grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
290
309
 
291
- fused_moe_router_large_bs_kernel[grid](
310
+ # TODO(ch-wan): temporary workaround for dp attention. We should support masked
311
+ # router to skip padded tokens.
312
+ from sglang.srt.layers.dp_attention import is_dp_attention_enabled
313
+
314
+ dp_attn_workaround_flag = is_dp_attention_enabled()
315
+
316
+ fused_moe_router_tensorcore_kernel[grid](
292
317
  a_ptr=x,
293
318
  b_ptr=router_weight,
294
319
  topk_weights_ptr=topk_weights,
@@ -299,11 +324,14 @@ def fused_moe_router_large_bs_impl(
299
324
  moe_softcapping=moe_softcapping,
300
325
  moe_renormalize=False,
301
326
  K=hidden_dim,
327
+ correction_bias_ptr=correction_bias,
328
+ is_correction_bias=is_correction_bias,
302
329
  BLOCK_SIZE_M=BLOCK_SIZE_M,
303
330
  BLOCK_SIZE_N=BLOCK_SIZE_N,
304
331
  BLOCK_SIZE_K=BLOCK_SIZE_K,
305
332
  stride_am=hidden_dim,
306
333
  stride_bn=hidden_dim,
334
+ dp_attn_workaround_flag=dp_attn_workaround_flag,
307
335
  )
308
336
 
309
337
  return topk_weights, topk_ids
@@ -316,6 +344,7 @@ def fused_moe_router_shim(
316
344
  topk,
317
345
  renormalize,
318
346
  correction_bias: Optional[torch.Tensor] = None,
347
+ enable_deterministic_inference: bool = False,
319
348
  ):
320
349
  assert not renormalize
321
350
  assert (
@@ -324,16 +353,22 @@ def fused_moe_router_shim(
324
353
  )
325
354
  bs, hidden_dim = hidden_states.shape
326
355
  num_experts = gating_output.shape[0]
356
+
327
357
  BLOCK_SIZE_M = 32
328
- BLOCK_SIZE_N = 16
329
- BLOCK_SIZE_K = 256
358
+
359
+ BLOCK_SIZE_N = max(num_experts, 16)
360
+ BLOCK_SIZE_K = (
361
+ 256 if num_experts < 256 else 64
362
+ ) # if experts are large, need to use smaller k block or shared memory OOM
363
+
330
364
  if (
331
- bs >= 512
332
- and topk <= 2
333
- and num_experts <= BLOCK_SIZE_N
365
+ (bs >= 512 or num_experts > 8)
334
366
  and hidden_dim % BLOCK_SIZE_K == 0
367
+ # we keep using single kernel to avoid non-deterministic behavior
368
+ and not enable_deterministic_inference
335
369
  ):
336
- return fused_moe_router_large_bs_impl(
370
+ # if large batch size or large expert, use kernel that uses tensorcore in matmul
371
+ return fused_moe_router_tensorcore(
337
372
  x=hidden_states,
338
373
  router_weight=gating_output,
339
374
  topk=topk,
@@ -341,9 +376,11 @@ def fused_moe_router_shim(
341
376
  BLOCK_SIZE_M=BLOCK_SIZE_M,
342
377
  BLOCK_SIZE_N=BLOCK_SIZE_N,
343
378
  BLOCK_SIZE_K=BLOCK_SIZE_K,
379
+ correction_bias=correction_bias,
344
380
  )
345
381
  else:
346
- return fused_moe_router_impl(
382
+ # if smaller, use kernel that does not use tensorcore in matmul
383
+ return fused_moe_router_cudacore(
347
384
  x=hidden_states,
348
385
  router_weight=gating_output,
349
386
  topk=topk,
@@ -380,11 +417,10 @@ class FusedMoeRouter:
380
417
  renormalize=False,
381
418
  )
382
419
 
383
- def forward_vllm(
420
+ def forward_torch(
384
421
  self,
385
422
  x: torch.Tensor,
386
423
  ) -> Tuple[torch.Tensor, torch.Tensor]:
387
- # g, _ = self.router_linear.forward(x)
388
424
  g = x.float() @ self.router_linear.weight.T.float()
389
425
 
390
426
  g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
@@ -12,12 +12,18 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import (
12
12
  DeepEPConfig,
13
13
  DeepEPDispatcher,
14
14
  DeepEPLLCombineInput,
15
- DeepEPLLOutput,
15
+ DeepEPLLDispatchOutput,
16
16
  DeepEPNormalCombineInput,
17
- DeepEPNormalOutput,
17
+ DeepEPNormalDispatchOutput,
18
+ )
19
+ from sglang.srt.layers.moe.token_dispatcher.mooncake import (
20
+ MooncakeCombineInput,
21
+ MooncakeDispatchOutput,
22
+ MooncakeEPDispatcher,
18
23
  )
19
24
  from sglang.srt.layers.moe.token_dispatcher.standard import (
20
25
  StandardCombineInput,
26
+ StandardDispatcher,
21
27
  StandardDispatchOutput,
22
28
  )
23
29
 
@@ -30,12 +36,16 @@ __all__ = [
30
36
  "DispatchOutput",
31
37
  "DispatchOutputFormat",
32
38
  "DispatchOutputChecker",
39
+ "MooncakeCombineInput",
40
+ "MooncakeDispatchOutput",
41
+ "MooncakeEPDispatcher",
42
+ "StandardDispatcher",
33
43
  "StandardDispatchOutput",
34
44
  "StandardCombineInput",
35
45
  "DeepEPConfig",
36
46
  "DeepEPDispatcher",
37
- "DeepEPNormalOutput",
38
- "DeepEPLLOutput",
47
+ "DeepEPNormalDispatchOutput",
48
+ "DeepEPLLDispatchOutput",
39
49
  "DeepEPLLCombineInput",
40
50
  "DeepEPNormalCombineInput",
41
51
  ]
@@ -9,9 +9,9 @@ import torch
9
9
  if TYPE_CHECKING:
10
10
  from sglang.srt.layers.moe.token_dispatcher import (
11
11
  DeepEPLLCombineInput,
12
- DeepEPLLOutput,
12
+ DeepEPLLDispatchOutput,
13
13
  DeepEPNormalCombineInput,
14
- DeepEPNormalOutput,
14
+ DeepEPNormalDispatchOutput,
15
15
  StandardCombineInput,
16
16
  StandardDispatchOutput,
17
17
  )
@@ -28,22 +28,28 @@ class DispatchOutputChecker:
28
28
  ) -> TypeGuard[StandardDispatchOutput]:
29
29
  return dispatch_output.format.is_standard()
30
30
 
31
+ @staticmethod
32
+ def format_is_triton_kernels(
33
+ dispatch_output: DispatchOutput,
34
+ ) -> TypeGuard[StandardDispatchOutput]:
35
+ return dispatch_output.format.is_standard()
36
+
31
37
  @staticmethod
32
38
  def format_is_deepep_normal(
33
39
  dispatch_output: DispatchOutput,
34
- ) -> TypeGuard[DeepEPNormalOutput]:
40
+ ) -> TypeGuard[DeepEPNormalDispatchOutput]:
35
41
  return dispatch_output.format.is_deepep_normal()
36
42
 
37
43
  @staticmethod
38
44
  def format_is_deepep_ll(
39
45
  dispatch_output: DispatchOutput,
40
- ) -> TypeGuard[DeepEPLLOutput]:
46
+ ) -> TypeGuard[DeepEPLLDispatchOutput]:
41
47
  return dispatch_output.format.is_deepep_ll()
42
48
 
43
49
  @staticmethod
44
50
  def format_is_deepep(
45
51
  dispatch_output: DispatchOutput,
46
- ) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
52
+ ) -> TypeGuard[Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput]]:
47
53
  return dispatch_output.format.is_deepep()
48
54
 
49
55
 
@@ -73,7 +79,7 @@ class DispatchOutputFormat(Enum):
73
79
  class DispatchOutput(Protocol):
74
80
  """Protocol for dispatch outputs in different formats."""
75
81
 
76
- # TODO: add hidden_states to the protocol
82
+ hidden_states: torch.Tensor
77
83
 
78
84
  @property
79
85
  def format(self) -> DispatchOutputFormat: ...