sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional
5
5
 
6
6
  import torch
7
7
  import torch.nn.functional as F
@@ -12,17 +12,20 @@ from sglang.srt.custom_op import CustomOp
12
12
  from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu
13
13
 
14
14
  if is_cuda():
15
- import deep_gemm
15
+ try:
16
+ import deep_gemm
17
+ except ImportError as e:
18
+ deep_gemm = e
16
19
 
17
- from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
20
+ from sglang.srt.layers import deep_gemm_wrapper
21
+ from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM
18
22
  from sglang.srt.layers.dp_attention import get_attention_tp_group
19
23
  from sglang.srt.layers.linear import ReplicatedLinear
20
- from sglang.srt.layers.quantization import deep_gemm_wrapper
21
24
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
22
25
  from sglang.srt.layers.rotary_embedding import get_rope_wrapper
23
- from sglang.srt.managers.schedule_batch import global_server_args_dict
24
26
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
25
27
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
+ from sglang.srt.server_args import get_global_server_args
26
29
 
27
30
  if TYPE_CHECKING:
28
31
  from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
@@ -71,7 +74,7 @@ class BaseIndexerMetadata(ABC):
71
74
 
72
75
  def rotate_activation(x: torch.Tensor) -> torch.Tensor:
73
76
  assert x.dtype == torch.bfloat16
74
- from fast_hadamard_transform import hadamard_transform
77
+ from sgl_kernel import hadamard_transform
75
78
 
76
79
  hidden_size = x.size(-1)
77
80
  assert (
@@ -159,49 +162,13 @@ class Indexer(CustomOp):
159
162
  base=rope_theta, # type: ignore
160
163
  rope_scaling=rope_scaling,
161
164
  is_neox_style=False,
162
- device=global_server_args_dict["device"],
165
+ device=get_global_server_args().device,
163
166
  )
164
167
  self.block_size = block_size
165
168
  self.scale_fmt = scale_fmt
166
169
  self.softmax_scale = self.head_dim**-0.5
167
170
 
168
- def _forward_fake(
169
- self,
170
- x: torch.Tensor,
171
- q_lora: torch.Tensor,
172
- positions: torch.Tensor,
173
- forward_batch: ForwardBatch,
174
- layer_id: int,
175
- ):
176
- bs = x.shape[0]
177
- assert self.index_topk == 2048
178
- ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[
179
- None, ...
180
- ].repeat(bs, 1)
181
- if forward_batch.forward_mode.is_extend():
182
- assert (
183
- forward_batch.extend_seq_lens_cpu is not None
184
- and forward_batch.seq_lens_cpu is not None
185
- )
186
- which = 0
187
- for i, (kv_len, qo_len) in enumerate(
188
- zip(
189
- forward_batch.seq_lens_cpu.tolist(),
190
- forward_batch.extend_seq_lens_cpu,
191
- strict=True,
192
- )
193
- ):
194
- for j in range(kv_len - qo_len, kv_len):
195
- ans[which, j + 1 :] = -1
196
- which += 1
197
- assert which == ans.shape[0]
198
- else:
199
- assert forward_batch.seq_lens_cpu is not None
200
- for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()):
201
- ans[i, seq_len:] = -1
202
-
203
- return ans
204
-
171
+ @torch.compile(dynamic=True)
205
172
  def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
206
173
  weights, _ = self.weights_proj(x)
207
174
  weights = weights * self.n_heads**-0.5
@@ -299,7 +266,10 @@ class Indexer(CustomOp):
299
266
  )
300
267
 
301
268
  blocksize = page_size
302
- seqlens_32 = metadata.get_seqlens_int32()
269
+ if forward_batch.forward_mode.is_target_verify():
270
+ seqlens_32 = metadata.get_seqlens_expanded()
271
+ else:
272
+ seqlens_32 = metadata.get_seqlens_int32()
303
273
  # NOTE(dark): 132 is SM count on H200/B200, not magic number
304
274
  schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
305
275
  seqlens_32, blocksize, self.sm_count
@@ -350,8 +320,9 @@ class Indexer(CustomOp):
350
320
  k_fp8_list = []
351
321
  k_scale_list = []
352
322
  ks_list = []
323
+ ke_list = []
353
324
  offset = 0
354
-
325
+ seq_lens_expanded = metadata.get_seqlens_expanded()
355
326
  block_tables = metadata.get_page_table_64()
356
327
 
357
328
  assert (
@@ -374,33 +345,37 @@ class Indexer(CustomOp):
374
345
  )
375
346
  extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
376
347
  ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
348
+ ke = ks + seq_lens_expanded[offset : offset + extend_seq_len]
377
349
  k_fp8_list.append(k_fp8)
378
350
  k_scale_list.append(k_scale)
379
351
  ks_list.append(ks)
352
+ ke_list.append(ke)
380
353
  offset += extend_seq_len
381
354
 
382
355
  k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
383
356
  k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
384
357
  kv_fp8 = (k_fp8, k_scale)
385
358
  ks = torch.cat(ks_list, dim=0)
