sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,348 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
6
+ from sglang.srt.lora.triton_ops import (
7
+ chunked_sgmv_lora_expand_forward,
8
+ chunked_sgmv_lora_shrink_forward,
9
+ )
10
+ from sglang.srt.lora.utils import LoRABatchInfo
11
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
12
+ from sglang.srt.server_args import ServerArgs
13
+
14
+ MIN_CHUNK_SIZE = 16
15
+
16
+
17
+ class ChunkedSgmvLoRABackend(BaseLoRABackend):
18
+ """
19
+ Chunked LoRA backend using segmented matrix-vector multiplication.
20
+
21
+ This backend is largely based on the SGMV (Segmented Gather Matrix-Vector multiplication) algorithm
22
+ introduced in the Punica paper (https://arxiv.org/pdf/2310.18547). One main variation made here is to
23
+ segment the input sequences into fixed-size chunks, which reduces excessive kernel launches especially
24
+ when the LoRA distribution is skewed.
25
+ """
26
+
27
+ name = "csgmv"
28
+
29
+ def __init__(
30
+ self,
31
+ max_loras_per_batch: int,
32
+ device: torch.device,
33
+ server_args: ServerArgs,
34
+ ):
35
+ super().__init__(max_loras_per_batch, device)
36
+ self.max_chunk_size = server_args.max_lora_chunk_size
37
+
38
+ def run_lora_a_sgemm(
39
+ self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
40
+ ) -> torch.Tensor:
41
+ return chunked_sgmv_lora_shrink_forward(
42
+ x=x,
43
+ weights=weights,
44
+ batch_info=self.batch_info,
45
+ num_slices=1,
46
+ )
47
+
48
+ def run_lora_b_sgemm(
49
+ self,
50
+ x: torch.Tensor,
51
+ weights: torch.Tensor,
52
+ output_offset: torch.Tensor,
53
+ base_output: torch.Tensor = None,
54
+ *args,
55
+ **kwargs
56
+ ) -> torch.Tensor:
57
+ # For simple lora B, we use slice offsets [0, output_dim]
58
+ output_dim = weights.shape[-2]
59
+ max_slice_size = output_dim
60
+ return chunked_sgmv_lora_expand_forward(
61
+ x=x,
62
+ weights=weights,
63
+ batch_info=self.batch_info,
64
+ slice_offsets=output_offset,
65
+ max_slice_size=max_slice_size,
66
+ base_output=base_output,
67
+ )
68
+
69
+ def run_qkv_lora(
70
+ self,
71
+ x: torch.Tensor,
72
+ qkv_lora_a: torch.Tensor,
73
+ qkv_lora_b: torch.Tensor,
74
+ output_offset: torch.Tensor,
75
+ max_qkv_out_dim: int,
76
+ base_output: torch.Tensor = None,
77
+ *args,
78
+ **kwargs
79
+ ) -> torch.Tensor:
80
+
81
+ # x: (s, input_dim)
82
+ # qkv_lora_a: (num_lora, 3 * r, input_dim)
83
+ # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
84
+ assert isinstance(qkv_lora_b, torch.Tensor)
85
+
86
+ lora_a_output = chunked_sgmv_lora_shrink_forward(
87
+ x=x,
88
+ weights=qkv_lora_a,
89
+ batch_info=self.batch_info,
90
+ num_slices=3,
91
+ )
92
+ lora_output = chunked_sgmv_lora_expand_forward(
93
+ x=lora_a_output,
94
+ weights=qkv_lora_b,
95
+ batch_info=self.batch_info,
96
+ slice_offsets=output_offset,
97
+ max_slice_size=max_qkv_out_dim,
98
+ base_output=base_output,
99
+ )
100
+ return lora_output
101
+
102
+ def run_gate_up_lora(
103
+ self,
104
+ x: torch.Tensor,
105
+ gate_up_lora_a: torch.Tensor,
106
+ gate_up_lora_b: torch.Tensor,
107
+ output_offset: torch.Tensor,
108
+ base_output: torch.Tensor = None,
109
+ *args,
110
+ **kwargs
111
+ ) -> torch.Tensor:
112
+
113
+ # x: (s, input_dim)
114
+ # gate_up_lora_a: (num_lora, 2 * r, input_dim)
115
+ # gate_up_lora_b: (num_lora, 2 * output_dim, r)
116
+ assert isinstance(gate_up_lora_b, torch.Tensor)
117
+ output_dim = gate_up_lora_b.shape[-2] // 2
118
+
119
+ # lora_a_output: (s, 2 * r)
120
+ lora_a_output = chunked_sgmv_lora_shrink_forward(
121
+ x=x,
122
+ weights=gate_up_lora_a,
123
+ batch_info=self.batch_info,
124
+ num_slices=2,
125
+ )
126
+ lora_output = chunked_sgmv_lora_expand_forward(
127
+ x=lora_a_output,
128
+ weights=gate_up_lora_b,
129
+ batch_info=self.batch_info,
130
+ slice_offsets=output_offset,
131
+ max_slice_size=output_dim,
132
+ base_output=base_output,
133
+ )
134
+ return lora_output
135
+
136
+ def _determine_chunk_size(self, forward_batch: ForwardBatch) -> int:
137
+ """
138
+ Heuristically determine the chunk size based on token token number in a batch.
139
+
140
+ Args:
141
+ forward_batch (ForwardBatch): The batch information containing sequence lengths.
142
+
143
+ Returns:
144
+ The determined chunk size
145
+ """
146
+
147
+ if self.max_chunk_size <= MIN_CHUNK_SIZE:
148
+ return MIN_CHUNK_SIZE
149
+
150
+ num_tokens = (
151
+ forward_batch.extend_num_tokens
152
+ if forward_batch.forward_mode.is_extend()
153
+ else forward_batch.batch_size
154
+ )
155
+ if num_tokens >= 256:
156
+ chunk_size = 128
157
+ elif num_tokens >= 64:
158
+ chunk_size = 32
159
+ else: # num_tokens < 64
160
+ chunk_size = 16
161
+ return min(self.max_chunk_size, chunk_size)
162
+
163
+ def prepare_lora_batch(
164
+ self,
165
+ forward_batch: ForwardBatch,
166
+ weight_indices: list[int],
167
+ lora_ranks: list[int],
168
+ scalings: list[float],
169
+ batch_info: Optional[LoRABatchInfo] = None,
170
+ ):
171
+ chunk_size = self._determine_chunk_size(forward_batch)
172
+
173
+ permutation, weight_indices_reordered = ChunkedSgmvLoRABackend._get_permutation(
174
+ seq_weight_indices=weight_indices,
175
+ forward_batch=forward_batch,
176
+ )
177
+
178
+ seg_weight_indices, seg_indptr = self._get_segments_info(
179
+ weights_reordered=weight_indices_reordered,
180
+ chunk_size=chunk_size,
181
+ )
182
+ num_segments = len(seg_weight_indices)
183
+
184
+ lora_ranks_tensor = torch.tensor(
185
+ lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
186
+ )
187
+ scalings_tensor = torch.tensor(
188
+ scalings, dtype=torch.float, pin_memory=True, device="cpu"
189
+ )
190
+
191
+ if batch_info is None:
192
+ batch_info = LoRABatchInfo(
193
+ bs=forward_batch.batch_size,
194
+ num_segments=num_segments,
195
+ max_len=chunk_size,
196
+ use_cuda_graph=False,
197
+ seg_indptr=torch.empty(
198
+ (num_segments + 1,), dtype=torch.int32, device=self.device
199
+ ),
200
+ weight_indices=torch.empty(
201
+ (num_segments,), dtype=torch.int32, device=self.device
202
+ ),
203
+ lora_ranks=torch.empty(
204
+ (self.max_loras_per_batch,), dtype=torch.int32, device=self.device
205
+ ),
206
+ scalings=torch.empty(
207
+ (self.max_loras_per_batch,), dtype=torch.float, device=self.device
208
+ ),
209
+ permutation=torch.empty(
210
+ (len(permutation),), dtype=torch.int32, device=self.device
211
+ ),
212
+ # Not used in chunked kernels
213
+ seg_lens=None,
214
+ )
215
+ else:
216
+ batch_info.bs = forward_batch.batch_size
217
+ batch_info.num_segments = num_segments
218
+ batch_info.max_len = chunk_size
219
+
220
+ # Copy to device asynchronously
221
+ batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
222
+ lora_ranks_tensor, non_blocking=True
223
+ )
224
+ batch_info.scalings[: self.max_loras_per_batch].copy_(
225
+ scalings_tensor, non_blocking=True
226
+ )
227
+ batch_info.weight_indices[:num_segments].copy_(
228
+ seg_weight_indices, non_blocking=True
229
+ )
230
+ batch_info.seg_indptr[: num_segments + 1].copy_(seg_indptr, non_blocking=True)
231
+ batch_info.permutation[: len(permutation)].copy_(permutation, non_blocking=True)
232
+
233
+ self.batch_info = batch_info
234
+
235
+ @staticmethod
236
+ def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch):
237
+ """
238
+ Computes permutation indices for reordering tokens by their LoRA adapter assignments.
239
+
240
+ This function implements the "gather" step in Chunked Segmented Gather Matrix Vector
241
+ multiplication by creating a permutation that groups tokens by their LoRA adapter.
242
+ Tokens using the same LoRA adapter are placed together to enable efficient batched
243
+ computation.
244
+
245
+ Example:
246
+ seq_weight_indices = [0, 1, 0] # 3 sequences using adapters [0, 1, 0]
247
+ extend_seq_lens = [2, 1, 3] # sequence lengths [2, 1, 3 tokens]
248
+
249
+ # Creates row_weight_indices: [0, 0, 1, 0, 0, 0] (6 tokens total)
250
+ # Returns permutation: [0, 1, 3, 4, 5, 2] (groups adapter 0 tokens together)
251
+ # weights_reordered: [0, 0, 0, 0, 0, 1] (sorted by adapter)
252
+
253
+ Args:
254
+ seq_weight_indices: List of LoRA adapter indices for each sequence
255
+ forward_batch (ForwardBatch): Batch information containing sequence lengths
256
+
257
+ Returns:
258
+ tuple: (permutation, weights_reordered) where:
259
+ - permutation: Token reordering indices to group by adapter
260
+ - weights_reordered: Sorted adapter indices for each token
261
+ """
262
+ with torch.device("cpu"):
263
+ seq_weight_indices = torch.tensor(seq_weight_indices, dtype=torch.int32)
264
+
265
+ seg_lens_cpu = (
266
+ torch.tensor(
267
+ forward_batch.extend_seq_lens_cpu,
268
+ dtype=torch.int32,
269
+ )
270
+ if forward_batch.forward_mode.is_extend()
271
+ else torch.ones(forward_batch.batch_size, dtype=torch.int32)
272
+ )
273
+
274
+ row_weight_indices = torch.repeat_interleave(
275
+ seq_weight_indices, seg_lens_cpu
276
+ )
277
+ permutation = torch.empty(
278
+ (len(row_weight_indices),), dtype=torch.long, pin_memory=True
279
+ )
280
+ torch.argsort(row_weight_indices, stable=True, out=permutation)
281
+ weights_reordered = row_weight_indices[permutation]
282
+
283
+ return permutation, weights_reordered
284
+
285
+ def _get_segments_info(self, weights_reordered: torch.Tensor, chunk_size: int):
286
+ """
287
+ Computes segment information for chunked SGMV operations.
288
+
289
+ This function takes the reordered weight indices and creates segments of fixed size
290
+ (self.segment_size) for efficient kernel execution. Each segment contains tokens
291
+ that use the same LoRA adapter, enabling vectorized computation.
292
+
293
+ The segmentation is necessary because:
294
+ 1. GPU kernels work efficiently on fixed-size blocks
295
+ 2. Large groups of tokens using the same adapter are split into manageable chunks
296
+ 3. Each segment can be processed independently in parallel
297
+
298
+ Example:
299
+ weights_reordered = [0, 0, 0, 0, 0, 1] # 5 tokens with adapter 0, 1 with adapter 1
300
+ segment_size = 3
301
+
302
+ # Creates segments:
303
+ # Segment 0: tokens 0-2 (adapter 0), length=3
304
+ # Segment 1: tokens 3-4 (adapter 0), length=2
305
+ # Segment 2: token 5 (adapter 1), length=1
306
+
307
+ # Returns:
308
+ # weight_indices_list: [0, 0, 1] (adapter for each segment)
309
+ # seg_indptr: [0, 3, 5, 6] (cumulative segment boundaries)
310
+
311
+ Args:
312
+ weights_reordered (torch.Tensor): Sorted adapter indices for each token
313
+ chunk_size (int): Fixed size for each segment
314
+
315
+ Returns:
316
+ tuple: (weight_indices_list, seg_indptr) where:
317
+ - weight_indices_list: LoRA adapter index for each segment
318
+ - seg_indptr: Cumulative segment boundaries (CSR-style indptr)
319
+ """
320
+ with torch.device("cpu"):
321
+ unique_weights, counts = torch.unique_consecutive(
322
+ weights_reordered, return_counts=True
323
+ )
324
+
325
+ weight_indices_list = []
326
+ seg_lens_list = []
327
+
328
+ for weight_idx, group_len in zip(unique_weights, counts):
329
+ group_len = group_len.item()
330
+ num_segs = (group_len + chunk_size - 1) // chunk_size
331
+
332
+ weight_indices_list.extend([weight_idx.item()] * num_segs)
333
+ seg_lens_list.extend([chunk_size] * (num_segs - 1))
334
+ seg_lens_list.append(group_len - (num_segs - 1) * chunk_size)
335
+
336
+ seg_lens = torch.tensor(seg_lens_list, dtype=torch.int32)
337
+
338
+ weight_indices_list = torch.tensor(
339
+ weight_indices_list, dtype=torch.int32, pin_memory=True
340
+ )
341
+
342
+ seg_indptr = torch.empty(
343
+ (len(seg_lens) + 1,), dtype=torch.int32, pin_memory=True
344
+ )
345
+ seg_indptr[0] = 0
346
+ seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
347
+
348
+ return weight_indices_list, seg_indptr
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  import torch
2
4
 
