sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,718 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+ from sglang.srt.custom_op import CustomOp
12
+ from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu
13
+
14
+ if is_cuda():
15
+ try:
16
+ import deep_gemm
17
+ except ImportError as e:
18
+ deep_gemm = e
19
+
20
+ from sglang.srt.layers import deep_gemm_wrapper
21
+ from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM
22
+ from sglang.srt.layers.dp_attention import get_attention_tp_group
23
+ from sglang.srt.layers.linear import ReplicatedLinear
24
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
25
+ from sglang.srt.layers.rotary_embedding import get_rope_wrapper
26
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
27
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
+ from sglang.srt.server_args import get_global_server_args
29
+
30
+ if TYPE_CHECKING:
31
+ from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
32
+
33
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024 if is_cuda() else 0
34
+
35
+
36
+ class BaseIndexerMetadata(ABC):
37
+ @abstractmethod
38
+ def get_seqlens_int32(self) -> torch.Tensor:
39
+ """
40
+ Return: (batch_size,) int32 tensor
41
+ """
42
+
43
+ @abstractmethod
44
+ def get_page_table_64(self) -> torch.Tensor:
45
+ """
46
+ Return: (batch_size, num_blocks) int32, page table.
47
+ The page size of the table is 64.
48
+ """
49
+
50
+ @abstractmethod
51
+ def get_seqlens_expanded(self) -> torch.Tensor:
52
+ """
53
+ Return: (sum_extend_seq_len,) int32 tensor
54
+ """
55
+
56
+ @abstractmethod
57
+ def topk_transform(
58
+ self,
59
+ logits: torch.Tensor,
60
+ topk: int,
61
+ ) -> torch.Tensor:
62
+ """
63
+ Perform topk selection on the logits and possibly transform the result.
64
+
65
+ NOTE that attention backend may override this function to do some
66
+ transformation, which means the result of this topk_transform may not
67
+ be the topk indices of the input logits.
68
+
69
+ Return: Anything, since it will be passed to the attention backend
70
+ for further processing on sparse attention computation.
71
+ Don't assume it is the topk indices of the input logits.
72
+ """
73
+
74
+
75
+ def rotate_activation(x: torch.Tensor) -> torch.Tensor:
76
+ assert x.dtype == torch.bfloat16
77
+ from sgl_kernel import hadamard_transform
78
+
79
+ hidden_size = x.size(-1)
80
+ assert (
81
+ hidden_size & (hidden_size - 1)
82
+ ) == 0, "Hidden size must be a power of 2 for Hadamard transform."
83
+ return hadamard_transform(x, scale=hidden_size**-0.5)
84
+
85
+
86
+ class V32LayerNorm(nn.Module):
87
+ """
88
+ Layer Normalization.
89
+ """
90
+
91
+ def __init__(self, dim: int, eps: float = 1e-6):
92
+ super().__init__()
93
+ self.dim = dim
94
+ self.eps = eps
95
+ self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
96
+ self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
97
+
98
+ def forward(self, x: torch.Tensor):
99
+ return F.layer_norm(
100
+ x.float(), (self.dim,), self.weight, self.bias, self.eps
101
+ ).type_as(x)
102
+
103
+
104
+ class Indexer(CustomOp):
105
+ def __init__(
106
+ self,
107
+ hidden_size: int,
108
+ index_n_heads: int,
109
+ index_head_dim: int,
110
+ rope_head_dim: int,
111
+ index_topk: int,
112
+ q_lora_rank: int,
113
+ max_position_embeddings: int,
114
+ rope_theta: float,
115
+ layer_id: int,
116
+ scale_fmt: Optional[str],
117
+ block_size: int = 128,
118
+ rope_scaling: Optional[Dict[str, Any]] = None,
119
+ prefix: str = "",
120
+ quant_config: Optional[QuantizationConfig] = None,
121
+ alt_stream: Optional[torch.cuda.Stream] = None,
122
+ ):
123
+ super().__init__()
124
+ self.hidden_size = hidden_size
125
+ self.n_heads = index_n_heads
126
+ self.head_dim = index_head_dim
127
+ self.rope_head_dim = rope_head_dim
128
+ self.index_topk = index_topk
129
+ self.q_lora_rank = q_lora_rank
130
+ self.layer_id = layer_id
131
+ self.alt_stream = alt_stream
132
+ if is_cuda():
133
+ self.sm_count = deep_gemm.get_num_sms()
134
+ self.half_device_sm_count = align(self.sm_count // 2, 8)
135
+
136
+ self.wq_b = ReplicatedLinear(
137
+ self.q_lora_rank,
138
+ self.n_heads * self.head_dim,
139
+ bias=False,
140
+ quant_config=quant_config,
141
+ prefix=add_prefix("wq_b", prefix),
142
+ )
143
+ self.wk = ReplicatedLinear(
144
+ self.hidden_size,
145
+ self.head_dim,
146
+ bias=False,
147
+ quant_config=quant_config,
148
+ prefix=add_prefix("wk", prefix),
149
+ )
150
+ self.k_norm = V32LayerNorm(self.head_dim)
151
+ # NOTE: weight_proj is not quantized
152
+ self.weights_proj = ReplicatedLinear(
153
+ self.hidden_size,
154
+ self.n_heads,
155
+ bias=False,
156
+ prefix=add_prefix("weights_proj", prefix),
157
+ )
158
+ self.rotary_emb = get_rope_wrapper(
159
+ rope_head_dim,
160
+ rotary_dim=rope_head_dim,
161
+ max_position=max_position_embeddings,
162
+ base=rope_theta, # type: ignore
163
+ rope_scaling=rope_scaling,
164
+ is_neox_style=False,
165
+ device=get_global_server_args().device,
166
+ )
167
+ self.block_size = block_size
168
+ self.scale_fmt = scale_fmt
169
+ self.softmax_scale = self.head_dim**-0.5
170
+
171
+ @torch.compile(dynamic=True)
172
+ def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
173
+ weights, _ = self.weights_proj(x)
174
+ weights = weights * self.n_heads**-0.5
175
+ weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
176
+ return weights
177
+
178
+ def _get_q_k_bf16(
179
+ self,
180
+ q_lora: torch.Tensor,
181
+ x: torch.Tensor,
182
+ positions: torch.Tensor,
183
+ enable_dual_stream: bool,
184
+ ):
185
+
186
+ if enable_dual_stream:
187
+ current_stream = torch.cuda.current_stream()
188
+ self.alt_stream.wait_stream(current_stream)
189
+
190
+ with deep_gemm_wrapper.configure_deep_gemm_num_sms(
191
+ self.half_device_sm_count
192
+ ):
193
+ query, _ = self.wq_b(q_lora)
194
+ query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
195
+ q_rope, _ = torch.split(
196
+ query,
197
+ [self.rope_head_dim, self.head_dim - self.rope_head_dim],
198
+ dim=-1,
199
+ )
200
+ with torch.cuda.stream(self.alt_stream):
201
+ # TODO we should also put DeepGEMM half SM here?
202
+ key, _ = self.wk(x)
203
+ key = self.k_norm(key)
204
+
205
+ k_rope, _ = torch.split(
206
+ key,
207
+ [self.rope_head_dim, self.head_dim - self.rope_head_dim],
208
+ dim=-1,
209
+ )
210
+
211
+ current_stream.wait_stream(self.alt_stream)
212
+ else:
213
+ query, _ = self.wq_b(q_lora)
214
+ query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
215
+
216
+ q_rope, _ = torch.split(
217
+ query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
218
+ )
219
+
220
+ key, _ = self.wk(x)
221
+ key = self.k_norm(key)
222
+ k_rope, _ = torch.split(
223
+ key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
224
+ )
225
+
226
+ q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)
227
+
228
+ query[..., : self.rope_head_dim] = q_rope
229
+ key[..., : self.rope_head_dim] = k_rope
230
+
231
+ if enable_dual_stream:
232
+ current_stream = torch.cuda.current_stream()
233
+ self.alt_stream.wait_stream(current_stream)
234
+ query = rotate_activation(query)
235
+
236
+ with torch.cuda.stream(self.alt_stream):
237
+ key = rotate_activation(key)
238
+ current_stream.wait_stream(self.alt_stream)
239
+ else:
240
+ query = rotate_activation(query)
241
+ key = rotate_activation(key)
242
+
243
+ return query, key
244
+
245
+ def _get_topk_paged(
246
+ self,
247
+ forward_batch: ForwardBatch,
248
+ layer_id: int,
249
+ q_fp8: torch.Tensor,
250
+ weights: torch.Tensor,
251
+ metadata: BaseIndexerMetadata,
252
+ ) -> torch.Tensor:
253
+ if TYPE_CHECKING:
254
+ assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
255
+
256
+ page_size = forward_batch.token_to_kv_pool.page_size
257
+ # NOTE(dark): blocksize = 64 is hardcoded in deep_gemm
258
+ assert page_size == 64, "only support page size 64"
259
+
260
+ # NOTE(dark): this support extend/decode/decode+graph
261
+ block_tables = metadata.get_page_table_64()
262
+
263
+ max_seq_len = block_tables.shape[1] * page_size
264
+ kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer(
265
+ layer_id=layer_id
266
+ )
267
+
268
+ blocksize = page_size
269
+ if forward_batch.forward_mode.is_target_verify():
270
+ seqlens_32 = metadata.get_seqlens_expanded()
271
+ else:
272
+ seqlens_32 = metadata.get_seqlens_int32()
273
+ # NOTE(dark): 132 is SM count on H200/B200, not magic number
274
+ schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
275
+ seqlens_32, blocksize, self.sm_count
276
+ )
277
+
278
+ assert len(q_fp8.shape) == 3
279
+ q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now
280
+ assert len(kv_cache_fp8.shape) == 2
281
+ block_kv = 64
282
+ num_heads_kv = 1
283
+ head_dim_with_sf = 132
284
+ kv_cache_fp8 = kv_cache_fp8.view(
285
+ kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf
286
+ )
287
+ assert len(weights.shape) == 3
288
+ weights = weights.squeeze(2)
289
+
290
+ logits = deep_gemm.fp8_paged_mqa_logits(
291
+ q_fp8,
292
+ kv_cache_fp8,
293
+ weights,
294
+ seqlens_32,
295
+ block_tables,
296
+ schedule_metadata,
297
+ max_seq_len,
298
+ clean_logits=False,
299
+ )
300
+
301
+ # NOTE(dark): logits should be cleaned in topk_transform
302
+ topk_result = metadata.topk_transform(logits, self.index_topk)
303
+ return topk_result
304
+
305
+ def _get_topk_ragged(
306
+ self,
307
+ forward_batch: ForwardBatch,
308
+ layer_id: int,
309
+ q_fp8: torch.Tensor,
310
+ weights: torch.Tensor,
311
+ metadata: BaseIndexerMetadata,
312
+ ) -> torch.Tensor:
313
+ if TYPE_CHECKING:
314
+ assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
315
+
316
+ page_size = forward_batch.token_to_kv_pool.page_size
317
+ assert page_size == 64, "only support page size 64"
318
+ assert len(weights.shape) == 3
319
+ weights = weights.squeeze(-1)
320
+ k_fp8_list = []
321
+ k_scale_list = []
322
+ ks_list = []
323
+ ke_list = []
324
+ offset = 0
325
+ seq_lens_expanded = metadata.get_seqlens_expanded()
326
+ block_tables = metadata.get_page_table_64()
327
+
328
+ assert (
329
+ forward_batch.seq_lens_cpu is not None
330
+ and forward_batch.extend_seq_lens_cpu is not None
331
+ )
332
+
333
+ for i in range(forward_batch.batch_size):
334
+ seq_len = forward_batch.seq_lens_cpu[i].item()
335
+ assert isinstance(seq_len, int)
336
+ k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
337
+ layer_id,
338
+ seq_len,
339
+ block_tables[i],
340
+ )
341
+ k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
342
+ layer_id,
343
+ seq_len,
344
+ block_tables[i],
345
+ )
346
+ extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
347
+ ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
348
+ ke = ks + seq_lens_expanded[offset : offset + extend_seq_len]
349
+ k_fp8_list.append(k_fp8)
350
+ k_scale_list.append(k_scale)
351
+ ks_list.append(ks)
352
+ ke_list.append(ke)
353
+ offset += extend_seq_len
354
+
355
+ k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
356
+ k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
357
+ kv_fp8 = (k_fp8, k_scale)
358
+ ks = torch.cat(ks_list, dim=0)
359
+ ke = torch.cat(ke_list, dim=0)
360
+
361
+ logits = deep_gemm.fp8_mqa_logits(
362
+ q_fp8[:offset],
363
+ kv_fp8,
364
+ weights[:offset],
365
+ ks,
366
+ ke,
367
+ clean_logits=False,
368
+ )
369
+ token_nums, _, _ = q_fp8.shape
370
+ assert logits.shape[0] == len(seq_lens_expanded)
371
+ raw_topk_result = metadata.topk_transform(logits, self.index_topk)
372
+ topk_result = torch.full(
373
+ (token_nums, self.index_topk), -1, device=q_fp8.device, dtype=torch.int32
374
+ )
375
+ topk_result[:offset] = raw_topk_result
376
+ return topk_result
377
+
378
+ def forward_indexer(
379
+ self,
380
+ q_fp8: torch.Tensor,
381
+ weights: torch.Tensor,
382
+ forward_batch: ForwardBatch,
383
+ topk: int,
384
+ layer_id: int,
385
+ ) -> Optional[torch.Tensor]:
386
+ if not is_npu():
387
+ from sglang.srt.layers.attention.nsa.tilelang_kernel import fp8_index
388
+
389
+ page_size = forward_batch.token_to_kv_pool.page_size
390
+ assert page_size == 64, "only support page size 64"
391
+
392
+ assert len(weights.shape) == 3
393
+ weights = weights.squeeze(-1)
394
+
395
+ # logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
396
+ k_fp8_list = []
397
+ k_scale_list = []
398
+
399
+ topk_indices_list = []
400
+
401
+ block_tables = forward_batch.req_to_token_pool.req_to_token[
402
+ forward_batch.req_pool_indices, :
403
+ ]
404
+ strided_indices = torch.arange(
405
+ 0, block_tables.shape[-1], page_size, device="cuda"
406
+ )
407
+ block_tables = block_tables[:, strided_indices] // page_size
408
+
409
+ q_len_start = 0
410
+
411
+ for i in range(forward_batch.batch_size):
412
+ seq_len = forward_batch.seq_lens[i].item()
413
+ q_len = (
414
+ forward_batch.extend_seq_lens_cpu[i]
415
+ if forward_batch.forward_mode.is_extend()
416
+ else 1
417
+ )
418
+ q_len_end = q_len_start + q_len
419
+
420
+ q_fp8_partial = q_fp8[q_len_start:q_len_end]
421
+ q_fp8_partial = q_fp8_partial.unsqueeze(0).contiguous()
422
+
423
+ weights_partial = weights[q_len_start:q_len_end]
424
+ weights_partial = weights_partial.squeeze(-1).unsqueeze(0).contiguous()
425
+
426
+ k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
427
+ layer_id,
428
+ seq_len,
429
+ block_tables[i],
430
+ )
431
+ k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
432
+ layer_id,
433
+ seq_len,
434
+ block_tables[i],
435
+ )
436
+
437
+ k_fp8 = k_fp8.view(torch.float8_e4m3fn).unsqueeze(0).contiguous()
438
+ k_scale = k_scale.view(torch.float32).squeeze(-1).unsqueeze(0).contiguous()
439
+
440
+ index_score = fp8_index(
441
+ q_fp8_partial,
442
+ weights_partial,
443
+ k_fp8,
444
+ k_scale,
445
+ )
446
+ end_pos = seq_len
447
+ topk_indices = index_score.topk(min(topk, end_pos), dim=-1)[1].squeeze(0)
448
+
449
+ pad_len = align(topk_indices.shape[-1], 2048) - topk_indices.shape[-1]
450
+ topk_indices = torch.nn.functional.pad(
451
+ topk_indices, (0, pad_len), "constant", -1
452
+ )
453
+
454
+ topk_indices_list.append(topk_indices)
455
+
456
+ q_len_start = q_len_end
457
+
458
+ topk_indices = torch.cat(topk_indices_list, dim=0)
459
+ return topk_indices
460
+
461
+ def forward_cuda(
462
+ self,
463
+ x: torch.Tensor,
464
+ q_lora: torch.Tensor,
465
+ positions: torch.Tensor,
466
+ forward_batch: ForwardBatch,
467
+ layer_id: int,
468
+ ) -> Optional[torch.Tensor]:
469
+ if is_hip():
470
+ from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
471
+ elif not is_npu():
472
+ from sglang.srt.layers.attention.nsa.triton_kernel import act_quant
473
+
474
+ if TYPE_CHECKING:
475
+ assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
476
+
477
+ metadata = forward_batch.attn_backend.get_indexer_metadata(
478
+ layer_id, forward_batch
479
+ )
480
+
481
+ enable_dual_stream = (
482
+ NSA_DUAL_STREAM
483
+ and self.alt_stream is not None
484
+ and get_is_capture_mode()
485
+ and q_lora.shape[0] > 0
486
+ and q_lora.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
487
+ )
488
+
489
+ # skip NSA if attention backend choose to skip this batch
490
+ if metadata is None:
491
+ return None
492
+
493
+ query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
494
+
495
+ if enable_dual_stream:
496
+ current_stream = torch.cuda.current_stream()
497
+ self.alt_stream.wait_stream(current_stream)
498
+
499
+ q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
500
+ with torch.cuda.stream(self.alt_stream):
501
+ k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
502
+ current_stream.wait_stream(self.alt_stream)
503
+ else:
504
+ q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
505
+ k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
506
+
507
+ # k_fp8: (seq_len, head_dim) fp8_e4m3fn
508
+ # k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
509
+ # k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
510
+ # k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
511
+ if not forward_batch.out_cache_loc.is_contiguous():
512
+ forward_batch.out_cache_loc = forward_batch.out_cache_loc.contiguous()
513
+ forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
514
+ layer_id=layer_id,
515
+ loc=forward_batch.out_cache_loc,
516
+ index_k=k_fp8,
517
+ index_k_scale=k_scale,
518
+ )
519
+
520
+ weights = self._get_logits_head_gate(x, q_scale)
521
+
522
+ if is_cuda():
523
+ assert forward_batch.seq_lens_cpu is not None
524
+ if len(forward_batch.seq_lens_cpu) == 0:
525
+ # this seems b/c max-pad, no worries?
526
+ # if x.shape[0] != 0:
527
+ # print(
528
+ # "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result"
529
+ # )
530
+ return torch.full(
531
+ (x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
532
+ )
533
+
534
+ if (
535
+ forward_batch.forward_mode.is_decode_or_idle()
536
+ or forward_batch.forward_mode.is_target_verify()
537
+ ):
538
+ topk_result = self._get_topk_paged(
539
+ forward_batch, layer_id, q_fp8, weights, metadata
540
+ )
541
+ else:
542
+ topk_result = self._get_topk_ragged(
543
+ forward_batch, layer_id, q_fp8, weights, metadata
544
+ )
545
+ else:
546
+ topk_result = self.forward_indexer(
547
+ q_fp8.contiguous(),
548
+ weights,
549
+ forward_batch,
550
+ topk=self.index_topk,
551
+ layer_id=layer_id,
552
+ )
553
+ return topk_result
554
+
555
+ def forward_npu(
556
+ self,
557
+ x: torch.Tensor,
558
+ q_lora: torch.Tensor,
559
+ positions: torch.Tensor,
560
+ forward_batch: ForwardBatch,
561
+ layer_id: int,
562
+ ) -> torch.Tensor:
563
+ import custom_ops # noqa: F401
564
+ import torch_npu
565
+
566
+ from sglang.srt.layers.dp_attention import (
567
+ get_attention_tp_rank,
568
+ get_attention_tp_size,
569
+ )
570
+ from sglang.srt.utils import get_bool_env_var
571
+
572
+ if forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int is None:
573
+ actual_seq_lengths_kv = forward_batch.attn_backend.forward_metadata.seq_lens
574
+ else:
575
+ actual_seq_lengths_kv = (
576
+ forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int
577
+ )
578
+ enable_index_cp = (
579
+ get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4
580
+ )
581
+ is_prefill = forward_batch.forward_mode.is_extend()
582
+
583
+ attention_tp_rank = get_attention_tp_rank()
584
+ attention_tp_size = get_attention_tp_size()
585
+
586
+ cos_sin = self.rotary_emb.cos_sin_cache[positions]
587
+ cos, sin = cos_sin.chunk(2, dim=-1)
588
+ cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
589
+ sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
590
+ if is_prefill and enable_index_cp:
591
+ slice_length = cos.shape[0] // attention_tp_size
592
+ cos = cos[
593
+ slice_length
594
+ * attention_tp_rank : slice_length
595
+ * (attention_tp_rank + 1)
596
+ ]
597
+ sin = sin[
598
+ slice_length
599
+ * attention_tp_rank : slice_length
600
+ * (attention_tp_rank + 1)
601
+ ]
602
+
603
+ slot_mapping = forward_batch.out_cache_loc
604
+ block_table = forward_batch.attn_backend.forward_metadata.block_tables
605
+
606
+ bs = x.shape[0]
607
+
608
+ q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]
609
+ q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128]
610
+ q_pe, q_nope = torch.split(
611
+ q,
612
+ [self.rope_head_dim, self.head_dim - self.rope_head_dim],
613
+ dim=-1,
614
+ ) # [bs, 64, 64 + 64]
615
+
616
+ q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim)
617
+ q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin).view(
618
+ bs, self.n_heads, self.rope_head_dim
619
+ ) # [bs, n, d]
620
+ q = torch.cat([q_pe, q_nope], dim=-1)
621
+
622
+ k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128]
623
+ k = self.k_norm(k_proj)
624
+ k_pe, k_nope = torch.split(
625
+ k,
626
+ [self.rope_head_dim, self.head_dim - self.rope_head_dim],
627
+ dim=-1,
628
+ ) # [bs, 64 + 64]
629
+
630
+ k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim)
631
+ k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin).view(
632
+ bs, 1, self.rope_head_dim
633
+ ) # [bs, 1, d]
634
+ k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128]
635
+
636
+ if is_prefill and enable_index_cp:
637
+ k, local_k = (
638
+ torch.empty(
639
+ (k.shape[0] * attention_tp_size, k.shape[1], k.shape[2]),
640
+ dtype=k.dtype,
641
+ device=k.device,
642
+ ),
643
+ k,
644
+ )
645
+ get_attention_tp_group().all_gather_into_tensor(k, local_k)
646
+
647
+ forward_batch.token_to_kv_pool.set_index_k_buffer(layer_id, slot_mapping, k)
648
+
649
+ indexer_input = {}
650
+ if is_prefill:
651
+ actual_seq_lengths_kv = forward_batch.seq_lens.to(device=q.device)
652
+ actual_seq_lengths_q = forward_batch.seq_lens.cumsum(dim=0).to(
653
+ device=q.device
654
+ )
655
+ if enable_index_cp:
656
+ actual_seq_lengths_q -= bs * attention_tp_rank
657
+ actual_seq_lengths_q = torch.max(
658
+ actual_seq_lengths_q,
659
+ torch.zeros_like(actual_seq_lengths_q).to(
660
+ device=actual_seq_lengths_q.device
661
+ ),
662
+ )
663
+ actual_seq_lengths_q = torch.min(
664
+ actual_seq_lengths_q,
665
+ torch.full(actual_seq_lengths_q.shape, bs).to(
666
+ device=actual_seq_lengths_q.device
667
+ ),
668
+ )
669
+
670
+ else:
671
+ if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None:
672
+ actual_seq_lengths_q = torch.tensor(
673
+ [1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device
674
+ )
675
+ else:
676
+ actual_seq_lengths_q = (
677
+ forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q
678
+ )
679
+
680
+ past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)
681
+
682
+ x = x.view(-1, self.hidden_size)
683
+ weights = self.weights_proj(x)[0]
684
+ block_table = (
685
+ block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table
686
+ )
687
+
688
+ topk_indices = torch.ops.custom.npu_lightning_indexer(
689
+ query=q.view(-1, self.n_heads, self.head_dim),
690
+ key=past_key_states,
691
+ weights=weights,
692
+ actual_seq_lengths_query=actual_seq_lengths_q.to(torch.int32),
693
+ actual_seq_lengths_key=actual_seq_lengths_kv.to(k.device).to(torch.int32),
694
+ block_table=block_table,
695
+ layout_query="TND",
696
+ layout_key="PA_BSND",
697
+ sparse_count=self.index_topk,
698
+ sparse_mode=3,
699
+ )
700
+
701
+ if is_prefill and enable_index_cp:
702
+ topk_indices, local_topk_indices = (
703
+ torch.empty(
704
+ (
705
+ topk_indices.shape[0] * attention_tp_size,
706
+ topk_indices.shape[1],
707
+ topk_indices.shape[2],
708
+ ),
709
+ dtype=topk_indices.dtype,
710
+ device=topk_indices.device,
711
+ ),
712
+ topk_indices,
713
+ )
714
+ get_attention_tp_group().all_gather_into_tensor(
715
+ topk_indices, local_topk_indices
716
+ )
717
+
718
+ return topk_indices