sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,348 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
6
+ from sglang.srt.lora.triton_ops import (
7
+ chunked_sgmv_lora_expand_forward,
8
+ chunked_sgmv_lora_shrink_forward,
9
+ )
10
+ from sglang.srt.lora.utils import LoRABatchInfo
11
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
12
+ from sglang.srt.server_args import ServerArgs
13
+
14
+ MIN_CHUNK_SIZE = 16
15
+
16
+
17
+ class ChunkedSgmvLoRABackend(BaseLoRABackend):
18
+ """
19
+ Chunked LoRA backend using segmented matrix-vector multiplication.
20
+
21
+ This backend is largely based on the SGMV (Segmented Gather Matrix-Vector multiplication) algorithm
22
+ introduced in the Punica paper (https://arxiv.org/pdf/2310.18547). One main variation made here is to
23
+ segment the input sequences into fixed-size chunks, which reduces excessive kernel launches especially
24
+ when the LoRA distribution is skewed.
25
+ """
26
+
27
+ name = "csgmv"
28
+
29
+ def __init__(
30
+ self,
31
+ max_loras_per_batch: int,
32
+ device: torch.device,
33
+ server_args: ServerArgs,
34
+ ):
35
+ super().__init__(max_loras_per_batch, device)
36
+ self.max_chunk_size = server_args.max_lora_chunk_size
37
+
38
+ def run_lora_a_sgemm(
39
+ self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
40
+ ) -> torch.Tensor:
41
+ return chunked_sgmv_lora_shrink_forward(
42
+ x=x,
43
+ weights=weights,
44
+ batch_info=self.batch_info,
45
+ num_slices=1,
46
+ )
47
+
48
+ def run_lora_b_sgemm(
49
+ self,
50
+ x: torch.Tensor,
51
+ weights: torch.Tensor,
52
+ output_offset: torch.Tensor,
53
+ base_output: torch.Tensor = None,
54
+ *args,
55
+ **kwargs
56
+ ) -> torch.Tensor:
57
+ # For simple lora B, we use slice offsets [0, output_dim]
58
+ output_dim = weights.shape[-2]
59
+ max_slice_size = output_dim
60
+ return chunked_sgmv_lora_expand_forward(
61
+ x=x,
62
+ weights=weights,
63
+ batch_info=self.batch_info,
64
+ slice_offsets=output_offset,
65
+ max_slice_size=max_slice_size,
66
+ base_output=base_output,
67
+ )
68
+
69
+ def run_qkv_lora(
70
+ self,
71
+ x: torch.Tensor,
72
+ qkv_lora_a: torch.Tensor,
73
+ qkv_lora_b: torch.Tensor,
74
+ output_offset: torch.Tensor,
75
+ max_qkv_out_dim: int,
76
+ base_output: torch.Tensor = None,
77
+ *args,
78
+ **kwargs
79
+ ) -> torch.Tensor:
80
+
81
+ # x: (s, input_dim)
82
+ # qkv_lora_a: (num_lora, 3 * r, input_dim)
83
+ # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
84
+ assert isinstance(qkv_lora_b, torch.Tensor)
85
+
86
+ lora_a_output = chunked_sgmv_lora_shrink_forward(
87
+ x=x,
88
+ weights=qkv_lora_a,
89
+ batch_info=self.batch_info,
90
+ num_slices=3,
91
+ )
92
+ lora_output = chunked_sgmv_lora_expand_forward(
93
+ x=lora_a_output,
94
+ weights=qkv_lora_b,
95
+ batch_info=self.batch_info,
96
+ slice_offsets=output_offset,
97
+ max_slice_size=max_qkv_out_dim,
98
+ base_output=base_output,
99
+ )
100
+ return lora_output
101
+
102
+ def run_gate_up_lora(
103
+ self,
104
+ x: torch.Tensor,
105
+ gate_up_lora_a: torch.Tensor,
106
+ gate_up_lora_b: torch.Tensor,
107
+ output_offset: torch.Tensor,
108
+ base_output: torch.Tensor = None,
109
+ *args,
110
+ **kwargs
111
+ ) -> torch.Tensor:
112
+
113
+ # x: (s, input_dim)
114
+ # gate_up_lora_a: (num_lora, 2 * r, input_dim)
115
+ # gate_up_lora_b: (num_lora, 2 * output_dim, r)
116
+ assert isinstance(gate_up_lora_b, torch.Tensor)
117
+ output_dim = gate_up_lora_b.shape[-2] // 2
118
+
119
+ # lora_a_output: (s, 2 * r)
120
+ lora_a_output = chunked_sgmv_lora_shrink_forward(
121
+ x=x,
122
+ weights=gate_up_lora_a,
123
+ batch_info=self.batch_info,
124
+ num_slices=2,
125
+ )
126
+ lora_output = chunked_sgmv_lora_expand_forward(
127
+ x=lora_a_output,
128
+ weights=gate_up_lora_b,
129
+ batch_info=self.batch_info,
130
+ slice_offsets=output_offset,
131
+ max_slice_size=output_dim,
132
+ base_output=base_output,
133
+ )
134
+ return lora_output
135
+
136
+ def _determine_chunk_size(self, forward_batch: ForwardBatch) -> int:
137
+ """
138
+ Heuristically determine the chunk size based on token token number in a batch.
139
+
140
+ Args:
141
+ forward_batch (ForwardBatch): The batch information containing sequence lengths.
142
+
143
+ Returns:
144
+ The determined chunk size
145
+ """
146
+
147
+ if self.max_chunk_size <= MIN_CHUNK_SIZE:
148
+ return MIN_CHUNK_SIZE
149
+
150
+ num_tokens = (
151
+ forward_batch.extend_num_tokens
152
+ if forward_batch.forward_mode.is_extend()
153
+ else forward_batch.batch_size
154
+ )
155
+ if num_tokens >= 256:
156
+ chunk_size = 128
157
+ elif num_tokens >= 64:
158
+ chunk_size = 32
159
+ else: # num_tokens < 64
160
+ chunk_size = 16
161
+ return min(self.max_chunk_size, chunk_size)
162
+
163
+ def prepare_lora_batch(
164
+ self,
165
+ forward_batch: ForwardBatch,
166
+ weight_indices: list[int],
167
+ lora_ranks: list[int],
168
+ scalings: list[float],
169
+ batch_info: Optional[LoRABatchInfo] = None,
170
+ ):
171
+ chunk_size = self._determine_chunk_size(forward_batch)
172
+
173
+ permutation, weight_indices_reordered = ChunkedSgmvLoRABackend._get_permutation(
174
+ seq_weight_indices=weight_indices,
175
+ forward_batch=forward_batch,
176
+ )
177
+
178
+ seg_weight_indices, seg_indptr = self._get_segments_info(
179
+ weights_reordered=weight_indices_reordered,
180
+ chunk_size=chunk_size,
181
+ )
182
+ num_segments = len(seg_weight_indices)
183
+
184
+ lora_ranks_tensor = torch.tensor(
185
+ lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
186
+ )
187
+ scalings_tensor = torch.tensor(
188
+ scalings, dtype=torch.float, pin_memory=True, device="cpu"
189
+ )
190
+
191
+ if batch_info is None:
192
+ batch_info = LoRABatchInfo(
193
+ bs=forward_batch.batch_size,
194
+ num_segments=num_segments,
195
+ max_len=chunk_size,
196
+ use_cuda_graph=False,
197
+ seg_indptr=torch.empty(
198
+ (num_segments + 1,), dtype=torch.int32, device=self.device
199
+ ),
200
+ weight_indices=torch.empty(
201
+ (num_segments,), dtype=torch.int32, device=self.device
202
+ ),
203
+ lora_ranks=torch.empty(
204
+ (self.max_loras_per_batch,), dtype=torch.int32, device=self.device
205
+ ),
206
+ scalings=torch.empty(
207
+ (self.max_loras_per_batch,), dtype=torch.float, device=self.device
208
+ ),
209
+ permutation=torch.empty(
210
+ (len(permutation),), dtype=torch.int32, device=self.device
211
+ ),
212
+ # Not used in chunked kernels
213
+ seg_lens=None,
214
+ )
215
+ else:
216
+ batch_info.bs = forward_batch.batch_size
217
+ batch_info.num_segments = num_segments
218
+ batch_info.max_len = chunk_size
219
+
220
+ # Copy to device asynchronously
221
+ batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
222
+ lora_ranks_tensor, non_blocking=True
223
+ )
224
+ batch_info.scalings[: self.max_loras_per_batch].copy_(
225
+ scalings_tensor, non_blocking=True
226
+ )
227
+ batch_info.weight_indices[:num_segments].copy_(
228
+ seg_weight_indices, non_blocking=True
229
+ )
230
+ batch_info.seg_indptr[: num_segments + 1].copy_(seg_indptr, non_blocking=True)
231
+ batch_info.permutation[: len(permutation)].copy_(permutation, non_blocking=True)
232
+
233
+ self.batch_info = batch_info
234
+
235
+ @staticmethod
236
+ def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch):
237
+ """
238
+ Computes permutation indices for reordering tokens by their LoRA adapter assignments.
239
+
240
+ This function implements the "gather" step in Chunked Segmented Gather Matrix Vector
241
+ multiplication by creating a permutation that groups tokens by their LoRA adapter.
242
+ Tokens using the same LoRA adapter are placed together to enable efficient batched
243
+ computation.
244
+
245
+ Example:
246
+ seq_weight_indices = [0, 1, 0] # 3 sequences using adapters [0, 1, 0]
247
+ extend_seq_lens = [2, 1, 3] # sequence lengths [2, 1, 3 tokens]
248
+
249
+ # Creates row_weight_indices: [0, 0, 1, 0, 0, 0] (6 tokens total)
250
+ # Returns permutation: [0, 1, 3, 4, 5, 2] (groups adapter 0 tokens together)
251
+ # weights_reordered: [0, 0, 0, 0, 0, 1] (sorted by adapter)
252
+
253
+ Args:
254
+ seq_weight_indices: List of LoRA adapter indices for each sequence
255
+ forward_batch (ForwardBatch): Batch information containing sequence lengths
256
+
257
+ Returns:
258
+ tuple: (permutation, weights_reordered) where:
259
+ - permutation: Token reordering indices to group by adapter
260
+ - weights_reordered: Sorted adapter indices for each token
261
+ """
262
+ with torch.device("cpu"):
263
+ seq_weight_indices = torch.tensor(seq_weight_indices, dtype=torch.int32)
264
+
265
+ seg_lens_cpu = (
266
+ torch.tensor(
267
+ forward_batch.extend_seq_lens_cpu,
268
+ dtype=torch.int32,
269
+ )
270
+ if forward_batch.forward_mode.is_extend()
271
+ else torch.ones(forward_batch.batch_size, dtype=torch.int32)
272
+ )
273
+
274
+ row_weight_indices = torch.repeat_interleave(
275
+ seq_weight_indices, seg_lens_cpu
276
+ )
277
+ permutation = torch.empty(
278
+ (len(row_weight_indices),), dtype=torch.long, pin_memory=True
279
+ )
280
+ torch.argsort(row_weight_indices, stable=True, out=permutation)
281
+ weights_reordered = row_weight_indices[permutation]
282
+
283
+ return permutation, weights_reordered
284
+
285
+ def _get_segments_info(self, weights_reordered: torch.Tensor, chunk_size: int):
286
+ """
287
+ Computes segment information for chunked SGMV operations.
288
+
289
+ This function takes the reordered weight indices and creates segments of fixed size
290
+ (self.segment_size) for efficient kernel execution. Each segment contains tokens
291
+ that use the same LoRA adapter, enabling vectorized computation.
292
+
293
+ The segmentation is necessary because:
294
+ 1. GPU kernels work efficiently on fixed-size blocks
295
+ 2. Large groups of tokens using the same adapter are split into manageable chunks
296
+ 3. Each segment can be processed independently in parallel
297
+
298
+ Example:
299
+ weights_reordered = [0, 0, 0, 0, 0, 1] # 5 tokens with adapter 0, 1 with adapter 1
300
+ segment_size = 3
301
+
302
+ # Creates segments:
303
+ # Segment 0: tokens 0-2 (adapter 0), length=3
304
+ # Segment 1: tokens 3-4 (adapter 0), length=2
305
+ # Segment 2: token 5 (adapter 1), length=1
306
+
307
+ # Returns:
308
+ # weight_indices_list: [0, 0, 1] (adapter for each segment)
309
+ # seg_indptr: [0, 3, 5, 6] (cumulative segment boundaries)
310
+
311
+ Args:
312
+ weights_reordered (torch.Tensor): Sorted adapter indices for each token
313
+ chunk_size (int): Fixed size for each segment
314
+
315
+ Returns:
316
+ tuple: (weight_indices_list, seg_indptr) where:
317
+ - weight_indices_list: LoRA adapter index for each segment
318
+ - seg_indptr: Cumulative segment boundaries (CSR-style indptr)
319
+ """
320
+ with torch.device("cpu"):
321
+ unique_weights, counts = torch.unique_consecutive(
322
+ weights_reordered, return_counts=True
323
+ )
324
+
325
+ weight_indices_list = []
326
+ seg_lens_list = []
327
+
328
+ for weight_idx, group_len in zip(unique_weights, counts):
329
+ group_len = group_len.item()
330
+ num_segs = (group_len + chunk_size - 1) // chunk_size
331
+
332
+ weight_indices_list.extend([weight_idx.item()] * num_segs)
333
+ seg_lens_list.extend([chunk_size] * (num_segs - 1))
334
+ seg_lens_list.append(group_len - (num_segs - 1) * chunk_size)
335
+
336
+ seg_lens = torch.tensor(seg_lens_list, dtype=torch.int32)
337
+
338
+ weight_indices_list = torch.tensor(
339
+ weight_indices_list, dtype=torch.int32, pin_memory=True
340
+ )
341
+
342
+ seg_indptr = torch.empty(
343
+ (len(seg_lens) + 1,), dtype=torch.int32, pin_memory=True
344
+ )
345
+ seg_indptr[0] = 0
346
+ seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
347
+
348
+ return weight_indices_list, seg_indptr
@@ -16,7 +16,12 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
16
16
  class TritonLoRABackend(BaseLoRABackend):