3
5
  from sglang.srt.lora.backend.base_backend import BaseLoRABackend
@@ -8,12 +10,20 @@ from sglang.srt.lora.triton_ops import (
8
10
  sgemm_lora_b_fwd,
9
11
  )
10
12
  from sglang.srt.lora.utils import LoRABatchInfo
13
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
14
+ from sglang.srt.server_args import ServerArgs
11
15
 
12
16
 
13
17
  class TritonLoRABackend(BaseLoRABackend):
18
+ name = "triton"
14
19
 
15
- def __init__(self, name: str, batch_info: LoRABatchInfo = None):
16
- super().__init__(name, batch_info)
20
+ def __init__(
21
+ self,
22
+ max_loras_per_batch: int,
23
+ device: torch.device,
24
+ **kwargs,
25
+ ):
26
+ super().__init__(max_loras_per_batch, device)
17
27
 
18
28
  def run_lora_a_sgemm(
19
29
  self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
@@ -26,7 +36,7 @@ class TritonLoRABackend(BaseLoRABackend):
26
36
  weights: torch.Tensor,
27
37
  base_output: torch.Tensor = None,
28
38
  *args,
29
- **kwargs
39
+ **kwargs,
30
40
  ) -> torch.Tensor:
31
41
  return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output)
32
42
 
