sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,348 @@
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
6
+ from sglang.srt.lora.triton_ops import (
7
+ chunked_sgmv_lora_expand_forward,
8
+ chunked_sgmv_lora_shrink_forward,
9
+ )
10
+ from sglang.srt.lora.utils import LoRABatchInfo
11
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
12
+ from sglang.srt.server_args import ServerArgs
13
+
14
+ MIN_CHUNK_SIZE = 16
15
+
16
+
17
+ class ChunkedSgmvLoRABackend(BaseLoRABackend):
18
+ """
19
+ Chunked LoRA backend using segmented matrix-vector multiplication.
20
+
21
+ This backend is largely based on the SGMV (Segmented Gather Matrix-Vector multiplication) algorithm
22
+ introduced in the Punica paper (https://arxiv.org/pdf/2310.18547). One main variation made here is to
23
+ segment the input sequences into fixed-size chunks, which reduces excessive kernel launches especially
24
+ when the LoRA distribution is skewed.
25
+ """
26
+
27
+ name = "csgmv"
28
+
29
+ def __init__(
30
+ self,
31
+ max_loras_per_batch: int,
32
+ device: torch.device,
33
+ server_args: ServerArgs,
34
+ ):
35
+ super().__init__(max_loras_per_batch, device)
36
+ self.max_chunk_size = server_args.max_lora_chunk_size
37
+
38
+ def run_lora_a_sgemm(
39
+ self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
40
+ ) -> torch.Tensor:
41
+ return chunked_sgmv_lora_shrink_forward(
42
+ x=x,
43
+ weights=weights,
44
+ batch_info=self.batch_info,
45
+ num_slices=1,
46
+ )
47
+
48
+ def run_lora_b_sgemm(
49
+ self,
50
+ x: torch.Tensor,
51
+ weights: torch.Tensor,
52
+ output_offset: torch.Tensor,
53
+ base_output: torch.Tensor = None,
54
+ *args,
55
+ **kwargs
56
+ ) -> torch.Tensor:
57
+ # For simple lora B, we use slice offsets [0, output_dim]
58
+ output_dim = weights.shape[-2]
59
+ max_slice_size = output_dim
60
+ return chunked_sgmv_lora_expand_forward(
61
+ x=x,
62
+ weights=weights,
63
+ batch_info=self.batch_info,
64
+ slice_offsets=output_offset,
65
+ max_slice_size=max_slice_size,
66
+ base_output=base_output,
67
+ )
68
+
69
+ def run_qkv_lora(
70
+ self,
71
+ x: torch.Tensor,
72
+ qkv_lora_a: torch.Tensor,
73
+ qkv_lora_b: torch.Tensor,
74
+ output_offset: torch.Tensor,
75
+ max_qkv_out_dim: int,
76
+ base_output: torch.Tensor = None,
77
+ *args,
78
+ **kwargs
79
+ ) -> torch.Tensor:
80
+
81
+ # x: (s, input_dim)
82
+ # qkv_lora_a: (num_lora, 3 * r, input_dim)
83
+ # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
84
+ assert isinstance(qkv_lora_b, torch.Tensor)
85
+
86
+ lora_a_output = chunked_sgmv_lora_shrink_forward(
87
+ x=x,
88
+ weights=qkv_lora_a,
89
+ batch_info=self.batch_info,
90
+ num_slices=3,
91
+ )
92
+ lora_output = chunked_sgmv_lora_expand_forward(
93
+ x=lora_a_output,
94
+ weights=qkv_lora_b,
95
+ batch_info=self.batch_info,
96
+ slice_offsets=output_offset,
97
+ max_slice_size=max_qkv_out_dim,
98
+ base_output=base_output,
99
+ )
100
+ return lora_output
101
+
102
+ def run_gate_up_lora(
103
+ self,
104
+ x: torch.Tensor,
105
+ gate_up_lora_a: torch.Tensor,
106
+ gate_up_lora_b: torch.Tensor,
107
+ output_offset: torch.Tensor,
108
+ base_output: torch.Tensor = None,
109
+ *args,
110
+ **kwargs
111
+ ) -> torch.Tensor:
112
+
113
+ # x: (s, input_dim)
114
+ # gate_up_lora_a: (num_lora, 2 * r, input_dim)
115
+ # gate_up_lora_b: (num_lora, 2 * output_dim, r)
116
+ assert isinstance(gate_up_lora_b, torch.Tensor)
117
+ output_dim = gate_up_lora_b.shape[-2] // 2
118
+
119
+ # lora_a_output: (s, 2 * r)
120
+ lora_a_output = chunked_sgmv_lora_shrink_forward(
121
+ x=x,
122
+ weights=gate_up_lora_a,
123
+ batch_info=self.batch_info,
124
+ num_slices=2,
125
+ )
126
+ lora_output = chunked_sgmv_lora_expand_forward(
127
+ x=lora_a_output,
128
+ weights=gate_up_lora_b,
129
+ batch_info=self.batch_info,
130
+ slice_offsets=output_offset,
131
+ max_slice_size=output_dim,
132
+ base_output=base_output,
133
+ )
134
+ return lora_output
135
+
136
+ def _determine_chunk_size(self, forward_batch: ForwardBatch) -> int:
137
+ """
138
+ Heuristically determine the chunk size based on token token number in a batch.
139
+
140
+ Args:
141
+ forward_batch (ForwardBatch): The batch information containing sequence lengths.
142
+
143
+ Returns:
144
+ The determined chunk size
145
+ """
146
+
147
+ if self.max_chunk_size <= MIN_CHUNK_SIZE:
148
+ return MIN_CHUNK_SIZE
149
+
150
+ num_tokens = (
151
+ forward_batch.extend_num_tokens
152
+ if forward_batch.forward_mode.is_extend()
153
+ else forward_batch.batch_size
154
+ )
155
+ if num_tokens >= 256:
156
+ chunk_size = 128
157
+ elif num_tokens >= 64:
158
+ chunk_size = 32
159
+ else: # num_tokens < 64
160
+ chunk_size = 16
161
+ return min(self.max_chunk_size, chunk_size)
162
+
163
+ def prepare_lora_batch(
164
+ self,
165
+ forward_batch: ForwardBatch,
166
+ weight_indices: list[int],
167
+ lora_ranks: list[int],
168
+ scalings: list[float],
169
+ batch_info: Optional[LoRABatchInfo] = None,
170
+ ):
171
+ chunk_size = self._determine_chunk_size(forward_batch)
172
+
173
+ permutation, weight_indices_reordered = ChunkedSgmvLoRABackend._get_permutation(
174
+ seq_weight_indices=weight_indices,
175
+ forward_batch=forward_batch,
176
+ )
177
+
178
+ seg_weight_indices, seg_indptr = self._get_segments_info(
179
+ weights_reordered=weight_indices_reordered,
180
+ chunk_size=chunk_size,
181
+ )
182
+ num_segments = len(seg_weight_indices)
183
+
184
+ lora_ranks_tensor = torch.tensor(
185
+ lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
186
+ )
187
+ scalings_tensor = torch.tensor(
188
+ scalings, dtype=torch.float, pin_memory=True, device="cpu"
189
+ )
190
+
191
+ if batch_info is None:
192
+ batch_info = LoRABatchInfo(
193
+ bs=forward_batch.batch_size,
194
+ num_segments=num_segments,
195
+ max_len=chunk_size,
196
+ use_cuda_graph=False,
197
+ seg_indptr=torch.empty(
198
+ (num_segments + 1,), dtype=torch.int32, device=self.device
199
+ ),
200
+ weight_indices=torch.empty(
201
+ (num_segments,), dtype=torch.int32, device=self.device
202
+ ),
203
+ lora_ranks=torch.empty(
204
+ (self.max_loras_per_batch,), dtype=torch.int32, device=self.device
205
+ ),
206
+ scalings=torch.empty(
207
+ (self.max_loras_per_batch,), dtype=torch.float, device=self.device
208
+ ),
209
+ permutation=torch.empty(
210
+ (len(permutation),), dtype=torch.int32, device=self.device
211
+ ),
212
+ # Not used in chunked kernels
213
+ seg_lens=None,
214
+ )
215
+ else:
216
+ batch_info.bs = forward_batch.batch_size
217
+ batch_info.num_segments = num_segments
218
+ batch_info.max_len = chunk_size
219
+
220
+ # Copy to device asynchronously
221
+ batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
222
+ lora_ranks_tensor, non_blocking=True
223
+ )
224
+ batch_info.scalings[: self.max_loras_per_batch].copy_(
225
+ scalings_tensor, non_blocking=True
226
+ )
227
+ batch_info.weight_indices[:num_segments].copy_(
228
+ seg_weight_indices, non_blocking=True
229
+ )
230
+ batch_info.seg_indptr[: num_segments + 1].copy_(seg_indptr, non_blocking=True)
231
+ batch_info.permutation[: len(permutation)].copy_(permutation, non_blocking=True)
232
+
233
+ self.batch_info = batch_info
234
+
235
+ @staticmethod
236
+ def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch):
237
+ """
238
+ Computes permutation indices for reordering tokens by their LoRA adapter assignments.
239
+
240
+ This function implements the "gather" step in Chunked Segmented Gather Matrix Vector
241
+ multiplication by creating a permutation that groups tokens by their LoRA adapter.
242
+ Tokens using the same LoRA adapter are placed together to enable efficient batched
243
+ computation.
244
+
245
+ Example:
246
+ seq_weight_indices = [0, 1, 0] # 3 sequences using adapters [0, 1, 0]
247
+ extend_seq_lens = [2, 1, 3] # sequence lengths [2, 1, 3 tokens]
248
+
249
+ # Creates row_weight_indices: [0, 0, 1, 0, 0, 0] (6 tokens total)
250
+ # Returns permutation: [0, 1, 3, 4, 5, 2] (groups adapter 0 tokens together)
251
+ # weights_reordered: [0, 0, 0, 0, 0, 1] (sorted by adapter)
252
+
253
+ Args:
254
+ seq_weight_indices: List of LoRA adapter indices for each sequence
255
+ forward_batch (ForwardBatch): Batch information containing sequence lengths
256
+
257
+ Returns:
258
+ tuple: (permutation, weights_reordered) where:
259
+ - permutation: Token reordering indices to group by adapter
260
+ - weights_reordered: Sorted adapter indices for each token
261
+ """
262
+ with torch.device("cpu"):
263
+ seq_weight_indices = torch.tensor(seq_weight_indices, dtype=torch.int32)
264
+
265
+ seg_lens_cpu = (
266
+ torch.tensor(
267
+ forward_batch.extend_seq_lens_cpu,
268
+ dtype=torch.int32,
269
+ )
270
+ if forward_batch.forward_mode.is_extend()
271
+ else torch.ones(forward_batch.batch_size, dtype=torch.int32)
272
+ )
273
+
274
+ row_weight_indices = torch.repeat_interleave(
275
+ seq_weight_indices, seg_lens_cpu
276
+ )
277
+ permutation = torch.empty(
278
+ (len(row_weight_indices),), dtype=torch.long, pin_memory=True
279
+ )
280
+ torch.argsort(row_weight_indices, stable=True, out=permutation)
281
+ weights_reordered = row_weight_indices[permutation]
282
+
283
+ return permutation, weights_reordered
284
+
285
+ def _get_segments_info(self, weights_reordered: torch.Tensor, chunk_size: int):
286
+ """
287
+ Computes segment information for chunked SGMV operations.
288
+
289
+ This function takes the reordered weight indices and creates segments of fixed size
290
+ (self.segment_size) for efficient kernel execution. Each segment contains tokens
291
+ that use the same LoRA adapter, enabling vectorized computation.
292
+
293
+ The segmentation is necessary because:
294
+ 1. GPU kernels work efficiently on fixed-size blocks
295
+ 2. Large groups of tokens using the same adapter are split into manageable chunks
296
+ 3. Each segment can be processed independently in parallel
297
+
298
+ Example:
299
+ weights_reordered = [0, 0, 0, 0, 0, 1] # 5 tokens with adapter 0, 1 with adapter 1
300
+ segment_size = 3
301
+
302
+ # Creates segments:
303
+ # Segment 0: tokens 0-2 (adapter 0), length=3
304
+ # Segment 1: tokens 3-4 (adapter 0), length=2
305
+ # Segment 2: token 5 (adapter 1), length=1
306
+
307
+ # Returns:
308
+ # weight_indices_list: [0, 0, 1] (adapter for each segment)
309
+ # seg_indptr: [0, 3, 5, 6] (cumulative segment boundaries)
310
+
311
+ Args:
312
+ weights_reordered (torch.Tensor): Sorted adapter indices for each token
313
+ chunk_size (int): Fixed size for each segment
314
+
315
+ Returns:
316
+ tuple: (weight_indices_list, seg_indptr) where:
317
+ - weight_indices_list: LoRA adapter index for each segment
318
+ - seg_indptr: Cumulative segment boundaries (CSR-style indptr)
319
+ """
320
+ with torch.device("cpu"):
321
+ unique_weights, counts = torch.unique_consecutive(
322
+ weights_reordered, return_counts=True
323
+ )
324
+
325
+ weight_indices_list = []
326
+ seg_lens_list = []
327
+
328
+ for weight_idx, group_len in zip(unique_weights, counts):
329
+ group_len = group_len.item()
330
+ num_segs = (group_len + chunk_size - 1) // chunk_size
331
+
332
+ weight_indices_list.extend([weight_idx.item()] * num_segs)
333
+ seg_lens_list.extend([chunk_size] * (num_segs - 1))
334
+ seg_lens_list.append(group_len - (num_segs - 1) * chunk_size)
335
+
336
+ seg_lens = torch.tensor(seg_lens_list, dtype=torch.int32)
337
+
338
+ weight_indices_list = torch.tensor(
339
+ weight_indices_list, dtype=torch.int32, pin_memory=True
340
+ )
341
+
342
+ seg_indptr = torch.empty(
343
+ (len(seg_lens) + 1,), dtype=torch.int32, pin_memory=True
344
+ )
345
+ seg_indptr[0] = 0
346
+ seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
347
+
348
+ return weight_indices_list, seg_indptr
@@ -11,12 +11,18 @@ from sglang.srt.lora.triton_ops import (
11
11
  )