386
- seq_lens_expanded = metadata.get_seqlens_expanded()
387
- ke = ks + seq_lens_expanded
359
+ ke = torch.cat(ke_list, dim=0)
388
360
 
389
361
  logits = deep_gemm.fp8_mqa_logits(
390
- q_fp8,
362
+ q_fp8[:offset],
391
363
  kv_fp8,
392
- weights,
364
+ weights[:offset],
393
365
  ks,
394
366
  ke,
395
367
  clean_logits=False,
396
368
  )
397
-
369
+ token_nums, _, _ = q_fp8.shape
398
370
  assert logits.shape[0] == len(seq_lens_expanded)
399
- topk_result = metadata.topk_transform(logits, self.index_topk)
400
-
371
+ raw_topk_result = metadata.topk_transform(logits, self.index_topk)
372
+ topk_result = torch.full(
373
+ (token_nums, self.index_topk), -1, device=q_fp8.device, dtype=torch.int32
374
+ )
375
+ topk_result[:offset] = raw_topk_result
401
376
  return topk_result
402
377
 
403
- def forward_indexer_bs_1(
378
+ def forward_indexer(
404
379
  self,
405
380
  q_fp8: torch.Tensor,
406
381
  weights: torch.Tensor,
@@ -481,20 +456,9 @@ class Indexer(CustomOp):
481
456
  q_len_start = q_len_end
482
457
 
483
458
  topk_indices = torch.cat(topk_indices_list, dim=0)
484
-
485
459
  return topk_indices
486
460
 
487
- def forward_indexer(
488
- self,
489
- q_fp8: torch.Tensor,
490
- weights: torch.Tensor,
491
- forward_batch: ForwardBatch,
492
- topk: int,
493
- layer_id: int,
494
- ) -> Optional[torch.Tensor]:
495
- return self.forward_indexer_bs_1(q_fp8, weights, forward_batch, topk, layer_id)
496
-
497
- def _forward(
461
+ def forward_cuda(
498
462
  self,
499
463
  x: torch.Tensor,
500
464
  q_lora: torch.Tensor,
@@ -502,8 +466,10 @@ class Indexer(CustomOp):
502
466
  forward_batch: ForwardBatch,
503
467
  layer_id: int,
504
468
  ) -> Optional[torch.Tensor]:
505
- if not is_npu():
469
+ if is_hip():
506
470
  from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
471
+ elif not is_npu():
472
+ from sglang.srt.layers.attention.nsa.triton_kernel import act_quant
507
473
 
508
474
  if TYPE_CHECKING:
509
475
  assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
@@ -524,9 +490,6 @@ class Indexer(CustomOp):
524
490
  if metadata is None:
525
491
  return None
526
492
 
527
- if not NSA_USE_REAL_INDEXER: # temporary
528
- return self._forward_fake(x, q_lora, positions, forward_batch, layer_id)
529
-
530
493
  query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
531
494
 
532
495
  if enable_dual_stream:
@@ -545,6 +508,8 @@ class Indexer(CustomOp):
545
508
  # k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
546
509
  # k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
547
510
  # k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
511
+ if not forward_batch.out_cache_loc.is_contiguous():
512
+ forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous()
548
513
  forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
549
514
  layer_id=layer_id,
550
515
  loc=forward_batch.out_cache_loc,
@@ -566,7 +531,10 @@ class Indexer(CustomOp):
566
531
  (x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
567
532
  )
568
533
 
569
- if forward_batch.forward_mode.is_decode_or_idle():
534
+ if (
535
+ forward_batch.forward_mode.is_decode_or_idle()
536
+ or forward_batch.forward_mode.is_target_verify()
537
+ ):
570
538
  topk_result = self._get_topk_paged(
571
539
  forward_batch, layer_id, q_fp8, weights, metadata
572
540
  )
@@ -582,19 +550,8 @@ class Indexer(CustomOp):
582
550
  topk=self.index_topk,
583
551
  layer_id=layer_id,
584
552
  )
585
-
586
553
  return topk_result
587
554
 
588
- def forward_cuda(
589
- self,
590
- x: torch.Tensor,
591
- q_lora: torch.Tensor,
592
- positions: torch.Tensor,
593
- forward_batch: ForwardBatch,
594
- layer_id: int,
595
- ) -> Optional[torch.Tensor]:
596
- return self._forward(x, q_lora, positions, forward_batch, layer_id)
597
-
598
555
  def forward_npu(
599
556
  self,
600
557
  x: torch.Tensor,
@@ -603,7 +560,7 @@ class Indexer(CustomOp):
603
560
  forward_batch: ForwardBatch,
604
561
  layer_id: int,
605
562
  ) -> torch.Tensor:
606
- import custom_ops
563
+ import custom_ops # noqa: F401
607
564
  import torch_npu
608
565
 
609
566
  from sglang.srt.layers.dp_attention import (
@@ -0,0 +1,136 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ # Triton implementation
9
+ @triton.jit
10
+ def _act_quant_kernel(
11
+ X_ptr,
12
+ Y_ptr,
13
+ S_ptr,
14
+ M,
15
+ N,
16
+ group_size: tl.constexpr,
17
+ round_scale: tl.constexpr,
18
+ BLOCK_M: tl.constexpr,
19
+ BLOCK_N: tl.constexpr,
20
+ ):
21
+ """
22
+ Triton kernel for activation quantization.
23
+
24
+ Each block processes BLOCK_M rows and group_size columns.
25
+ """
26
+ # Get block IDs
27
+ pid_m = tl.program_id(0)
28
+ pid_n = tl.program_id(1)
29
+
30
+ # FP8 constants
31
+ fp8_min = -448.0
32
+ fp8_max = 448.0
33
+ fp8_max_inv = 1.0 / fp8_max
34
+
35
+ # Calculate row and column offsets
36
+ row_start = pid_m * BLOCK_M
37
+ col_start = pid_n * group_size
38
+
39
+ # Create offset arrays
40
+ rows = row_start + tl.arange(0, BLOCK_M)
41
+ cols = col_start + tl.arange(0, BLOCK_N)
42
+
43
+ # Mask for valid rows and columns
44
+ row_mask = rows < M
45
+ col_mask = cols < N
46
+ mask = row_mask[:, None] & col_mask[None, :]
47
+
48
+ # Load input data
49
+ x_ptrs = X_ptr + rows[:, None] * N + cols[None, :]
50
+ x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32)
51
+
52
+ # Compute absolute max along columns (group_size dimension) for each row
53
+ x_abs = tl.abs(x)
54
+ amax = tl.max(x_abs, axis=1) # Shape: (BLOCK_M,)
55
+
56
+ # Clamp amax to avoid division by zero
57
+ amax = tl.maximum(amax, 1e-4)
58
+
59
+ # Compute scale
60
+ if round_scale:
61
+ # Fast round scale using bit manipulation approximation
62
+ # This is a simplified version - the exact bit manipulation is harder in Triton
63
+ # Using log2 + ceil + pow2 as approximation
64
+ log_val = tl.log2(amax * fp8_max_inv)
65
+ log_ceil = tl.ceil(log_val)
66
+ scale = tl.exp2(log_ceil)
67
+ else:
68
+ scale = amax * fp8_max_inv
69
+
70
+ # Quantize: y = clamp(x / scale, fp8_min, fp8_max)
71
+ scale_broadcast = scale[:, None]
72
+ y = x / scale_broadcast
73
+ y = tl.minimum(tl.maximum(y, fp8_min), fp8_max)
74
+
75
+ # Store quantized output
76
+ y_ptrs = Y_ptr + rows[:, None] * N + cols[None, :]
77
+ tl.store(y_ptrs, y, mask=mask)
78
+
79
+ # Store scales
80
+ s_cols = pid_n
81
+ s_ptrs = S_ptr + rows * (N // group_size) + s_cols
82
+ s_mask = row_mask
83
+ tl.store(s_ptrs, scale, mask=s_mask)
84
+
85
+
86
+ def act_quant(
87
+ x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
88
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ """
90
+ Quantizes the input tensor `x` using block-wise quantization with Triton.
91
+
92
+ Args:
93
+ x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
94
+ block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
95
+ scale_fmt (Optional[str], optional): The format of the scale. Default is None.
96
+ Returns:
97
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
98
+ - The quantized tensor with dtype `torch.float8_e4m3fn`.
99
+ - A tensor of scaling factors with dtype `torch.float32`.
100
+ """
101
+ assert x.is_contiguous(), "Input tensor must be contiguous"
102
+ assert (
103
+ x.size(-1) % block_size == 0
104
+ ), f"Last dimension size must be divisible by block_size (block_size={block_size})"
105
+
106
+ # Flatten all dims except last
107
+ N = x.size(-1)
108
+ x_flat = x.view(-1, N)
109
+ M = x_flat.size(0)
110
+
111
+ # Allocate output tensors
112
+ y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
113
+ y_flat = y.view(-1, N)
114
+ s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
115
+ s_flat = s.view(-1, N // block_size)
116
+
117
+ # Launch kernel
118
+ BLOCK_M = 32
119
+ BLOCK_N = block_size
120
+ grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, block_size))
121
+ round_scale = scale_fmt is not None
122
+
123
+ _act_quant_kernel[grid](
124
+ x_flat,
125
+ y_flat,
126
+ s_flat,
127
+ M,
128
+ N,
129
+ group_size=block_size,
130
+ round_scale=round_scale,
131
+ BLOCK_M=BLOCK_M,
132
+ BLOCK_N=BLOCK_N,
133
+ num_stages=0 if round_scale else 2,
134
+ )
135
+
136
+ return y, s
@@ -1,7 +1,6 @@
1
1
  # temp NSA debugging environ
2
2
  from sglang.srt.utils import get_bool_env_var
3
3
 
4
- NSA_USE_REAL_INDEXER = get_bool_env_var("SGLANG_NSA_USE_REAL_INDEXER", "true")
5
4
  NSA_DUAL_STREAM = get_bool_env_var("SGLANG_NSA_DUAL_STREAM", "true")
6
5
  NSA_FUSE_TOPK = get_bool_env_var("SGLANG_NSA_FUSE_TOPK", "true")
7
6