@@ -39,7 +49,7 @@ class TritonLoRABackend(BaseLoRABackend):
39
49
  max_qkv_out_dim: int,
40
50
  base_output: torch.Tensor = None,
41
51
  *args,
42
- **kwargs
52
+ **kwargs,
43
53
  ) -> torch.Tensor:
44
54
 
45
55
  # x: (s, input_dim)
@@ -65,7 +75,7 @@ class TritonLoRABackend(BaseLoRABackend):
65
75
  gate_up_lora_b: torch.Tensor,
66
76
  base_output: torch.Tensor = None,
67
77
  *args,
68
- **kwargs
78
+ **kwargs,
69
79
  ) -> torch.Tensor:
70
80
 
71
81
  # x: (s, input_dim)
@@ -86,3 +96,87 @@ class TritonLoRABackend(BaseLoRABackend):
86
96
  base_output,
87
97
  )
88
98
  return lora_output
99
+
100
+ def init_cuda_graph_batch_info(
101
+ self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
102
+ ):
103
+ # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
104
+ # across batches.
105
+ cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1)
106
+ torch.cumsum(
107
+ cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
108
+ dim=0,
109
+ out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
110
+ )
111
+
112
+ def prepare_lora_batch(
113
+ self,
114
+ forward_batch: ForwardBatch,
115
+ weight_indices: list[int],
116
+ lora_ranks: list[int],
117
+ scalings: list[float],
118
+ batch_info: Optional[LoRABatchInfo] = None,
119
+ ):
120
+ # Use pinned memory to avoid synchronizations during host-to-device transfer
121
+ weight_indices_tensor = torch.tensor(
122
+ weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
123
+ )
124
+ lora_ranks_tensor = torch.tensor(
125
+ lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
126
+ )
127
+ scalings_tensor = torch.tensor(
128
+ scalings, dtype=torch.float, pin_memory=True, device="cpu"
129
+ )
130
+
131
+ bs = forward_batch.batch_size
132
+
133
+ if batch_info is not None:
134
+ assert (
135
+ batch_info.use_cuda_graph
136
+ ), "batch_info.use_cuda_graph must be True when batch_info is provided"
137
+ batch_info.bs = forward_batch.batch_size
138
+ batch_info.num_segments = forward_batch.batch_size
139
+ else:
140
+ max_len = (
141
+ # Calculate max_len from the CPU copy to avoid D2H transfer.
142
+ max(forward_batch.extend_seq_lens_cpu)
143
+ if forward_batch.forward_mode.is_extend()
144
+ else 1
145
+ )
146
+ seg_lens = (
147
+ forward_batch.extend_seq_lens
148
+ if forward_batch.forward_mode.is_extend()
149
+ else torch.ones(bs, device=self.device)
150
+ )
151
+ seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
152
+ seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
153
+
154
+ batch_info = LoRABatchInfo(
155
+ bs=forward_batch.batch_size,
156
+ num_segments=forward_batch.batch_size,
157
+ max_len=max_len,
158
+ use_cuda_graph=False,
159
+ seg_lens=seg_lens,
160
+ seg_indptr=seg_indptr,
161
+ weight_indices=torch.empty(
162
+ (bs,), dtype=torch.int32, device=self.device
163
+ ),
164
+ lora_ranks=torch.empty(
165
+ (self.max_loras_per_batch,), dtype=torch.int64, device=self.device
166
+ ),
167
+ scalings=torch.empty(
168
+ (self.max_loras_per_batch,), dtype=torch.float, device=self.device
169
+ ),
170
+ permutation=None,
171
+ )
172
+
173
+ # Copy to device asynchronously
174
+ batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
175
+ lora_ranks_tensor, non_blocking=True
176
+ )
177
+ batch_info.scalings[: self.max_loras_per_batch].copy_(
178
+ scalings_tensor, non_blocking=True
179
+ )
180
+ batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
181
+
182
+ self.batch_info = batch_info
sglang/srt/lora/layers.py CHANGED
@@ -66,6 +66,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
66
66
  lora_backend: BaseLoRABackend,