12
12
  from sglang.srt.lora.utils import LoRABatchInfo
13
13
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
14
+ from sglang.srt.server_args import ServerArgs
14
15
 
15
16
 
16
17
  class TritonLoRABackend(BaseLoRABackend):
17
18
  name = "triton"
18
19
 
19
- def __init__(self, max_loras_per_batch: int, device: torch.device):
20
+ def __init__(
21
+ self,
22
+ max_loras_per_batch: int,
23
+ device: torch.device,
24
+ **kwargs,
25
+ ):
20
26
  super().__init__(max_loras_per_batch, device)
21
27
 
22
28
  def run_lora_a_sgemm(
@@ -30,7 +36,7 @@ class TritonLoRABackend(BaseLoRABackend):
30
36
  weights: torch.Tensor,
31
37
  base_output: torch.Tensor = None,
32
38
  *args,
33
- **kwargs
39
+ **kwargs,
34
40
  ) -> torch.Tensor:
35
41
  return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output)
36
42
 
@@ -43,7 +49,7 @@ class TritonLoRABackend(BaseLoRABackend):
43
49
  max_qkv_out_dim: int,
44
50
  base_output: torch.Tensor = None,
45
51
  *args,
46
- **kwargs
52
+ **kwargs,
47
53
  ) -> torch.Tensor:
