sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -11,8 +11,8 @@ from sglang.srt.layers.dp_attention import (
11
11
  is_dp_attention_enabled,
12
12
  )
13
13
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
14
- from sglang.srt.managers.schedule_batch import global_server_args_dict
15
14
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
15
+ from sglang.srt.server_args import get_global_server_args
16
16
  from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
17
17
 
18
18
  if is_cuda():
@@ -27,13 +27,13 @@ if is_cuda():
27
27
  logger = logging.getLogger(__name__)
28
28
 
29
29
  SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
30
- RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
30
+ SGLANG_RETURN_ORIGINAL_LOGPROB = get_bool_env_var("SGLANG_RETURN_ORIGINAL_LOGPROB")
31
31
 
32
32
 
33
33
  class Sampler(nn.Module):
34
34
  def __init__(self):
35
35
  super().__init__()
36
- self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
36
+ self.use_nan_detection = get_global_server_args().enable_nan_detection
37
37
  self.tp_sync_group = get_tp_group().device_group
38
38
 
39
39
  if is_dp_attention_enabled():
@@ -91,20 +91,40 @@ class Sampler(nn.Module):
91
91
  batch_next_token_ids = torch.argmax(logits, -1)
92
92
  if return_logprob:
93
93
  logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
94
-
95
94
  else:
95
+ can_sample_directly_from_probs = (
96
+ not sampling_info.need_top_p_sampling
97
+ and not sampling_info.need_top_k_sampling
98
+ and not sampling_info.need_min_p_sampling
99
+ )
100
+
96
101
  # If requested, cache probabilities from original logits before temperature scaling.
97
- if return_logprob and RETURN_ORIGINAL_LOGPROB:
102
+ if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB:
98
103
  probs_without_temp_scaling = torch.softmax(logits, dim=-1)
99
104
 
105
+ if get_global_server_args().rl_on_policy_target == "fsdp":
106
+ logits_div_temperature = (
107
+ logits.bfloat16().div(sampling_info.temperatures).bfloat16()
108
+ )
109
+ logprobs_via_logsoftmax_kernel = torch.log_softmax(
110
+ logits_div_temperature, dim=-1
111
+ )
112
+
100
113
  # Post process logits
101
114
  logits.div_(sampling_info.temperatures)
102
115
  logits[:] = torch.softmax(logits, dim=-1)
103
116
  probs = logits
104
117
  del logits
105
118
 
106
- if True: # Keep this redundant check to simplify some internal code sync
107
- if global_server_args_dict["sampling_backend"] == "flashinfer":
119
+ if can_sample_directly_from_probs:
120
+ # when we don't need top-k, top-p, or min-p sampling, we can directly sample from the probs
121
+ batch_next_token_ids = sampling_from_probs_torch(
122
+ probs,
123
+ sampling_seed=sampling_info.sampling_seed,
124
+ positions=positions,
125
+ )
126
+ else:
127
+ if get_global_server_args().sampling_backend == "flashinfer":
108
128
  if sampling_info.need_min_p_sampling:
109
129
  probs = top_k_renorm_prob(probs, sampling_info.top_ks)
110
130
  probs = top_p_renorm_prob(probs, sampling_info.top_ps)
@@ -119,7 +139,7 @@ class Sampler(nn.Module):
119
139
  filter_apply_order="joint",
120
140
  check_nan=self.use_nan_detection,
121
141
  )
122
- elif global_server_args_dict["sampling_backend"] == "pytorch":
142
+ elif get_global_server_args().sampling_backend == "pytorch":
123
143
  # A slower fallback implementation with torch native operations.
124
144
  batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
125
145
  probs,
@@ -132,12 +152,15 @@ class Sampler(nn.Module):
132
152
  )
133
153
  else:
134
154
  raise ValueError(
135
- f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
155
+ f"Invalid sampling backend: {get_global_server_args().sampling_backend}"
136
156
  )
137
157
 
138
158
  if return_logprob:
159
+ if get_global_server_args().rl_on_policy_target == "fsdp":
160
+ logprobs = logprobs_via_logsoftmax_kernel
161
+ del logprobs_via_logsoftmax_kernel
139
162
  # clamp to avoid -inf