67
67
  ) -> None:
68
68
  super().__init__(base_layer, lora_backend)
69
+ shard_size = self.base_layer.output_partition_sizes[0]
70
+ self.output_offset = torch.tensor(
71
+ [
72
+ 0,
73
+ shard_size,
74
+ ],
75
+ dtype=torch.int32,
76
+ device=next(self.base_layer.parameters()).device,
77
+ )
69
78
 
70
79
  def set_lora_info(
71
80
  self,
@@ -81,6 +90,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
81
90
  lora_output = self.lora_backend.run_lora_b_sgemm(
82
91
  x=lora_a_output,
83
92
  weights=self.B_buffer,
93
+ output_offset=self.output_offset,
84
94
  base_output=base_output,
85
95
  )
86
96
  return lora_output
@@ -130,11 +140,23 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
130
140
  self.A_buffer_gate_up = A_buffer
131
141
  self.B_buffer_gate_up = B_buffer
132
142
 
143
+ shard_size = self.base_layer.output_partition_sizes[0]
144
+ self.output_offset = torch.tensor(
145
+ [
146
+ 0,
147
+ shard_size,
148
+ 2 * shard_size,
149
+ ],
150
+ dtype=torch.int32,
151
+ device=next(self.base_layer.parameters()).device,
152
+ )
153
+
133
154
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
134
155
  lora_output = self.lora_backend.run_gate_up_lora(
135
156
  x=x,
136
157
  gate_up_lora_a=self.A_buffer_gate_up,
137
158
  gate_up_lora_b=self.B_buffer_gate_up,
159
+ output_offset=self.output_offset,
138
160
  base_output=base_output,
139
161
  )
140
162
  return lora_output
@@ -243,12 +265,22 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
243
265
  self.set_lora = True
244
266
  self.A_buffer = A_buffer
245
267
  self.B_buffer = B_buffer
268
+ output_size = self.base_layer.output_size
269
+ self.output_offset = torch.tensor(
270
+ [
271
+ 0,
272
+ output_size,
273
+ ],
274
+ dtype=torch.int32,
275
+ device=next(self.base_layer.parameters()).device,
276
+ )
246
277
 
247
278
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
248
279
  lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
249
280
  lora_output = self.lora_backend.run_lora_b_sgemm(
250
281
  x=lora_a_output,
251
282
  weights=self.B_buffer,
283
+ output_offset=self.output_offset,
252
284
  base_output=base_output,
253
285
  )
254
286
  return lora_output
sglang/srt/lora/lora.py CHANGED
@@ -26,13 +26,17 @@ import torch
26
26
  from torch import nn
27
27
 
28
28
  from sglang.srt.configs.load_config import LoadConfig
29
- from sglang.srt.hf_transformers_utils import AutoConfig
30
29
  from sglang.srt.lora.backend.base_backend import BaseLoRABackend
30
+ from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
31
+ from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
31
32
  from sglang.srt.lora.lora_config import LoRAConfig
32
33
  from sglang.srt.model_loader.loader import DefaultModelLoader
34
+ from sglang.srt.utils.hf_transformers_utils import AutoConfig
33
35
 
34
36
  logger = logging.getLogger(__name__)
35
37
 
38
+ SUPPORTED_BACKENDS = (TritonLoRABackend, ChunkedSgmvLoRABackend)
39
+
36
40
 
37
41
  class LoRALayer(nn.Module):
38
42
  def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
@@ -45,6 +49,7 @@ class LoRALayer(nn.Module):
45
49
 
46
50
 
47
51
  class LoRAAdapter(nn.Module):
52
+
48
53
  def __init__(
49
54
  self,
50
55
  uid: str,
@@ -156,8 +161,8 @@ class LoRAAdapter(nn.Module):
156
161
  gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
157
162
  if up_name not in weights:
158
163
  weights[up_name] = torch.zeros_like(weights[weight_name])
159
- assert self.lora_backend.name == "triton", (
160
- f"LoRA weight initialization currently only supported for 'triton' backend. "
164
+ assert isinstance(self.lora_backend, SUPPORTED_BACKENDS), (
165
+ f"LoRA weight initialization currently only supported for LoRA backends: {', '.join(b.name for b in SUPPORTED_BACKENDS)}"
161
166
  f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
162
167
  f"or consider implementing custom initialization logic for other backends."
163
168
  )