48
54
 
49
55
  # x: (s, input_dim)
@@ -69,7 +75,7 @@ class TritonLoRABackend(BaseLoRABackend):
69
75
  gate_up_lora_b: torch.Tensor,
70
76
  base_output: torch.Tensor = None,
71
77
  *args,
72
- **kwargs
78
+ **kwargs,
73
79
  ) -> torch.Tensor:
74
80
 
75
81
  # x: (s, input_dim)
sglang/srt/lora/lora.py CHANGED
@@ -26,16 +26,17 @@ import torch
26
26
  from torch import nn
27
27
 
28
28
  from sglang.srt.configs.load_config import LoadConfig
29
- from sglang.srt.hf_transformers_utils import AutoConfig
30
29
  from sglang.srt.lora.backend.base_backend import BaseLoRABackend
31
-
32
- # from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
30
+ from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
33
31
  from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
34
32
  from sglang.srt.lora.lora_config import LoRAConfig
35
33
  from sglang.srt.model_loader.loader import DefaultModelLoader
34
+ from sglang.srt.utils.hf_transformers_utils import AutoConfig
36
35
 
37
36
  logger = logging.getLogger(__name__)
38
37
 
38
+ SUPPORTED_BACKENDS = (TritonLoRABackend, ChunkedSgmvLoRABackend)
39
+
39
40
 
