sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,343 @@
1
+ # Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
2
+ # Copyright (c) 2024, Tri Dao.
3
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
4
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
5
+ # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
6
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
7
+
8
+ import math
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import triton
13
+ import triton.language as tl
14
+ from einops import rearrange
15
+
16
+
17
+ def rms_norm_ref(
18
+ x,
19
+ weight,
20
+ bias,
21
+ z=None,
22
+ eps=1e-6,
23
+ group_size=None,
24
+ norm_before_gate=True,
25
+ upcast=True,
26
+ ):
27
+ dtype = x.dtype
28
+ N = x.shape[-1]
29
+ weight = weight.float()
30
+ bias = bias.float() if bias is not None else None
31
+ if upcast:
32
+ x = x.float()
33
+ z = z.float() if z is not None else z
34
+ if z is not None and not norm_before_gate:
35
+ x = x * F.silu(z)
36
+ if group_size is None:
37
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
38
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
39
+ else:
40
+ x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
41
+ rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
42
+ out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
43
+ if bias is not None:
44
+ out = out + bias
45
+ if z is not None and norm_before_gate:
46
+ out *= F.silu(z)
47
+ return out.to(dtype)
48
+
49
+
50
+ @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
51
+ @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
52
+ @triton.jit
53
+ def _layer_norm_fwd_1pass_kernel(
54
+ X, # pointer to the input
55
+ Y, # pointer to the output
56
+ W, # pointer to the weights
57
+ B, # pointer to the biases
58
+ Z, # pointer to the other branch
59
+ Mean, # pointer to the mean
60
+ Rstd, # pointer to the 1/std
61
+ stride_x_row, # how much to increase the pointer when moving by 1 row
62
+ stride_y_row,
63
+ stride_z_row,
64
+ M, # number of rows in X
65
+ N, # number of columns in X
66
+ eps, # epsilon to avoid division by zero
67
+ BLOCK_N: tl.constexpr,
68
+ HAS_BIAS: tl.constexpr,
69
+ HAS_Z: tl.constexpr,
70
+ NORM_BEFORE_GATE: tl.constexpr,
71
+ IS_RMS_NORM: tl.constexpr,
72
+ ):
73
+ # Map the program id to the row of X and Y it should compute.
74
+ row = tl.program_id(0)
75
+ group = tl.program_id(1)
76
+ X += row * stride_x_row + group * N
77
+ Y += row * stride_y_row + group * N
78
+ if HAS_Z:
79
+ Z += row * stride_z_row + group * N
80
+ if not IS_RMS_NORM:
81
+ Mean += group * M
82
+ Rstd += group * M
83
+ W += group * N
84
+ if HAS_BIAS:
85
+ B += group * N
86
+ # Compute mean and variance
87
+ cols = tl.arange(0, BLOCK_N)
88
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
89
+ if HAS_Z and not NORM_BEFORE_GATE:
90
+ z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
91
+ x *= z * tl.sigmoid(z)
92
+ if not IS_RMS_NORM:
93
+ mean = tl.sum(x, axis=0) / N
94
+ tl.store(Mean + row, mean)
95
+ xbar = tl.where(cols < N, x - mean, 0.0)
96
+ var = tl.sum(xbar * xbar, axis=0) / N
97
+ else:
98
+ xbar = tl.where(cols < N, x, 0.0)
99
+ var = tl.sum(xbar * xbar, axis=0) / N
100
+ rstd = 1 / tl.sqrt(var + eps)
101
+ tl.store(Rstd + row, rstd)
102
+ # Normalize and apply linear transformation
103
+ mask = cols < N
104
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
105
+ if HAS_BIAS:
106
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
107
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
108
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
109
+ if HAS_Z and NORM_BEFORE_GATE:
110
+ z = tl.load(Z + cols, mask=mask).to(tl.float32)
111
+ y *= z * tl.sigmoid(z)
112
+ # Write output
113
+ tl.store(Y + cols, y, mask=mask)
114
+
115
+
116
+ def _layer_norm_fwd(
117
+ x,
118
+ weight,
119
+ bias,
120
+ eps,
121
+ z=None,
122
+ out=None,
123
+ group_size=None,
124
+ norm_before_gate=True,
125
+ is_rms_norm=False,
126
+ ):
127
+ M, N = x.shape
128
+ if group_size is None:
129
+ group_size = N
130
+ assert N % group_size == 0
131
+ ngroups = N // group_size
132
+ assert x.stride(-1) == 1
133
+ if z is not None:
134
+ assert z.stride(-1) == 1
135
+ assert z.shape == (M, N)
136
+ assert weight.shape == (N,)
137
+ assert weight.stride(-1) == 1
138
+ if bias is not None:
139
+ assert bias.stride(-1) == 1
140
+ assert bias.shape == (N,)
141
+ # allocate output
142
+ if out is not None:
143
+ assert out.shape == x.shape
144
+ else:
145
+ out = torch.empty_like(x)
146
+ assert out.stride(-1) == 1
147
+ mean = (
148
+ torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
149
+ if not is_rms_norm
150
+ else None
151
+ )
152
+ rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
153
+ # Less than 64KB per feature: enqueue fused kernel
154
+ MAX_FUSED_SIZE = 65536 // x.element_size()
155
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
156
+ if group_size > BLOCK_N:
157
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
158
+ # heuristics for number of warps
159
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
160
+ grid = (M, ngroups)
161
+ with torch.get_device_module(x.device).device(x.device.index):
162
+ _layer_norm_fwd_1pass_kernel[grid](
163
+ x,
164
+ out,
165
+ weight,
166
+ bias,
167
+ z,
168
+ mean,
169
+ rstd,
170
+ x.stride(0),
171
+ out.stride(0),
172
+ z.stride(0) if z is not None else 0,
173
+ M,
174
+ group_size,
175
+ eps,
176
+ BLOCK_N=BLOCK_N,
177
+ NORM_BEFORE_GATE=norm_before_gate,
178
+ IS_RMS_NORM=is_rms_norm,
179
+ num_warps=num_warps,
180
+ )
181
+ return out, mean, rstd
182
+
183
+
184
+ def rms_norm_gated(
185
+ *,
186
+ x,
187
+ weight,
188
+ bias,
189
+ z=None,
190
+ eps=1e-6,
191
+ group_size=None,
192
+ norm_before_gate=True,
193
+ is_rms_norm=False,
194
+ ):
195
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
196
+
197
+ x_shape_og = x.shape
198
+ # reshape input data into 2D tensor
199
+ x = x.reshape(-1, x.shape[-1])
200
+ if x.stride(-1) != 1:
201
+ x = x.contiguous()
202
+ if z is not None:
203
+ assert z.shape == x_shape_og
204
+ z = z.reshape(-1, z.shape[-1])
205
+ if z.stride(-1) != 1:
206
+ z = z.contiguous()
207
+ weight = weight.contiguous()
208
+ if bias is not None:
209
+ bias = bias.contiguous()
210
+ y, mean, rstd = _layer_norm_fwd(
211
+ x,
212
+ weight,
213
+ bias,
214
+ eps,
215
+ z=z,
216
+ group_size=group_size,
217
+ norm_before_gate=norm_before_gate,
218
+ is_rms_norm=is_rms_norm,
219
+ )
220
+ return y.reshape(x_shape_og)
221
+
222
+
223
+ class LayerNormFn(torch.autograd.Function):
224
+
225
+ @staticmethod
226
+ def forward(
227
+ ctx,
228
+ x,
229
+ weight,
230
+ bias,
231
+ z=None,
232
+ eps=1e-6,
233
+ group_size=None,
234
+ norm_before_gate=True,
235
+ is_rms_norm=False,
236
+ ):
237
+ return rms_norm_gated(
238
+ x=x,
239
+ weight=weight,
240
+ bias=bias,
241
+ eps=eps,
242
+ z=z,
243
+ group_size=group_size,
244
+ norm_before_gate=norm_before_gate,
245
+ is_rms_norm=is_rms_norm,
246
+ )
247
+
248
+
249
+ def layernorm_fn(
250
+ x,
251
+ weight,
252
+ bias,
253
+ z=None,
254
+ eps=1e-6,
255
+ group_size=None,
256
+ norm_before_gate=True,
257
+ is_rms_norm=False,
258
+ ):
259
+ return LayerNormFn.apply(
260
+ x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
261
+ )
262
+
263
+
264
+ class LayerNorm(torch.nn.Module):
265
+
266
+ def __init__(
267
+ self,
268
+ hidden_size,
269
+ eps=1e-5,
270
+ group_size=None,
271
+ norm_before_gate=True,
272
+ device=None,
273
+ dtype=None,
274
+ ):
275
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
276
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
277
+ """
278
+
279
+ factory_kwargs = {"device": device, "dtype": dtype}
280
+ super().__init__()
281
+ self.eps = eps
282
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
283
+ self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
284
+ self.group_size = group_size
285
+ self.norm_before_gate = norm_before_gate
286
+ self.reset_parameters()
287
+
288
+ def reset_parameters(self):
289
+ torch.nn.init.ones_(self.weight)
290
+ torch.nn.init.zeros_(self.bias)
291
+
292
+ def forward(self, x, z=None):
293
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
294
+ return layernorm_fn(
295
+ x,
296
+ self.weight,
297
+ self.bias,
298
+ z=z,
299
+ group_size=self.group_size,
300
+ eps=self.eps,
301
+ norm_before_gate=self.norm_before_gate,
302
+ is_rms_norm=False,
303
+ )
304
+
305
+
306
+ class RMSNorm(torch.nn.Module):
307
+
308
+ def __init__(
309
+ self,
310
+ hidden_size,
311
+ eps=1e-5,
312
+ group_size=None,
313
+ norm_before_gate=True,
314
+ device=None,
315
+ dtype=None,
316
+ ):
317
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
318
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
319
+ """
320
+ factory_kwargs = {"device": device, "dtype": dtype}
321
+ super().__init__()
322
+ self.eps = eps
323
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
324
+ self.register_parameter("bias", None)
325
+ self.group_size = group_size
326
+ self.norm_before_gate = norm_before_gate
327
+ self.reset_parameters()
328
+
329
+ def reset_parameters(self):
330
+ torch.nn.init.ones_(self.weight)
331
+
332
+ def forward(self, x, z=None):
333
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
334
+ return layernorm_fn(
335
+ x,
336
+ self.weight,
337
+ self.bias,
338
+ z=z,
339
+ eps=self.eps,
340
+ group_size=self.group_size,
341
+ norm_before_gate=self.norm_before_gate,
342
+ is_rms_norm=True,
343
+ )
@@ -0,0 +1,66 @@
1
+ # Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/op.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ import os
6
+
7
+ import triton
8
+ import triton.language as tl
9
+ import triton.language.extra.libdevice as tldevice
10
+
11
+ from sglang.srt.layers.attention.fla.utils import is_gather_supported
12
+
13
+ if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
14
+ exp = tldevice.fast_expf
15
+ exp2 = tldevice.exp2
16
+ log = tldevice.fast_logf
17
+ log2 = tldevice.fast_log2f
18
+ else:
19
+ exp = tl.exp
20
+ exp2 = tl.math.exp2
21
+ log = tl.log
22
+ log2 = tl.log2
23
+
24
+
25
+ @triton.jit
26
+ def safe_exp(x):
27
+ return exp(tl.where(x <= 0, x, float("-inf")))
28
+
29
+
30
+ if not is_gather_supported:
31
+
32
+ @triton.jit
33
+ def gather(src, index, axis, _builder=None):
34
+ """
35
+ Gather operation that works when tl.gather is not supported.
36
+ This is a fallback implementation that returns None.
37
+ Just to make triton compiler happy.
38
+ """
39
+ return None
40
+
41
+ else:
42
+ gather = tl.gather
43
+
44
+
45
+ if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
46
+ # For Triton 3.3.x
47
+ make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
48
+ elif hasattr(triton.language, "make_tensor_descriptor"):
49
+ # For Triton 3.4.x and later
50
+ make_tensor_descriptor = triton.language.make_tensor_descriptor
51
+ else:
52
+ """
53
+ Fallback implementation when TMA is not supported.
54
+ Returns None to indicate TMA descriptors are unavailable.
55
+ Just make triton compiler happy.
56
+ """
57
+
58
+ @triton.jit
59
+ def make_tensor_descriptor(
60
+ base,
61
+ shape,
62
+ strides,
63
+ block_shape,
64
+ _builder=None,
65
+ ):
66
+ return None