sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
sglang/srt/lora/layers.py CHANGED
@@ -66,6 +66,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
66
66
  lora_backend: BaseLoRABackend,
67
67
  ) -> None:
68
68
  super().__init__(base_layer, lora_backend)
69
+ shard_size = self.base_layer.output_partition_sizes[0]
70
+ self.output_offset = torch.tensor(
71
+ [
72
+ 0,
73
+ shard_size,
74
+ ],
75
+ dtype=torch.int32,
76
+ device=next(self.base_layer.parameters()).device,
77
+ )
69
78
 
70
79
  def set_lora_info(
71
80
  self,
@@ -81,6 +90,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
81
90
  lora_output = self.lora_backend.run_lora_b_sgemm(
82
91
  x=lora_a_output,
83
92
  weights=self.B_buffer,
93
+ output_offset=self.output_offset,
84
94
  base_output=base_output,
85
95
  )
86
96
  return lora_output
@@ -130,11 +140,23 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
130
140
  self.A_buffer_gate_up = A_buffer
131
141
  self.B_buffer_gate_up = B_buffer
132
142
 
143
+ shard_size = self.base_layer.output_partition_sizes[0]
144
+ self.output_offset = torch.tensor(
145
+ [
146
+ 0,
147
+ shard_size,
148
+ 2 * shard_size,
149
+ ],
150
+ dtype=torch.int32,
151
+ device=next(self.base_layer.parameters()).device,
152
+ )
153
+
133
154
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
134
155
  lora_output = self.lora_backend.run_gate_up_lora(
135
156
  x=x,
136
157
  gate_up_lora_a=self.A_buffer_gate_up,
137
158
  gate_up_lora_b=self.B_buffer_gate_up,
159
+ output_offset=self.output_offset,
138
160
  base_output=base_output,
139
161
  )
140
162
  return lora_output
@@ -243,12 +265,22 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
243
265
  self.set_lora = True
244
266
  self.A_buffer = A_buffer
245
267
  self.B_buffer = B_buffer
268
+ output_size = self.base_layer.output_size
269
+ self.output_offset = torch.tensor(
270
+ [
271
+ 0,
272
+ output_size,
273
+ ],
274
+ dtype=torch.int32,
275
+ device=next(self.base_layer.parameters()).device,
276
+ )
246
277
 
247
278
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
248
279
  lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
249
280
  lora_output = self.lora_backend.run_lora_b_sgemm(
250
281
  x=lora_a_output,
251
282
  weights=self.B_buffer,
283
+ output_offset=self.output_offset,
252
284
  base_output=base_output,
253
285
  )
254
286
  return lora_output
sglang/srt/lora/lora.py CHANGED
@@ -28,6 +28,9 @@ from torch import nn
28
28
  from sglang.srt.configs.load_config import LoadConfig
29
29
  from sglang.srt.hf_transformers_utils import AutoConfig
30
30
  from sglang.srt.lora.backend.base_backend import BaseLoRABackend
31
+
32
+ # from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
33
+ from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
31
34
  from sglang.srt.lora.lora_config import LoRAConfig
32
35
  from sglang.srt.model_loader.loader import DefaultModelLoader
33
36
 
@@ -156,7 +159,7 @@ class LoRAAdapter(nn.Module):
156
159
  gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
157
160
  if up_name not in weights:
158
161
  weights[up_name] = torch.zeros_like(weights[weight_name])
159
- assert self.lora_backend.name == "triton", (
162
+ assert isinstance(self.lora_backend, TritonLoRABackend), (
160
163
  f"LoRA weight initialization currently only supported for 'triton' backend. "
161
164
  f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
162
165
  f"or consider implementing custom initialization logic for other backends."
@@ -69,7 +69,10 @@ class LoRAManager:
69
69
  # LoRA backend for running sgemm kernels
70
70
  logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
71
71
  backend_type = get_backend_from_name(lora_backend)
72
- self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
72
+ self.lora_backend: BaseLoRABackend = backend_type(
73
+ max_loras_per_batch=max_loras_per_batch,
74
+ device=self.device,
75
+ )
73
76
 
74
77
  # Initialize mutable internal state of the LoRAManager.
75
78
  self.init_state(
@@ -82,29 +85,22 @@ class LoRAManager:
82
85
  self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
83
86
  with torch.device("cuda"):
84
87
  self.cuda_graph_batch_info = LoRABatchInfo(
85
- bs=self.max_bs_in_cuda_graph,
86
- seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
87
- seg_indptr=torch.zeros(
88
- self.max_bs_in_cuda_graph + 1, dtype=torch.int32
89
- ),
88
+ bs=max_bs_in_cuda_graph,
89
+ use_cuda_graph=True,
90
+ num_segments=None,
91
+ seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
92
+ seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32),
90
93
  max_len=1,
91
- weight_indices=torch.zeros(
92
- self.max_bs_in_cuda_graph, dtype=torch.int32
93
- ),
94
+ weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
95
+ permutation=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
94
96
  lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
95
97
  scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
96
98
  )
97
99
 
98
- # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
99
- # across batches.
100
- self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
101
- torch.cumsum(
102
- self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
103
- dim=0,
104
- out=self.cuda_graph_batch_info.seg_indptr[
105
- 1 : self.max_bs_in_cuda_graph + 1
106
- ],
107
- )
100
+ self.lora_backend.init_cuda_graph_batch_info(
101
+ cuda_graph_batch_info=self.cuda_graph_batch_info,
102
+ max_bs_in_cuda_graph=max_bs_in_cuda_graph,
103
+ )
108
104
 
109
105
  def create_lora_update_result(
110
106
  self, success: bool, error_message: str = ""
@@ -232,7 +228,6 @@ class LoRAManager:
232
228
  return required_slots <= mem_pool_vacancy
233
229
 
234
230
  def prepare_lora_batch(self, forward_batch: ForwardBatch):
235
-
236
231
  # Load active loras into lora memory pool
237
232
  cur_uids = set(forward_batch.lora_ids)
238
233
 
@@ -247,102 +242,30 @@ class LoRAManager:
247
242
  # set up batch info shared by all lora modules
248
243
  bs = forward_batch.batch_size
249
244
 
250
- def transfer_adapter_info(
251
- weight_indices_out: torch.Tensor,
252
- lora_ranks_out: torch.Tensor,
253
- scalings_out: torch.Tensor,
254
- ):
255
- """
256
- Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
257
- to device (CUDA) asynchronously.
258
- """
259
- weight_indices = [0] * len(forward_batch.lora_ids)
260
- lora_ranks = [0] * self.max_loras_per_batch
261
- scalings = [0] * self.max_loras_per_batch
262
- for i, uid in enumerate(forward_batch.lora_ids):
263
- weight_indices[i] = self.memory_pool.get_buffer_id(uid)
264
- if uid is not None:
265
- lora = self.loras[uid]
266
- lora_ranks[weight_indices[i]] = lora.config.r
267
- scalings[weight_indices[i]] = lora.scaling
268
-
269
- # Use pinned memory to avoid synchronizations during host-to-device transfer
270
- weight_indices_tensor = torch.tensor(
271
- weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
272
- )
273
- lora_ranks_tensor = torch.tensor(
274
- lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
275
- )
276
- scalings_tensor = torch.tensor(
277
- scalings, dtype=torch.float, pin_memory=True, device="cpu"
278
- )
279
-
280
- # Copy to device tensors asynchronously
281
- weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
282
- lora_ranks_out[: self.max_loras_per_batch].copy_(
283
- lora_ranks_tensor, non_blocking=True
284
- )
285
- scalings_out[: self.max_loras_per_batch].copy_(
286
- scalings_tensor, non_blocking=True
287
- )
288
-
289
- if (
245
+ use_cuda_graph = (
290
246
  hasattr(self, "max_bs_in_cuda_graph")
291
247
  and bs <= self.max_bs_in_cuda_graph
292
248
  and forward_batch.forward_mode.is_cuda_graph()
293
- ):
294
- # Do in-place updates when CUDA graph is enabled and the batch forward mode
295
- # could use CUDA graph.
296
-
297
- transfer_adapter_info(
298
- self.cuda_graph_batch_info.weight_indices,
299
- self.cuda_graph_batch_info.lora_ranks,
300
- self.cuda_graph_batch_info.scalings,
301
- )
302
-
303
- self.cuda_graph_batch_info.bs = bs
304
- self.cuda_graph_batch_info.max_len = 1
305
- batch_info = self.cuda_graph_batch_info
306
- else:
307
- weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
308
- lora_ranks = torch.zeros(
309
- (self.max_loras_per_batch,), dtype=torch.int64, device=self.device
310
- )
311
- scalings = torch.zeros(
312
- (self.max_loras_per_batch,), dtype=torch.float, device=self.device
313
- )
314
- transfer_adapter_info(
315
- weight_indices,
316
- lora_ranks,
317
- scalings,
318
- )
319
-
320
- seg_lens = (
321
- forward_batch.extend_seq_lens
322
- if forward_batch.forward_mode.is_extend()
323
- else torch.ones(bs, device=self.device)
324
- )
325
-
326
- max_len = (
327
- # Calculate max_len from the CPU copy to avoid D2H transfer.
328
- max(forward_batch.extend_seq_lens_cpu)
329
- if forward_batch.forward_mode.is_extend()
330
- else 1
331
- )
249
+ )
332
250
 
333
- seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
334
- seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
335
-
336
- batch_info = LoRABatchInfo(
337
- bs=bs,
338
- seg_lens=seg_lens,
339
- seg_indptr=seg_indptr,
340
- max_len=max_len,
341
- weight_indices=weight_indices,
342
- lora_ranks=lora_ranks,
343
- scalings=scalings,
344
- )
345
- self.lora_backend.set_batch_info(batch_info)
251
+ weight_indices = [0] * len(forward_batch.lora_ids)
252
+ lora_ranks = [0] * self.max_loras_per_batch
253
+ scalings = [0] * self.max_loras_per_batch
254
+ for i, uid in enumerate(forward_batch.lora_ids):
255
+ weight_indices[i] = self.memory_pool.get_buffer_id(uid)
256
+ if uid is not None:
257
+ lora = self.loras[uid]
258
+ lora_ranks[weight_indices[i]] = lora.config.r
259
+ scalings[weight_indices[i]] = lora.scaling
260
+ # Do in-place updates when CUDA graph is enabled and the batch forward mode
261
+ # could use CUDA graph.
262
+ self.lora_backend.prepare_lora_batch(
263
+ forward_batch=forward_batch,
264
+ weight_indices=weight_indices,
265
+ lora_ranks=lora_ranks,
266
+ scalings=scalings,
267
+ batch_info=self.cuda_graph_batch_info if use_cuda_graph else None,
268
+ )
346
269
 
347
270
  def update_lora_info(self):
348
271
  """
@@ -104,12 +104,18 @@ class LoRAMemoryPool:
104
104
  return all(_can_support(x) for x in config)
105
105
 
106
106
  def get_lora_A_shape(
107
- self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
107
+ self,
108
+ module_name: str,
109
+ base_model: torch.nn.Module,
110
+ max_lora_dim: int,
111
+ layer_idx: int,
108
112
  ) -> Tuple[int]:
109
113
  """
110
114
  Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
111
115
  """
112
- input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
116
+ input_dim, _ = get_hidden_dim(
117
+ module_name, self.base_hf_config, base_model, layer_idx
118
+ )
113
119
  c = get_stacked_multiply(module_name)
114
120
  if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
115
121
  input_dim = divide(input_dim, self.tp_size)
@@ -120,12 +126,18 @@ class LoRAMemoryPool:
120
126
  )
121
127
 
122
128
  def get_lora_B_shape(
123
- self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
129
+ self,
130
+ module_name: str,
131
+ base_model: torch.nn.Module,
132
+ max_lora_dim: int,
133
+ layer_idx: int,
124
134
  ) -> Tuple[int]:
125
135
  """
126
136
  Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
127
137
  """
128
- _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
138
+ _, output_dim = get_hidden_dim(
139
+ module_name, self.base_hf_config, base_model, layer_idx
140
+ )
129
141
  if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
130
142
  output_dim = divide(output_dim, self.tp_size)
131
143
  return (
@@ -140,19 +152,21 @@ class LoRAMemoryPool:
140
152
  def init_buffer(
141
153
  buffer: Dict[str, List[torch.Tensor]],
142
154
  target_modules: Set[str],
143
- get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
155
+ get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]],
144
156
  ):
145
157
  for module_name in target_modules:
146
- lora_shape = get_lora_shape_fn(
147
- module_name, base_model, self.max_lora_rank
148
- )
149
158
  buffer[module_name] = [
150
159
  torch.empty(
151
- lora_shape,
160
+ get_lora_shape_fn(
161
+ module_name,
162
+ base_model,
163
+ self.max_lora_rank,
164
+ idx,
165
+ ),
152
166
  dtype=self.dtype,
153
167
  device=device,
154
168
  )
155
- for _ in range(self.num_layer)
169
+ for idx in range(self.num_layer)
156
170
  ]
157
171
 
158
172
  init_buffer(
sglang/srt/lora/utils.py CHANGED
@@ -10,19 +10,19 @@ from sglang.srt.hf_transformers_utils import AutoConfig
10
10
 
11
11
  @dataclass
12
12
  class LoRABatchInfo:
13
+ # The forward mode is using CUDA Graph.
14
+ use_cuda_graph: bool
15
+
13
16
  # Batch size
14
17
  bs: int
15
18
 
16
- # Lengths of each sequence in shape (bs,)
17
- seg_lens: torch.Tensor
19
+ # Number of segments. For triton backend, it is equal to batch size.
20
+ num_segments: int
18
21
 
19
- # Indice pointers of each sequence in shape (bs + 1, )
22
+ # Indice pointers of each segment in shape (num_segments + 1, )
20
23
  seg_indptr: torch.Tensor
21
24
 
22
- # Maximum sequence length of current batch
23
- max_len: int
24
-
25
- # The index of lora adapter used by each sequence, in shape (bs,)
25
+ # The index of lora adapter used by each segment, in shape (num_segments,)
26
26
  weight_indices: torch.Tensor
27
27
 
28
28
  # ranks of each lora adapter, in shape (lora_num,)
@@ -31,6 +31,15 @@ class LoRABatchInfo:
31
31
  # scaling of each lora adapter, in shape (lora_num,)
32
32
  scalings: torch.Tensor
33
33
 
34
+ # Lengths of each segments in shape (num_segments,)
35
+ seg_lens: Optional[torch.Tensor]
36
+
37
+ # Maximum segment length of current batch
38
+ max_len: Optional[int]
39
+
40
+ # The logical (re)ordering of input rows (tokens), in shape (num_tokens,)
41
+ permutation: Optional[torch.Tensor]
42
+
34
43
 
35
44
  class LoRAType(Enum):
36
45
  LORA_A = 0
@@ -48,14 +57,14 @@ def get_layer_id(name: str) -> int:
48
57
 
49
58
 
50
59
  def get_hidden_dim(
51
- module_name: str, config: AutoConfig, base_model: torch.nn.Module
60
+ module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int
52
61
  ) -> Tuple[int]:
53
62
  """
54
63
  Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
55
64
  """
56
65
 
57
66
  if hasattr(base_model, "get_hidden_dim"):
58
- return base_model.get_hidden_dim(module_name)
67
+ return base_model.get_hidden_dim(module_name, layer_idx)
59
68
  else:
60
69
  """
61
70
  WARNING: get_hidden_dim() is not defined,