40
41
  class LoRALayer(nn.Module):
41
42
  def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
@@ -48,6 +49,7 @@ class LoRALayer(nn.Module):
48
49
 
49
50
 
50
51
  class LoRAAdapter(nn.Module):
52
+
51
53
  def __init__(
52
54
  self,
53
55
  uid: str,
@@ -159,8 +161,8 @@ class LoRAAdapter(nn.Module):
159
161
  gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
160
162
  if up_name not in weights:
161
163
  weights[up_name] = torch.zeros_like(weights[weight_name])
162
- assert isinstance(self.lora_backend, TritonLoRABackend), (
163
- f"LoRA weight initialization currently only supported for 'triton' backend. "
164
+ assert isinstance(self.lora_backend, SUPPORTED_BACKENDS), (
165
+ f"LoRA weight initialization currently only supported for LoRA backends: {', '.join(b.name for b in SUPPORTED_BACKENDS)}"
164
166
  f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
165
167
  f"or consider implementing custom initialization logic for other backends."
166
168
  )
@@ -21,7 +21,6 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple
21
21
  import torch
22
22
 
23
23
  from sglang.srt.configs.load_config import LoadConfig
24
- from sglang.srt.hf_transformers_utils import AutoConfig
25
24
  from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
26
25
  from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
27
26
  from sglang.srt.lora.lora import LoRAAdapter
@@ -35,9 +34,11 @@ from sglang.srt.lora.utils import (
35
34
  get_normalized_target_modules,
36
35
  get_target_module_name,
37
36
  )
38
- from sglang.srt.managers.io_struct import LoRAUpdateResult
37
+ from sglang.srt.managers.io_struct import LoRAUpdateOutput
39
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
+ from sglang.srt.server_args import ServerArgs
40
40
  from sglang.srt.utils import replace_submodule
41
+ from sglang.srt.utils.hf_transformers_utils import AutoConfig
41
42
 
42
43
  logger = logging.getLogger(__name__)
43
44
 
@@ -56,6 +57,7 @@ class LoRAManager:
56
57
  max_lora_rank: Optional[int] = None,
57
58
  target_modules: Optional[Iterable[str]] = None,
58
59
  lora_paths: Optional[List[LoRARef]] = None,
60
+ server_args: Optional[ServerArgs] = None,
59
61
  ):
60
62
  self.base_model: torch.nn.Module = base_model
61
63
  self.base_hf_config: AutoConfig = base_hf_config
@@ -72,6 +74,7 @@ class LoRAManager:
72
74
  self.lora_backend: BaseLoRABackend = backend_type(
73
75
  max_loras_per_batch=max_loras_per_batch,
74
76
  device=self.device,
77
+ server_args=server_args,
75
78
  )
76
79
 
77
80
  # Initialize mutable internal state of the LoRAManager.
@@ -104,8 +107,8 @@ class LoRAManager:
104
107
 
105
108
  def create_lora_update_result(
106
109
  self, success: bool, error_message: str = ""
107
- ) -> LoRAUpdateResult:
108
- return LoRAUpdateResult(
110
+ ) -> LoRAUpdateOutput:
111
+ return LoRAUpdateOutput(
109
112
  success=success,
110
113
  error_message=error_message,
111
114
  loaded_adapters={
@@ -114,7 +117,7 @@ class LoRAManager:
114
117
  },
115
118
  )
116
119
 
117
- def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
120
+ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
118
121
  """
119
122
  Load a single LoRA adapter from the specified path.
120
123
 
@@ -171,7 +174,7 @@ class LoRAManager:
171
174
  "`--max-loras-per-batch` or load it as unpinned LoRA adapters."
172
175
  )
173
176
 
174
- def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
177
+ def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
175
178
  """
176
179
  Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
177
180
  delete the corresponding LoRA modules.
@@ -415,6 +418,10 @@ class LoRAManager:
415
418
  replace_submodule(self.base_model, module_name, lora_module)
416
419
  return lora_module
417
420
 
421
+ def should_skip_lora_for_vision_model(self, module_name):
422
+ # TODO: support different vision models
423
+ return module_name.find("vision_model.model") != -1
424
+
418
425
  def init_lora_modules(self):
419
426
  # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
420
427
  self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
@@ -432,6 +439,10 @@ class LoRAManager:
432
439
  ) and not self.base_model.should_apply_lora(module_name):
