sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
4
4
  import torch
5
5
 
6
6
  from sglang.srt.distributed import divide
7
- from sglang.srt.hf_transformers_utils import AutoConfig
7
+ from sglang.srt.lora.eviction_policy import get_eviction_policy
8
8
  from sglang.srt.lora.layers import BaseLayerWithLoRA
9
9
  from sglang.srt.lora.lora import LoRAAdapter
10
10
  from sglang.srt.lora.lora_config import LoRAConfig
@@ -17,6 +17,7 @@ from sglang.srt.lora.utils import (
17
17
  get_stacked_multiply,
18
18
  get_target_module_name,
19
19
  )
20
+ from sglang.srt.utils.hf_transformers_utils import AutoConfig
20
21
 
21
22
  logger = logging.getLogger(__name__)
22
23
 
@@ -54,6 +55,7 @@ class LoRAMemoryPool:
54
55
  max_lora_rank: int,
55
56
  target_modules: Set[str],
56
57
  base_model: torch.nn.Module,
58
+ eviction_policy: str,
57
59
  ):
58
60
  self.base_hf_config: AutoConfig = base_hf_config
59
61
  self.num_layer: int = base_hf_config.num_hidden_layers
@@ -64,6 +66,9 @@ class LoRAMemoryPool:
64
66
  self.max_lora_rank: int = max_lora_rank
65
67
  self.target_modules: Set[str] = target_modules
66
68
 
69
+ # Initialize eviction policy
70
+ self.eviction_policy = get_eviction_policy(eviction_policy)
71
+
67
72
  # Both A_buffer and B_buffer maps lora weight names to its buffer space.
68
73
  # A_buffer contains num_layer number of row-major tensors with shape
69
74
  # (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
@@ -189,31 +194,50 @@ class LoRAMemoryPool:
189
194
  lora_refs: Dict[str, LoRARef],
190
195
  ):
191
196
  def get_available_buffer_slot():
197
+ # 1. Prioritize empty slots
192
198
  for buffer_id in range(self.max_loras_per_batch):
193
- # Prioritize empty slots
194
199
  if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
195
200
  return buffer_id
196
201
 
202
+ # 2. Memory pool is full, need to evict using policy
203
+ candidates = set()
204
+
197
205
  for buffer_id in range(self.max_loras_per_batch):
198
206
  uid = self.buffer_id_to_uid[buffer_id]
199
207
 
200
- # Evict unneeded lora
201
- if uid not in cur_uids:
202
- # Skip pinned LoRAs
203
- # TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
204
- if uid is not None:
205
- lora_ref = lora_refs.get(uid)
206
- if lora_ref is not None and lora_ref.pinned:
207
- continue
208
-
209
- self.uid_to_buffer_id.pop(uid)
210
- logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.")
211
- self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
212
- return buffer_id
208
+ # Skip if this adapter is needed by current batch
209
+ # TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
210
+ if uid in cur_uids:
211
+ continue
212
+
213
+ # Skip if this adapter is pinned (base model cannot be pinned, so can be evicted)
214
+ if uid is not None:
215
+ lora_ref = lora_refs.get(uid)
216
+ if lora_ref and lora_ref.pinned:
217
+ continue
218
+ candidates.add(uid)
213
219
 