140
- if RETURN_ORIGINAL_LOGPROB:
163
+ elif SGLANG_RETURN_ORIGINAL_LOGPROB:
141
164
  logprobs = torch.log(probs_without_temp_scaling).clamp(
142
165
  min=torch.finfo(probs_without_temp_scaling.dtype).min
143
166
  )
@@ -288,21 +311,29 @@ def multinomial_with_seed(
288
311
  """
289
312
  n, m = inputs.shape
290
313
  col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
291
- step_seed = seed * 19349663 ^ positions * 73856093
314
+ step_seed = (seed * 19349663) ^ (positions * 73856093)
292
315
  seed_expanded = step_seed.unsqueeze(-1)
293
- hashed = seed_expanded * 8589934591 ^ col_indices * 479001599
316
+ hashed = (seed_expanded * 8589934591) ^ (col_indices * 479001599)
294
317
  uniform_samples = (hashed % (2**24)).float() / (2**24)
295
- epsilon = 1e-9
296
- gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon)
318
+ epsilon = 1e-10
319
+ uniform_samples = uniform_samples.clamp(epsilon, 1.0 - epsilon)
320
+ gumbel_noise = -torch.log(-torch.log(uniform_samples))
297
321
  log_probs = torch.log(inputs + epsilon)
298
322
  perturbed_log_probs = log_probs + gumbel_noise
299
323
  return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
300
324
 
301
325
 
302
- def sampling_from_probs_torch(probs: torch.Tensor):
326
+ def sampling_from_probs_torch(
327
+ probs: torch.Tensor,
328
+ sampling_seed: Optional[torch.Tensor] = None,
329
+ positions: Optional[torch.Tensor] = None,
330
+ ):
303
331
  """A sampling implementation with native pytorch operations, without
304
332
  top-k, top-p, or min-p filtering."""
305
- sampled_index = torch.multinomial(probs, num_samples=1)
333
+ if sampling_seed is not None:
334
+ sampled_index = multinomial_with_seed(probs, sampling_seed, positions)
335
+ else:
336
+ sampled_index = torch.multinomial(probs, num_samples=1)
306
337
  batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
307
338
  return batch_next_token_ids
308
339
 
@@ -0,0 +1,98 @@
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import PretrainedConfig
7
+
8
+ from sglang.srt.model_executor.model_runner import ForwardBatch
9
+
10
+
11
+ @dataclass
12
+ class SparseEmbeddingOutput:
13
+ embeddings: torch.Tensor # [batch_size, vocab_size]
14
+
15
+
16
+ class SparsePooler(nn.Module):
17
+ """A layer that pools hidden states into sparse vocabulary-space embeddings.
18
+
19
+ This layer does the following:
20
+ 1. Applies a linear transformation + ReLU to get token-level weights
21
+ 2. Maps these weights to vocabulary positions using token IDs
22
+ 3. Aggregates weights for repeated tokens using max pooling
23
+ 4. Returns sparse embeddings in vocabulary space
24
+
25
+ Attributes:
26
+ config: Model configuration containing vocab_size and hidden_size
27
+ sparse_linear: Linear layer for computing token weights
28
+ vocab_size: Size of vocabulary for output embeddings
29
+ """
30
+
31
+ def __init__(self, config: PretrainedConfig):
32
+ super().__init__()
33
+
34
+ # Validate required attributes
35
+ if not hasattr(config, "vocab_size"):
36
+ raise AttributeError(
37
+ f"Config {type(config)} missing required 'vocab_size' attribute"
38
+ )
39
+ if not hasattr(config, "hidden_size"):
40
+ raise AttributeError(
41
+ f"Config {type(config)} missing required 'hidden_size' attribute"
42
+ )
43
+
44
+ self.vocab_size = config.vocab_size
45
+ self.sparse_linear = nn.Linear(config.hidden_size, 1)
46
+ self._weights_loaded = False
47
+
48
+ def forward(
49
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
50
+ ) -> SparseEmbeddingOutput:
51
+ """
52
+ Forward pass for sparse pooling.
53
+
54
+ Args:
55
+ hidden_states: Packed sequence hidden states [total_tokens, hidden_size]
56
+ forward_batch: Batch information with sequence lengths and input_ids
57
+
58
+ Returns:
59
+ SparseEmbeddingOutput with embeddings of shape [batch_size, vocab_size]
60
+ """
61
+ if not self._weights_loaded:
62
+ raise ValueError(
63
+ "Sparse pooling weights not loaded. Call load_weights() first"
64
+ )
65
+
66
+ # Apply sparse linear + ReLU to get token weights
67
+ token_weights = F.relu(self.sparse_linear(hidden_states)).squeeze(
68
+ -1
69
+ ) # [total_tokens]
70
+
71
+ # Create batch indices for packed sequences
72
+ batch_indices = torch.repeat_interleave(
73
+ torch.arange(
74
+ len(forward_batch.extend_seq_lens), device=hidden_states.device
75
+ ),
76
+ forward_batch.extend_seq_lens,
77
+ )
78
+
79
+ # Initialize sparse embedding output
80
+ sparse_embedding = torch.zeros(
81
+ len(forward_batch.extend_seq_lens),
82
+ self.vocab_size,
83
+ dtype=token_weights.dtype,
84
+ device=token_weights.device,
85
+ )
86
+
87
+ # Map to vocabulary space using scatter_reduce with amax
88
+ flat_indices = batch_indices * self.vocab_size + forward_batch.input_ids
89
+ sparse_embedding.view(-1).scatter_reduce_(
90
+ 0, flat_indices, token_weights, reduce="amax"
91
+ )
92
+
93
+ return SparseEmbeddingOutput(embeddings=sparse_embedding)
94
+
95
+ def load_weights(self, state_dict: dict):
96
+ """Load weights from state dict (called by the model)."""
97
+ self.sparse_linear.load_state_dict(state_dict)
98
+ self._weights_loaded = True
@@ -1,6 +1,5 @@
1
1
  import logging
2
2
  import re
3
- from functools import lru_cache
4
3
 
5
4
  import torch
6
5
 
@@ -540,7 +540,10 @@ class ParallelLMHead(VocabParallelEmbedding):
540
540
 
541
541
  # We only support pack LMHead if it's not quantized.
542
542
  if _is_cpu and _is_cpu_amx_available:
543
- if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16:
543
+ if hasattr(self, "weight") and self.weight.dtype in [
544
+ torch.bfloat16,
545
+ torch.float16,
546
+ ]:
544
547
  self.quant_method = PackWeightMethod(weight_names=["weight"])
545
548
 
546
549
  if bias:
@@ -11,7 +11,6 @@ from sglang.srt.lora.triton_ops import (
11
11
  )
12
12
  from sglang.srt.lora.utils import LoRABatchInfo
13
13
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
14
- from sglang.srt.server_args import ServerArgs
15
14
 
16
15
 
17
16
  class TritonLoRABackend(BaseLoRABackend):
@@ -0,0 +1,139 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ """
16
+ Eviction policies for LoRA adapter memory management.
17
+ """
18
+
19
+ import logging
20
+ import time
21
+ from abc import ABC, abstractmethod
22
+ from collections import OrderedDict
23
+ from typing import Optional, Set
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class EvictionPolicy(ABC):
29
+ """Abstract base class for LoRA adapter eviction policies."""
30
+
31
+ @abstractmethod
32
+ def mark_used(self, uid: Optional[str]) -> None:
33
+ """Marks an adapter as used."""
34
+ pass
35
+
36
+ @abstractmethod
37
+ def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
38
+ """Selects an adapter to evict from candidates."""
39
+ pass
40
+
41
+ @abstractmethod
42
+ def remove(self, uid: Optional[str]) -> None:
43
+ """Removes an adapter from the policy's tracking."""
44
+ pass
45
+
46
+
47
+ class LRUEvictionPolicy(EvictionPolicy):
48
+ """LRU eviction policy - evicts the least recently used adapter."""
49
+
50
+ def __init__(self):
51
+ self.access_order = OrderedDict() # key=uid, value=last_access_time
52
+ self.total_accesses = 0
53
+ self.eviction_count = 0
54
+
55
+ def mark_used(self, uid: Optional[str]) -> None:
56
+ if uid is not None:
57
+ current_time = time.monotonic()
58
+ # Remove and re-add to move to end (most recent)
59
+ self.access_order.pop(uid, None)
60
+ self.access_order[uid] = current_time
61
+ self.total_accesses += 1
62
+ logger.debug(f"LoRA {uid} marked as used at {current_time}")
63
+
64
+ def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
65
+ """Select the least recently used adapter from candidates."""
66
+ # Base model (currently None, will be replaced with special UID in future)
67
+ # always has lowest priority - evict it first if available
68
+ BASE_MODEL_UID = None # TODO: Replace with special UID constant
69
+ if BASE_MODEL_UID in candidates:
70
+ logger.debug(f"Selected base model for eviction (LRU)")
71
+ self.eviction_count += 1
72
+ return BASE_MODEL_UID
73
+
74
+ # Iterate through access_order (oldest first) to find LRU victim
75
+ for uid in list(self.access_order.keys()):
76
+ if uid in candidates:
77
+ logger.debug(f"Selected LoRA {uid} for eviction (LRU)")
78
+ self.eviction_count += 1
79
+ return uid
80
+
81
+ # Should never reach here if candidates is non-empty
82
+ assert False, f"Failed to select LRU victim from candidates: {candidates}"
83
+
84
+ def remove(self, uid: Optional[str]) -> None:
85
+ if uid is not None:
86
+ self.access_order.pop(uid, None)
87
+ logger.debug(f"Removed LoRA {uid} from LRU tracking")
88
+
89
+
90
+ class FIFOEvictionPolicy(EvictionPolicy):
91
+ """FIFO eviction policy - for backward compatibility."""
92
+
93
+ def __init__(self):
94
+ self.insertion_order = (
95
+ OrderedDict()
96
+ ) # key=uid, OrderedDict maintains insertion order
97
+ self.eviction_count = 0
98
+
99
+ def mark_used(self, uid: Optional[str]) -> None:
100
+ """For FIFO, we only track insertion order (not access time)."""
101
+ if uid is not None and uid not in self.insertion_order:
102
+ self.insertion_order[uid] = (
103
+ True # Value unused, OrderedDict tracks insertion order
104
+ )
105
+
106
+ def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
107
+ """Select the first inserted adapter from candidates."""
108
+ # Base model (currently None, will be replaced with special UID in future)
109
+ # always has lowest priority - evict it first if available
110
+ BASE_MODEL_UID = None # TODO: Replace with special UID constant
111
+ if BASE_MODEL_UID in candidates:
112
+ logger.debug(f"Selected base model for eviction (FIFO)")
113
+ self.eviction_count += 1
114
+ return BASE_MODEL_UID
115
+
116
+ # Iterate through insertion_order (oldest first) to find FIFO victim
117
+ for uid in list(self.insertion_order.keys()):
118
+ if uid in candidates:
119
+ logger.debug(f"Selected LoRA {uid} for eviction (FIFO)")
120
+ self.eviction_count += 1
121
+ return uid
122
+
123
+ # Should never reach here if candidates is non-empty
124
+ assert False, f"Failed to select FIFO victim from candidates: {candidates}"
125
+
126
+ def remove(self, uid: Optional[str]) -> None:
127
+ if uid is not None:
128
+ self.insertion_order.pop(uid, None)
129
+
130
+
131
+ def get_eviction_policy(policy_name: str) -> EvictionPolicy:
132
+ """Factory function to create eviction policy instances."""
133
+ policies = {
134
+ "fifo": FIFOEvictionPolicy,
135
+ "lru": LRUEvictionPolicy,
136
+ }
137
+ if policy_name not in policies:
138
+ raise ValueError(f"Unknown eviction policy: {policy_name}")
139
+ return policies[policy_name]()
@@ -16,7 +16,7 @@
16
16
  # and "Punica: Multi-Tenant LoRA Serving"
17
17
 
18
18
  import logging
19
- from typing import Dict, Iterable, List, Optional, Set, Tuple
19
+ from typing import Dict, Iterable, List, Optional
20
20
 
21
21
  import torch
22
22
 
@@ -68,6 +68,9 @@ class LoRAManager:
68
68
  self.tp_size: int = tp_size
69
69
  self.tp_rank: int = tp_rank
70
70
 
71
+ # Store eviction policy from server args
72
+ self.eviction_policy = server_args.lora_eviction_policy
73
+
71
74
  # LoRA backend for running sgemm kernels
72
75
  logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
73
76
  backend_type = get_backend_from_name(lora_backend)
@@ -131,6 +134,16 @@ class LoRAManager:
131
134
  lora_ref.lora_id not in self.loras
132
135
  ), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