17
17
  name = "triton"
18
18
 
19
- def __init__(self, max_loras_per_batch: int, device: torch.device):
19
+ def __init__(
20
+ self,
21
+ max_loras_per_batch: int,
22
+ device: torch.device,
23
+ **kwargs,
24
+ ):
20
25
  super().__init__(max_loras_per_batch, device)
21
26
 
22
27
  def run_lora_a_sgemm(
@@ -30,7 +35,7 @@ class TritonLoRABackend(BaseLoRABackend):
30
35
  weights: torch.Tensor,
31
36
  base_output: torch.Tensor = None,
32
37
  *args,
33
- **kwargs
38
+ **kwargs,
34
39
  ) -> torch.Tensor:
35
40
  return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output)
36
41
 
@@ -43,7 +48,7 @@ class TritonLoRABackend(BaseLoRABackend):
43
48
  max_qkv_out_dim: int,
44
49
  base_output: torch.Tensor = None,
45
50
  *args,
46
- **kwargs
51
+ **kwargs,
47
52
  ) -> torch.Tensor:
48
53
 
49
54
  # x: (s, input_dim)
@@ -69,7 +74,7 @@ class TritonLoRABackend(BaseLoRABackend):
69
74
  gate_up_lora_b: torch.Tensor,
70
75
  base_output: torch.Tensor = None,