214
- raise ValueError(
215
- "No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
220
+ if not candidates:
221
+ raise ValueError(
222
+ "No available buffer slots found. Please ensure the number of active (pinned) loras is less than max_loras_per_batch."
223
+ )
224
+
225
+ # Select victim using eviction policy
226
+ victim_uid = self.eviction_policy.select_victim(candidates)
227
+
228
+ # Evict the selected victim
229
+ victim_buffer_id = self.uid_to_buffer_id[victim_uid]
230
+ self.uid_to_buffer_id.pop(victim_uid)
231
+ self.eviction_policy.remove(victim_uid)
232
+ self.buffer_id_to_uid[victim_buffer_id] = EMPTY_SLOT
233
+ logger.debug(
234
+ f"Evicting LoRA {victim_uid} from buffer slot {victim_buffer_id}."
216
235
  )
236
+ return victim_buffer_id
237
+
238
+ # Mark all adapters in current batch as used (for LRU tracking)
239
+ for uid in cur_uids:
240
+ self.eviction_policy.mark_used(uid)
217
241
 
218
242
  for uid in cur_uids:
219
243
  if uid not in self.uid_to_buffer_id:
@@ -1,3 +1,5 @@
1
+ from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward
2
+ from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward
1
3
  from .gate_up_lora_b import gate_up_lora_b_fwd
2
4
  from .qkv_lora_b import qkv_lora_b_fwd
3
5
  from .sgemm_lora_a import sgemm_lora_a_fwd
@@ -8,4 +10,6 @@ __all__ = [
8
10
  "qkv_lora_b_fwd",
9
11
  "sgemm_lora_a_fwd",
10
12
  "sgemm_lora_b_fwd",
13
+ "chunked_sgmv_lora_shrink_forward",
14
+ "chunked_sgmv_lora_expand_forward",
11
15
  ]
@@ -0,0 +1,214 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ from sglang.srt.lora.utils import LoRABatchInfo
8
+ from sglang.srt.utils import cached_triton_kernel
9
+
10
+
11
+ @cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
12
+ @triton.jit(do_not_specialize=["num_segs"])
13
+ def _chunked_lora_expand_kernel(
14
+ # Pointers to matrices
15
+ x,
16
+ weights,
17
+ output,
18
+ # Information on sequence lengths and weight id
19
+ seg_indptr,
20
+ weight_indices,
21
+ lora_ranks,
22
+ permutation,
23
+ num_segs,
24
+ # For fused output scaling
25
+ scalings,
26
+ # Offsets of q/k/v slice on output dimension
27
+ slice_offsets,
28
+ # Meta parameters
29
+ NUM_SLICES: tl.constexpr,
30
+ OUTPUT_DIM: tl.constexpr,
31
+ MAX_RANK: tl.constexpr, # K = R
32
+ BLOCK_M: tl.constexpr,
33
+ BLOCK_N: tl.constexpr,
34
+ BLOCK_K: tl.constexpr,
35
+ ):
36
+ """
37
+ Computes a chunked SGMV for LoRA expand operations.
38
+
39
+ When a sequence's rank is 0, the kernel is essentially a no-op, following
40
+ the convention in pytorch where the product of two matrices of shape (m, 0)
41
+ and (0, n) is an all-zero matrix of shape (m, n).
42
+
43
+ Args:
44
+ x (Tensor): The input tensor, which is the result of the LoRA A projection.
45
+ Shape: (s, num_slices * K), where s is the sum of all sequence lengths in the
46
+ batch and K is the maximum LoRA rank.
47
+ weights (Tensor): The LoRA B weights for all adapters.
48
+ Shape: (num_lora, output_dim, K).
49
+ output (Tensor): The output tensor where the result is stored.
50
+ Shape: (s, output_dim).
51
+ """
52
+ tl.static_assert(NUM_SLICES <= 3)
53
+
54
+ x_stride_0: tl.constexpr = NUM_SLICES * MAX_RANK
55
+ x_stride_1: tl.constexpr = 1
56
+
57
+ w_stride_0: tl.constexpr = OUTPUT_DIM * MAX_RANK
58
+ w_stride_1: tl.constexpr = MAX_RANK
59
+ w_stride_2: tl.constexpr = 1
60
+
61
+ output_stride_0: tl.constexpr = OUTPUT_DIM
62
+ output_stride_1: tl.constexpr = 1
63
+
64
+ pid_s = tl.program_id(axis=2)
65
+ if pid_s >= num_segs:
66
+ return
67
+
68
+ # Current block computes sequence with batch_id,
69
+ # which starts from row seg_start of x with length seg_len.
70
+ # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
71
+ w_index = tl.load(weight_indices + pid_s)
72
+ cur_rank = tl.load(lora_ranks + w_index)
73
+
74
+ # If rank is 0, this kernel is a no-op.
75
+ if cur_rank == 0:
76
+ return
77
+
78
+ seg_start = tl.load(seg_indptr + pid_s)
79
+ seg_end = tl.load(seg_indptr + pid_s + 1)
80
+
81
+ slice_id = tl.program_id(axis=1)
82
+ slice_start = tl.load(slice_offsets + slice_id)
83
+ slice_end = tl.load(slice_offsets + slice_id + 1)
84
+
85
+ scaling = tl.load(scalings + w_index)
86
+ # Adjust K (rank) according to the specific LoRA adapter
87
+ cur_rank = tl.minimum(MAX_RANK, cur_rank)
88
+
89
+ # Map logical sequence index to physical index
90
+ s_offset_logical = tl.arange(0, BLOCK_M) + seg_start
91
+ s_offset_physical = tl.load(
92
+ permutation + s_offset_logical, mask=s_offset_logical < seg_end
93
+ )
94
+
95
+ # Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
96
+ # The pointers will be advanced as we move in the K direction
97
+ # and accumulate
98
+ pid_n = tl.program_id(axis=0)
99
+ n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_start
100
+ k_offset = tl.arange(0, BLOCK_K)
101
+
102
+ x_ptrs = (
103
+ x
104
+ + slice_id * cur_rank * x_stride_1
105
+ + (s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1)
106
+ )
107
+ w_ptrs = (weights + w_index * w_stride_0) + (
108
+ k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
109
+ )
110
+
111
+ # Iterate to compute the block in output matrix
112
+ partial_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
113
+ for k in range(0, tl.cdiv(cur_rank, BLOCK_K)):
114
+ x_tile = tl.load(
115
+ x_ptrs,
116
+ mask=(s_offset_logical[:, None] < seg_end)
117
+ & (k_offset[None, :] < cur_rank - k * BLOCK_K),
118
+ other=0.0,
119
+ )
120
+ w_tile = tl.load(
121
+ w_ptrs,
122
+ mask=(k_offset[:, None] < cur_rank - k * BLOCK_K)
123
+ & (n_offset[None, :] < slice_end),
124
+ other=0.0,
125
+ )
126
+ partial_sum += tl.dot(x_tile, w_tile)
127
+
128
+ x_ptrs += BLOCK_K * x_stride_1
129
+ w_ptrs += BLOCK_K * w_stride_2
130
+
131
+ # Store result to output matrix
132
+ partial_sum *= scaling
133
+ partial_sum = partial_sum.to(x.dtype.element_ty)
134
+ output_ptr = output + (
135
+ s_offset_physical[:, None] * output_stride_0
136
+ + n_offset[None, :] * output_stride_1
137
+ )
138
+ output_mask = (s_offset_logical[:, None] < seg_end) & (
139
+ n_offset[None, :] < slice_end
140
+ )
141
+ partial_sum += tl.load(output_ptr, mask=output_mask, other=0.0)
142
+ tl.store(output_ptr, partial_sum, mask=output_mask)
143
+
144
+
145
+ def chunked_sgmv_lora_expand_forward(
146
+ x: torch.Tensor,
147
+ weights: torch.Tensor,
148
+ batch_info: LoRABatchInfo,
149
+ slice_offsets: torch.Tensor,
150
+ max_slice_size: int,
151
+ base_output: Optional[torch.Tensor],
152
+ ) -> torch.Tensor:
153
+
154
+ # x: (s, slice_num * r)
155
+ # weights: (num_lora, output_dim, r)
156
+ # slice_offsets: boundaries for different slices in the output dimension
157
+ # output: (s, output_dim)
158
+
159
+ # Compute lora_output with shape (s, output_dim) as follows:
160
+ # For each slice i, accumulates:
161
+ # lora_output[:, slice_offsets[i]:slice_offsets[i+1]] += scaling * sgemm(x[:, i*cur_rank:(i+1)*cur_rank], weights[:, slice_offsets[i]:slice_offsets[i+1], :])
162
+
163
+ assert x.is_contiguous()
164
+ assert weights.is_contiguous()
165
+ assert len(x.shape) == 2
166
+ assert len(weights.shape) == 3
167
+
168
+ # Get dims
169
+ M = x.shape[0]
170
+ input_dim = x.shape[1]
171
+ OUTPUT_DIM = weights.shape[1]
172
+ MAX_RANK = weights.shape[2]
173
+ num_slices = len(slice_offsets) - 1
174
+ assert input_dim == num_slices * MAX_RANK
175
+
176
+ # TODO (lifuhuang): fine-tune per operation
177
+ BLOCK_M = batch_info.max_len
178
+ BLOCK_K = 16
179
+ BLOCK_N = 64
180
+
181
+ num_segments = batch_info.num_segments
182
+
183
+ grid = (
184
+ triton.cdiv(max_slice_size, BLOCK_N),
185
+ num_slices, # number of slices in the input/output
186
+ batch_info.bs if batch_info.use_cuda_graph else num_segments,
187
+ )
188
+
189
+ if base_output is None:
190
+ output = torch.zeros((M, OUTPUT_DIM), device=x.device, dtype=x.dtype)
191
+ else:
192
+ output = base_output
193
+
194
+ _chunked_lora_expand_kernel[grid](
195
+ x=x,
196
+ weights=weights,
197
+ output=output,
198
+ seg_indptr=batch_info.seg_indptr,
199
+ weight_indices=batch_info.weight_indices,
200
+ lora_ranks=batch_info.lora_ranks,
201
+ permutation=batch_info.permutation,
202
+ num_segs=num_segments,
203
+ scalings=batch_info.scalings,
204
+ slice_offsets=slice_offsets,
205
+ # constants
206
+ NUM_SLICES=num_slices,
207
+ OUTPUT_DIM=OUTPUT_DIM,
208
+ MAX_RANK=MAX_RANK,
209
+ BLOCK_M=BLOCK_M,
210
+ BLOCK_N=BLOCK_N,
211
+ BLOCK_K=BLOCK_K,
212
+ )
213
+
214
+ return output
@@ -0,0 +1,176 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from sglang.srt.lora.utils import LoRABatchInfo
6
+ from sglang.srt.utils import cached_triton_kernel
7
+
8
+
9
+ @cached_triton_kernel(
10
+ lambda _, kwargs: (kwargs["K"], kwargs["NUM_SLICES"], kwargs["BLOCK_M"])
11
+ )
12
+ @triton.jit(do_not_specialize=["num_segs"])
13
+ def _chunked_lora_shrink_kernel(
14
+ # Pointers to matrices
15
+ x,
16
+ weights,
17
+ output,
18
+ # Information on sequence lengths,ranks and weight id
19
+ seg_indptr,
20
+ weight_indices,
21
+ lora_ranks,
22
+ permutation,
23
+ num_segs,
24
+ # Meta parameters
25
+ N: tl.constexpr, # num_slices * r
26
+ K: tl.constexpr, # input_dim
27
+ NUM_SLICES: tl.constexpr,
28
+ BLOCK_M: tl.constexpr,
29
+ BLOCK_N: tl.constexpr,
30
+ BLOCK_K: tl.constexpr,
31
+ ):
32
+ """
33
+ Computes a chunked SGMV for LoRA shrink operations.
34
+
35
+ The kernel ensures that output[seg_start:seg_start + seg_len, :rank * num_slices]
36
+ stores the product of the input `x` and the LoRA weights for the corresponding
37
+ sequence. This implies that when rank is 0, the kernel is essentially a no-op,
38
+ as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
39
+
40
+ Args:
41
+ x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
42
+ is the sum of all sequence lengths in the batch.
43
+ weights (torch.Tensor): The LoRA A weights for all available adapters,
44
+ with shape `(num_lora, N, K)` where N = num_slices * r.
45
+ output (torch.Tensor): The output tensor of shape `(s, N)`.
46
+ """
47
+ x_stride_1: tl.constexpr = 1
48
+ x_stride_0: tl.constexpr = K
49
+
50
+ w_stride_0: tl.constexpr = N * K
51
+ w_stride_1: tl.constexpr = K
52
+ w_stride_2: tl.constexpr = 1
53
+
54
+ output_stride_0: tl.constexpr = N
55
+ output_stride_1: tl.constexpr = 1
56
+
57
+ pid_s = tl.program_id(1)
58
+ if pid_s >= num_segs:
59
+ return
60
+
61
+ pid_n = tl.program_id(0)
62
+
63
+ # Current block computes sequence with batch_id,
64
+ # which starts from row seg_start of x with length seg_len
65
+ w_index = tl.load(weight_indices + pid_s)
66
+ rank = tl.load(lora_ranks + w_index)
67
+
68
+ # If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
69
+ if rank == 0:
70
+ return
71
+
72
+ seg_start = tl.load(seg_indptr + pid_s)
73
+ seg_end = tl.load(seg_indptr + pid_s + 1)
74
+
75
+ # Adjust N dim according to the specific LoRA adapter
76
+ cur_n = tl.minimum(N, rank * NUM_SLICES)
77
+
78
+ # Map logical sequence index to physical index
79
+ s_offset_logical = tl.arange(0, BLOCK_M) + seg_start
80
+ s_offset_physical = tl.load(
81
+ permutation + s_offset_logical, mask=s_offset_logical < seg_end
82
+ )
83
+
84
+ n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
85
+ k_offset = tl.arange(0, BLOCK_K)
86
+ x_ptrs = x + (
87
+ s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
88
+ )
89
+ w_ptrs = (weights + w_index * w_stride_0) + (
90
+ k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
91
+ )
92
+
93
+ # Iterate to compute the block in output matrix
94
+ partial_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
95
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
96
+ x_tile = tl.load(
97
+ x_ptrs,
98
+ mask=(s_offset_logical[:, None] < seg_end)
99
+ & (k_offset[None, :] < K - k * BLOCK_K),
100
+ other=0.0,
101
+ )
102
+ w_tile = tl.load(
103
+ w_ptrs,
104
+ mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < cur_n),
105
+ other=0.0,
106
+ )
107
+ partial_sum += tl.dot(x_tile, w_tile)
108
+
109
+ x_ptrs += BLOCK_K * x_stride_1
110
+ w_ptrs += BLOCK_K * w_stride_2
111
+
112
+ # Store result to output matrix
113
+ partial_sum = partial_sum.to(x.dtype.element_ty)
114
+ output_ptr = output + (
115
+ s_offset_physical[:, None] * output_stride_0
116
+ + n_offset[None, :] * output_stride_1
117
+ )
118
+ output_mask = (s_offset_logical[:, None] < seg_end) & (n_offset[None, :] < cur_n)
119
+ tl.store(output_ptr, partial_sum, mask=output_mask)
120
+
121
+
122
+ def chunked_sgmv_lora_shrink_forward(
123
+ x: torch.Tensor,
124
+ weights: torch.Tensor,
125
+ batch_info: LoRABatchInfo,
126
+ num_slices: int,
127
+ ) -> torch.Tensor:
128
+ # x: (s, input_dim)
129
+ # weights: (num_lora, num_slices * r, input_dim)
130
+ # output: (s, num_slices * r)
131
+ # num_slices: qkv=3, gate_up=2, others=1
132
+ # when called with multiple slices, the weights.shape[-2] will be num_slices * r
133
+ # input_dim is much larger than r
134
+
135
+ assert x.is_contiguous()
136
+ assert weights.is_contiguous()
137
+ assert len(x.shape) == 2
138
+ assert len(weights.shape) == 3
139
+
140
+ # Block shapes
141
+ # TODO (lifuhuang): experiment with split-k
142
+ BLOCK_M = batch_info.max_len
143
+ BLOCK_N = 16
144
+ BLOCK_K = 256
145
+
146
+ S = x.shape[0]
147
+ N = weights.shape[1]
148
+ K = weights.shape[2]
149
+ assert x.shape[-1] == K
150
+
151
+ num_segments = batch_info.num_segments
152
+ grid = (
153
+ triton.cdiv(N, BLOCK_N),
154
+ batch_info.bs if batch_info.use_cuda_graph else num_segments,
155
+ )
156
+
157
+ output = torch.empty((S, N), device=x.device, dtype=x.dtype)
158
+ _chunked_lora_shrink_kernel[grid](
159
+ x=x,
160
+ weights=weights,
161
+ output=output,
162
+ seg_indptr=batch_info.seg_indptr,
163
+ weight_indices=batch_info.weight_indices,
164
+ lora_ranks=batch_info.lora_ranks,
165
+ permutation=batch_info.permutation,
166
+ num_segs=num_segments,
167
+ # constants
168
+ N=N,
169
+ K=K,
170
+ NUM_SLICES=num_slices,
171
+ BLOCK_M=BLOCK_M,
172
+ BLOCK_N=BLOCK_N,
173
+ BLOCK_K=BLOCK_K,
174
+ )
175
+
176
+ return output
sglang/srt/lora/utils.py CHANGED
@@ -5,7 +5,7 @@ from typing import Iterable, Optional, Set, Tuple
5
5
 