133
136
 
137
+ if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
138
+ return self.create_lora_update_result(
139
+ success=False,
140
+ error_message=(
141
+ f"Already have {self.num_pinned_loras} pinned adapters, "
142
+ f"max allowed is {self.max_loras_per_batch - 1} (reserving 1 slot for dynamic use). "
143
+ f"Please unpin some adapters or increase max_loras_per_batch."
144
+ ),
145
+ )
146
+
134
147
  try:
135
148
  # load configs
136
149
  new_adapter = LoRAConfig(lora_ref.lora_path)
@@ -156,6 +169,15 @@ class LoRAManager:
156
169
  Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
157
170
  """
158
171
 
172
+ # Check if this LoRA adapter is already loaded
173
+ if any(
174
+ lora_ref.lora_name == existing_lora_ref.lora_name
175
+ for existing_lora_ref in self.lora_refs.values()
176
+ ):
177
+ raise ValueError(
178
+ f"Failed to load LoRA adapter {lora_ref.lora_name} because it is already loaded"
179
+ )
180
+
159
181
  # Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration.
160
182
  memory_pool = getattr(self, "memory_pool", None)
161
183
  incompatible = memory_pool and not memory_pool.can_support(lora_config)
@@ -411,6 +433,7 @@ class LoRAManager:
411
433
  max_lora_rank=self.max_lora_rank,
412
434
  target_modules=self.target_modules,
413
435
  base_model=self.base_model,
436
+ eviction_policy=self.eviction_policy,
414
437
  )
415
438
 
416
439
  def set_lora_module(self, module_name, module):
@@ -418,10 +441,6 @@ class LoRAManager:
418
441
  replace_submodule(self.base_model, module_name, lora_module)
419
442
  return lora_module
420
443
 
421
- def should_skip_lora_for_vision_model(self, module_name):
422
- # TODO: support different vision models
423
- return module_name.find("vision_model.model") != -1
424
-
425
444
  def init_lora_modules(self):
426
445
  # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
427
446
  self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
@@ -439,10 +458,6 @@ class LoRAManager:
439
458
  ) and not self.base_model.should_apply_lora(module_name):
440
459
  continue
441
460
 
442
- # Skip vision model
443
- if self.should_skip_lora_for_vision_model(module_name):
444
- continue
445
-
446
461
  # The module should be converted if it is included in target_names
447
462
  if module_name.split(".")[-1] in self.target_modules:
448
463
  layer_id = get_layer_id(module_name)
@@ -18,8 +18,8 @@ from dataclasses import dataclass, field, fields
18
18
  from typing import Dict, List, Optional, Union
19
19
  from uuid import uuid4
20
20
 
21
- from sglang.srt.aio_rwlock import RWLock
22
21
  from sglang.srt.utils import ConcurrentCounter
22
+ from sglang.srt.utils.aio_rwlock import RWLock
23
23
 
24
24
 
25
25
  @dataclass(frozen=True)
@@ -4,6 +4,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
4
4
  import torch
5
5
 
6
6
  from sglang.srt.distributed import divide
7
+ from sglang.srt.lora.eviction_policy import get_eviction_policy
7
8
  from sglang.srt.lora.layers import BaseLayerWithLoRA
8
9
  from sglang.srt.lora.lora import LoRAAdapter
9
10
  from sglang.srt.lora.lora_config import LoRAConfig
@@ -54,6 +55,7 @@ class LoRAMemoryPool:
54
55
  max_lora_rank: int,
55
56
  target_modules: Set[str],
56
57
  base_model: torch.nn.Module,
58
+ eviction_policy: str,
57
59
  ):
58
60
  self.base_hf_config: AutoConfig = base_hf_config
59
61
  self.num_layer: int = base_hf_config.num_hidden_layers
@@ -64,6 +66,9 @@ class LoRAMemoryPool:
64
66
  self.max_lora_rank: int = max_lora_rank
65
67
  self.target_modules: Set[str] = target_modules
66
68
 
69
+ # Initialize eviction policy
70
+ self.eviction_policy = get_eviction_policy(eviction_policy)
71
+
67
72
  # Both A_buffer and B_buffer maps lora weight names to its buffer space.
68
73
  # A_buffer contains num_layer number of row-major tensors with shape
69
74
  # (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
@@ -189,31 +194,50 @@ class LoRAMemoryPool:
189
194
  lora_refs: Dict[str, LoRARef],
190
195
  ):
191
196
  def get_available_buffer_slot():
197
+ # 1. Prioritize empty slots
192
198
  for buffer_id in range(self.max_loras_per_batch):
193
- # Prioritize empty slots
194
199
  if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
195
200
  return buffer_id
196
201
 
202
+ # 2. Memory pool is full, need to evict using policy
203
+ candidates = set()
204
+
197
205
  for buffer_id in range(self.max_loras_per_batch):
198
206
  uid = self.buffer_id_to_uid[buffer_id]
199
207
 
200
- # Evict unneeded lora
201
- if uid not in cur_uids:
202
- # Skip pinned LoRAs
203
- # TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
204
- if uid is not None:
205
- lora_ref = lora_refs.get(uid)
206
- if lora_ref is not None and lora_ref.pinned:
207
- continue
208
-
209
- self.uid_to_buffer_id.pop(uid)
210
- logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.")
211
- self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
212
- return buffer_id
208
+ # Skip if this adapter is needed by current batch
209
+ # TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
210
+ if uid in cur_uids:
211
+ continue
212
+
213
+ # Skip if this adapter is pinned (base model cannot be pinned, so can be evicted)
214
+ if uid is not None:
215
+ lora_ref = lora_refs.get(uid)
216
+ if lora_ref and lora_ref.pinned:
217
+ continue
218
+ candidates.add(uid)
213
219
 
214
- raise ValueError(
215
- "No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
220
+ if not candidates:
221
+ raise ValueError(
222
+ "No available buffer slots found. Please ensure the number of active (pinned) loras is less than max_loras_per_batch."
223
+ )
224
+
225
+ # Select victim using eviction policy
226
+ victim_uid = self.eviction_policy.select_victim(candidates)
227
+
228
+ # Evict the selected victim
229
+ victim_buffer_id = self.uid_to_buffer_id[victim_uid]
230
+ self.uid_to_buffer_id.pop(victim_uid)
231
+ self.eviction_policy.remove(victim_uid)
232
+ self.buffer_id_to_uid[victim_buffer_id] = EMPTY_SLOT
233
+ logger.debug(
234
+ f"Evicting LoRA {victim_uid} from buffer slot {victim_buffer_id}."
216
235
  )
236
+ return victim_buffer_id
237
+
238
+ # Mark all adapters in current batch as used (for LRU tracking)
239
+ for uid in cur_uids:
240
+ self.eviction_policy.mark_used(uid)
217
241
 
218
242
  for uid in cur_uids:
219
243
  if uid not in self.uid_to_buffer_id:
@@ -9,7 +9,7 @@ from sglang.srt.utils import cached_triton_kernel
9
9
 
10
10
 
11
11
  @cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
12
- @triton.jit
12
+ @triton.jit(do_not_specialize=["num_segs"])
13
13
  def _chunked_lora_expand_kernel(
14
14
  # Pointers to matrices
15
15
  x,
@@ -6,8 +6,10 @@ from sglang.srt.lora.utils import LoRABatchInfo
6
6
  from sglang.srt.utils import cached_triton_kernel
7
7
 
8
8
 
9
- @cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
10
- @triton.jit
9
+ @cached_triton_kernel(
10
+ lambda _, kwargs: (kwargs["K"], kwargs["NUM_SLICES"], kwargs["BLOCK_M"])
11
+ )
12
+ @triton.jit(do_not_specialize=["num_segs"])
11
13
  def _chunked_lora_shrink_kernel(
12
14
  # Pointers to matrices
13
15
  x,