71
76
  *args,
72
- **kwargs
77
+ **kwargs,
73
78
  ) -> torch.Tensor:
74
79
 
75
80
  # x: (s, input_dim)
@@ -0,0 +1,139 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ """
16
+ Eviction policies for LoRA adapter memory management.
17
+ """
18
+
19
+ import logging
20
+ import time
21
+ from abc import ABC, abstractmethod
22
+ from collections import OrderedDict
23
+ from typing import Optional, Set
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class EvictionPolicy(ABC):
29
+ """Abstract base class for LoRA adapter eviction policies."""
30
+
31
+ @abstractmethod
32
+ def mark_used(self, uid: Optional[str]) -> None:
33
+ """Marks an adapter as used."""
34
+ pass
35
+
36
+ @abstractmethod
37
+ def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
38
+ """Selects an adapter to evict from candidates."""
39
+ pass
40
+
41
+ @abstractmethod
42
+ def remove(self, uid: Optional[str]) -> None:
43
+ """Removes an adapter from the policy's tracking."""
44
+ pass
45
+
46
+
47
+ class LRUEvictionPolicy(EvictionPolicy):
48
+ """LRU eviction policy - evicts the least recently used adapter."""
49
+
50
+ def __init__(self):
51
+ self.access_order = OrderedDict() # key=uid, value=last_access_time
52
+ self.total_accesses = 0
53
+ self.eviction_count = 0
54
+
55
+ def mark_used(self, uid: Optional[str]) -> None:
56
+ if uid is not None:
57
+ current_time = time.monotonic()
58
+ # Remove and re-add to move to end (most recent)
59
+ self.access_order.pop(uid, None)
60
+ self.access_order[uid] = current_time
61
+ self.total_accesses += 1
62
+ logger.debug(f"LoRA {uid} marked as used at {current_time}")
63
+
64
+ def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
65
+ """Select the least recently used adapter from candidates."""
66
+ # Base model (currently None, will be replaced with special UID in future)
67
+ # always has lowest priority - evict it first if available
68
+ BASE_MODEL_UID = None # TODO: Replace with special UID constant
69
+ if BASE_MODEL_UID in candidates:
70
+ logger.debug(f"Selected base model for eviction (LRU)")
71
+ self.eviction_count += 1
72
+ return BASE_MODEL_UID
73
+
74
+ # Iterate through access_order (oldest first) to find LRU victim
75
+ for uid in list(self.access_order.keys()):
76
+ if uid in candidates:
77
+ logger.debug(f"Selected LoRA {uid} for eviction (LRU)")
78
+ self.eviction_count += 1
79
+ return uid
80
+
81
+ # Should never reach here if candidates is non-empty
82
+ assert False, f"Failed to select LRU victim from candidates: {candidates}"
83
+
84
+ def remove(self, uid: Optional[str]) -> None:
85
+ if uid is not None:
86
+ self.access_order.pop(uid, None)
87
+ logger.debug(f"Removed LoRA {uid} from LRU tracking")
88
+
89
+
90
+ class FIFOEvictionPolicy(EvictionPolicy):
91
+ """FIFO eviction policy - for backward compatibility."""
92
+
93
+ def __init__(self):
94
+ self.insertion_order = (
95
+ OrderedDict()
96
+ ) # key=uid, OrderedDict maintains insertion order
97
+ self.eviction_count = 0
98
+
99
+ def mark_used(self, uid: Optional[str]) -> None:
100
+ """For FIFO, we only track insertion order (not access time)."""
101
+ if uid is not None and uid not in self.insertion_order:
102
+ self.insertion_order[uid] = (
103
+ True # Value unused, OrderedDict tracks insertion order
104
+ )
105
+
106
+ def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
107
+ """Select the first inserted adapter from candidates."""
108
+ # Base model (currently None, will be replaced with special UID in future)
109
+ # always has lowest priority - evict it first if available
110
+ BASE_MODEL_UID = None # TODO: Replace with special UID constant
111
+ if BASE_MODEL_UID in candidates:
112
+ logger.debug(f"Selected base model for eviction (FIFO)")
113
+ self.eviction_count += 1
114
+ return BASE_MODEL_UID
115
+
116
+ # Iterate through insertion_order (oldest first) to find FIFO victim
117
+ for uid in list(self.insertion_order.keys()):
118
+ if uid in candidates:
119
+ logger.debug(f"Selected LoRA {uid} for eviction (FIFO)")
120
+ self.eviction_count += 1
121
+ return uid
122
+
123
+ # Should never reach here if candidates is non-empty
124
+ assert False, f"Failed to select FIFO victim from candidates: {candidates}"
125
+
126
+ def remove(self, uid: Optional[str]) -> None:
127
+ if uid is not None:
128
+ self.insertion_order.pop(uid, None)
129
+
130
+
131
+ def get_eviction_policy(policy_name: str) -> EvictionPolicy:
132
+ """Factory function to create eviction policy instances."""
133
+ policies = {
134
+ "fifo": FIFOEvictionPolicy,
135
+ "lru": LRUEvictionPolicy,
136
+ }
137
+ if policy_name not in policies:
138
+ raise ValueError(f"Unknown eviction policy: {policy_name}")
139
+ return policies[policy_name]()
sglang/srt/lora/lora.py CHANGED
@@ -26,16 +26,17 @@ import torch
26
26
  from torch import nn