6
6
  import torch
7
7
 
8
- from sglang.srt.hf_transformers_utils import AutoConfig
8
+ from sglang.srt.utils.hf_transformers_utils import AutoConfig
9
9
 
10
10
 
11
11
  @dataclass
@@ -19,6 +19,9 @@ class LoRABatchInfo:
19
19
  # Number of segments. For triton backend, it is equal to batch size.
20
20
  num_segments: int
21
21
 
22
+ # Maximum segment length of current batch
23
+ max_len: int
24
+
22
25
  # Indice pointers of each segment in shape (num_segments + 1, )
23
26
  seg_indptr: torch.Tensor
24
27
 
@@ -34,9 +37,6 @@ class LoRABatchInfo:
34
37
  # Lengths of each segments in shape (num_segments,)
35
38
  seg_lens: Optional[torch.Tensor]
36
39
 
37
- # Maximum segment length of current batch
38
- max_len: Optional[int]
39
-
40
40
  # The logical (re)ordering of input rows (tokens), in shape (num_tokens,)
41
41
  permutation: Optional[torch.Tensor]
42
42
 
@@ -98,6 +98,7 @@ def get_normalized_target_modules(
98
98
  ) -> set[str]:
99
99
  """
100
100
  Mapping a list of target module name to names of the normalized LoRA weights.
101
+ Handles both base module names (e.g., "gate_proj") and prefixed module names (e.g., "feed_forward.gate_proj").
101
102
  """
102
103
  params_mapping = {
103
104
  "q_proj": "qkv_proj",
@@ -109,7 +110,8 @@ def get_normalized_target_modules(
109
110
 
110
111
  result = set()
111
112
  for name in target_modules:
112
- normalized_name = params_mapping.get(name, name)
113
+ base_name = name.split(".")[-1]
114
+ normalized_name = params_mapping.get(base_name, base_name)
113
115
  result.add(normalized_name)
114
116
  return result
115
117