433
440
  continue
434
441
 
442
+ # Skip vision model
443
+ if self.should_skip_lora_for_vision_model(module_name):
444
+ continue
445
+
435
446
  # The module should be converted if it is included in target_names
436
447
  if module_name.split(".")[-1] in self.target_modules:
437
448
  layer_id = get_layer_id(module_name)
@@ -4,7 +4,6 @@ from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
4
4
  import torch
5
5
 
6
6
  from sglang.srt.distributed import divide
7
- from sglang.srt.hf_transformers_utils import AutoConfig
8
7
  from sglang.srt.lora.layers import BaseLayerWithLoRA
9
8
  from sglang.srt.lora.lora import LoRAAdapter
10
9
  from sglang.srt.lora.lora_config import LoRAConfig
@@ -17,6 +16,7 @@ from sglang.srt.lora.utils import (
17
16
  get_stacked_multiply,
18
17
  get_target_module_name,
19
18
  )
19
+ from sglang.srt.utils.hf_transformers_utils import AutoConfig
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
@@ -1,3 +1,5 @@
1
+ from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward
2
+ from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward
1
3
  from .gate_up_lora_b import gate_up_lora_b_fwd
2
4
  from .qkv_lora_b import qkv_lora_b_fwd
3
5
  from .sgemm_lora_a import sgemm_lora_a_fwd
@@ -8,4 +10,6 @@ __all__ = [
8
10
  "qkv_lora_b_fwd",
9
11
  "sgemm_lora_a_fwd",
10
12
  "sgemm_lora_b_fwd",
13
+ "chunked_sgmv_lora_shrink_forward",
14
+ "chunked_sgmv_lora_expand_forward",
11
15
  ]