27
27
 
28
28
  from sglang.srt.configs.load_config import LoadConfig
29
- from sglang.srt.hf_transformers_utils import AutoConfig
30
29
  from sglang.srt.lora.backend.base_backend import BaseLoRABackend
31
-
32
- # from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
30
+ from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
33
31
  from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
34
32
  from sglang.srt.lora.lora_config import LoRAConfig
35
33
  from sglang.srt.model_loader.loader import DefaultModelLoader
34
+ from sglang.srt.utils.hf_transformers_utils import AutoConfig
36
35
 
37
36
  logger = logging.getLogger(__name__)
38
37
 
38
+ SUPPORTED_BACKENDS = (TritonLoRABackend, ChunkedSgmvLoRABackend)
39
+
39
40
 
40
41
  class LoRALayer(nn.Module):
41
42
  def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
@@ -48,6 +49,7 @@ class LoRALayer(nn.Module):
48
49
 
49
50
 
50
51
  class LoRAAdapter(nn.Module):
52
+
51
53
  def __init__(
52
54
  self,
53
55
  uid: str,
@@ -159,8 +161,8 @@ class LoRAAdapter(nn.Module):
159
161
  gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
160
162
  if up_name not in weights:
161
163
  weights[up_name] = torch.zeros_like(weights[weight_name])
162
- assert isinstance(self.lora_backend, TritonLoRABackend), (
163
- f"LoRA weight initialization currently only supported for 'triton' backend. "
164
+ assert isinstance(self.lora_backend, SUPPORTED_BACKENDS), (
165
+ f"LoRA weight initialization currently only supported for LoRA backends: {', '.join(b.name for b in SUPPORTED_BACKENDS)}"
164
166
  f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
165
167
  f"or consider implementing custom initialization logic for other backends."
166
168
  )
@@ -16,12 +16,11 @@
16
16
  # and "Punica: Multi-Tenant LoRA Serving"
17
17
 
18
18
  import logging
19
- from typing import Dict, Iterable, List, Optional, Set, Tuple
19
+ from typing import Dict, Iterable, List, Optional
20
20
 
21
21
  import torch
22
22
 
23
23
  from sglang.srt.configs.load_config import LoadConfig
24
- from sglang.srt.hf_transformers_utils import AutoConfig
25
24
  from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
26
25
  from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
27
26
  from sglang.srt.lora.lora import LoRAAdapter
@@ -35,9 +34,11 @@ from sglang.srt.lora.utils import (
35
34
  get_normalized_target_modules,
36
35
  get_target_module_name,
37
36
  )
38
- from sglang.srt.managers.io_struct import LoRAUpdateResult
37
+ from sglang.srt.managers.io_struct import LoRAUpdateOutput
39
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
+ from sglang.srt.server_args import ServerArgs
40
40
  from sglang.srt.utils import replace_submodule
41
+ from sglang.srt.utils.hf_transformers_utils import AutoConfig
41
42
 
42
43
  logger = logging.getLogger(__name__)
43
44
 
@@ -56,6 +57,7 @@ class LoRAManager:
56
57
  max_lora_rank: Optional[int] = None,
57
58
  target_modules: Optional[Iterable[str]] = None,
58
59
  lora_paths: Optional[List[LoRARef]] = None,
60
+ server_args: Optional[ServerArgs] = None,
59
61
  ):
60
62
  self.base_model: torch.nn.Module = base_model
61
63
  self.base_hf_config: AutoConfig = base_hf_config
@@ -66,12 +68,16 @@ class LoRAManager:
66
68
  self.tp_size: int = tp_size
67
69
  self.tp_rank: int = tp_rank
68
70
 
71
+ # Store eviction policy from server args
72
+ self.eviction_policy = server_args.lora_eviction_policy
73
+
69
74
  # LoRA backend for running sgemm kernels
70
75
  logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
71
76
  backend_type = get_backend_from_name(lora_backend)
72
77
  self.lora_backend: BaseLoRABackend = backend_type(
73
78
  max_loras_per_batch=max_loras_per_batch,
74
79
  device=self.device,
80
+ server_args=server_args,
75
81
  )
76
82
 
77
83
  # Initialize mutable internal state of the LoRAManager.
@@ -104,8 +110,8 @@ class LoRAManager:
104
110
 
105
111
  def create_lora_update_result(
106
112
  self, success: bool, error_message: str = ""
107
- ) -> LoRAUpdateResult:
108
- return LoRAUpdateResult(
113
+ ) -> LoRAUpdateOutput:
114
+ return LoRAUpdateOutput(
109
115
  success=success,
110
116
  error_message=error_message,
111
117
  loaded_adapters={
@@ -114,7 +120,7 @@ class LoRAManager:
114
120
  },
115
121
  )
116
122
 
117
- def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
123
+ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
118
124
  """
119
125
  Load a single LoRA adapter from the specified path.
120
126
 
@@ -128,6 +134,16 @@ class LoRAManager:
128
134
  lora_ref.lora_id not in self.loras
129
135
  ), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
130
136
 
137
+ if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
138
+ return self.create_lora_update_result(
139
+ success=False,
140
+ error_message=(
141
+ f"Already have {self.num_pinned_loras} pinned adapters, "
142
+ f"max allowed is {self.max_loras_per_batch - 1} (reserving 1 slot for dynamic use). "
143
+ f"Please unpin some adapters or increase max_loras_per_batch."
144
+ ),
145
+ )
146
+
131
147
  try:
132
148
  # load configs
133
149
  new_adapter = LoRAConfig(lora_ref.lora_path)
@@ -153,6 +169,15 @@ class LoRAManager:
153
169
  Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
154
170
  """
155
171
 
172
+ # Check if this LoRA adapter is already loaded
173
+ if any(
174
+ lora_ref.lora_name == existing_lora_ref.lora_name
175
+ for existing_lora_ref in self.lora_refs.values()
176
+ ):
177
+ raise ValueError(
178
+ f"Failed to load LoRA adapter {lora_ref.lora_name} because it is already loaded"
179
+ )
180
+
156
181
  # Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration.
157
182
  memory_pool = getattr(self, "memory_pool", None)
158
183
  incompatible = memory_pool and not memory_pool.can_support(lora_config)
@@ -171,7 +196,7 @@ class LoRAManager:
171
196
  "`--max-loras-per-batch` or load it as unpinned LoRA adapters."
172
197
  )
173
198
 
174
- def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
199
+ def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
175
200
  """
176
201
  Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
177
202
  delete the corresponding LoRA modules.
@@ -408,6 +433,7 @@ class LoRAManager:
408
433
  max_lora_rank=self.max_lora_rank,
409
434
  target_modules=self.target_modules,
410
435
  base_model=self.base_model,
436
+ eviction_policy=self.eviction_policy,
411
437
  )
412
438
 
413
439
  def set_lora_module(self, module_name, module):
@@ -18,8 +18,8 @@ from dataclasses import dataclass, field, fields
18
18
  from typing import Dict, List, Optional, Union
19
19
  from uuid import uuid4
20
20
 
21
- from sglang.srt.aio_rwlock import RWLock
22
21
  from sglang.srt.utils import ConcurrentCounter
22
+ from sglang.srt.utils.aio_rwlock import RWLock
23
23
 
24
24
 
25
25
  @dataclass(frozen=True)