sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
sglang/srt/lora/layers.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import List, Tuple
2
-
3
1
  import torch
4
2
  from torch import nn
5
3
 
@@ -79,18 +77,13 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
79
77
  self.B_buffer = B_buffer
80
78
 
81
79
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
82
- backend_kwargs = {"base_output": base_output}
83
80
  lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
84
81
  lora_output = self.lora_backend.run_lora_b_sgemm(
85
- lora_a_output,
86
- self.B_buffer[0],
87
- **backend_kwargs,
88
- )
89
- return (
90
- lora_output
91
- if self.lora_backend.fuse_output_add
92
- else base_output + lora_output
82
+ x=lora_a_output,
83
+ weights=self.B_buffer,
84
+ base_output=base_output,
93
85
  )
86
+ return lora_output
94
87
 
95
88
  def forward(self, input_: torch.Tensor):
96
89
  # duplicate the logic in ColumnParallelLinear
@@ -135,37 +128,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
135
128
  ):
136
129
  self.set_lora = True
137
130
  self.A_buffer_gate_up = A_buffer
138
- if self.lora_backend.fuse_stacked_lora_b:
139
- # B_buffer_gate_up: (num_lora, 2 * output_dim, r)
140
- if getattr(self, "B_buffer_gate_up", None) is None:
141
- self.B_buffer_gate_up = torch.empty(
142
- (
143
- B_buffer[0].shape[0],
144
- 2 * B_buffer[0].shape[1],
145
- B_buffer[0].shape[2],
146
- ),
147
- dtype=B_buffer[0].dtype,
148
- device=B_buffer[0].device,
149
- )
150
- self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
151
- self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
152
- else:
153
- self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
131
+ self.B_buffer_gate_up = B_buffer
154
132
 
155
133
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
156
- backend_kwargs = {"base_output": base_output}
157
-
158
134
  lora_output = self.lora_backend.run_gate_up_lora(
159
- x,
160
- self.A_buffer_gate_up,
161
- self.B_buffer_gate_up,
162
- **backend_kwargs,
163
- )
164
- return (
165
- lora_output
166
- if self.lora_backend.fuse_output_add
167
- else base_output + lora_output
135
+ x=x,
136
+ gate_up_lora_a=self.A_buffer_gate_up,
137
+ gate_up_lora_b=self.B_buffer_gate_up,
138
+ base_output=base_output,
168
139
  )
140
+ return lora_output
169
141
 
170
142
  def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
171
143
  return A
@@ -173,9 +145,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
173
145
  def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
174
146
  # Since the outputs for both gate and up are identical, we use a random one.
175
147
  shard_size = self.base_layer.output_partition_sizes[0]
148
+ gate_size = self.base_layer.output_sizes[0]
176
149
  start_idx = tp_rank * shard_size
177
150
  end_idx = (tp_rank + 1) * shard_size
178
- return B[:, start_idx:end_idx, :]
151
+ return torch.concat(
152
+ (
153
+ B[start_idx:end_idx, :],
154
+ B[gate_size + start_idx : gate_size + end_idx],
155
+ ),
156
+ dim=0,
157
+ )
179
158
 
180
159
 
181
160
  class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
@@ -185,86 +164,46 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
185
164
  lora_backend: BaseLoRABackend,
186
165
  ) -> None:
187
166
  super().__init__(base_layer, lora_backend)
167
+ q_proj_shard_size = self.base_layer.q_proj_shard_size
168
+ kv_proj_shard_size = self.base_layer.kv_proj_shard_size
169
+ self.output_offset = torch.tensor(
170
+ [
171
+ 0,
172
+ q_proj_shard_size,
173
+ q_proj_shard_size + kv_proj_shard_size,
174
+ q_proj_shard_size + 2 * kv_proj_shard_size,
175
+ ],
176
+ dtype=torch.int32,
177
+ device=next(self.base_layer.parameters()).device,
178
+ )
179
+
180
+ # For computing number of launched blocks
181
+ self.max_qkv_out_dim = max(q_proj_shard_size, kv_proj_shard_size)
188
182
 
189
183
  def set_lora_info(
190
184
  self,
191
185
  A_buffer_qkv: torch.Tensor,
192
- B_buffer_q: torch.Tensor,
193
- B_buffer_kv: torch.Tensor,
186
+ B_buffer_qkv: torch.Tensor,
194
187
  ):
195
188
  self.set_lora = True
196
189
  self.A_buffer_qkv = A_buffer_qkv
197
-
198
- if self.lora_backend.fuse_stacked_lora_b:
199
- assert (
200
- B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
201
- ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
202
- output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
203
-
204
- # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
205
- if getattr(self, "B_buffer_qkv", None) is None:
206
- self.B_buffer_qkv = torch.empty(
207
- (
208
- B_buffer_q[0].shape[0],
209
- output_dim_q + 2 * output_dim_kv,
210
- B_buffer_q[0].shape[2],
211
- ),
212
- dtype=B_buffer_q[0].dtype,
213
- device=B_buffer_q[0].device,
214
- )
215
- self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
216
- self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
217
- B_buffer_kv[0]
218
- )
219
- self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
220
- B_buffer_kv[1]
221
- )
222
-
223
- # Offsets of q/k/v in output dimension
224
- if getattr(self, "output_offset", None) is None:
225
- self.output_offset = torch.tensor(
226
- [
227
- 0,
228
- output_dim_q,
229
- output_dim_q + output_dim_kv,
230
- output_dim_q + 2 * output_dim_kv,
231
- ],
232
- dtype=torch.int32,
233
- device=B_buffer_q.device,
234
- )
235
- # For computing number of launched blocks
236
- self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
237
- else:
238
- self.B_buffer_qkv = (
239
- B_buffer_q,
240
- B_buffer_kv,
241
- )
190
+ self.B_buffer_qkv = B_buffer_qkv
242
191
 
243
192
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
244
- backend_kwargs = {"base_output": base_output}
245
- if self.lora_backend.fuse_stacked_lora_b:
246
- backend_kwargs["output_offset"] = self.output_offset
247
- backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
248
-
249
193
  lora_output = self.lora_backend.run_qkv_lora(
250
- x,
251
- self.A_buffer_qkv,
252
- self.B_buffer_qkv,
253
- **backend_kwargs,
254
- )
255
- return (
256
- lora_output
257
- if self.lora_backend.fuse_output_add
258
- else base_output + lora_output
194
+ x=x,
195
+ qkv_lora_a=self.A_buffer_qkv,
196
+ qkv_lora_b=self.B_buffer_qkv,
197
+ base_output=base_output,
198
+ output_offset=self.output_offset,
199
+ max_qkv_out_dim=self.max_qkv_out_dim,
259
200
  )
201
+ return lora_output
260
202
 
261
203
  def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
262
204
  return A
263
205
 
264
- def slice_lora_b_weights(
265
- self, B: List[torch.Tensor], tp_rank: int
266
- ) -> Tuple[torch.Tensor, torch.Tensor]:
267
- B_q, B_kv = B
206
+ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int) -> torch.Tensor:
268
207
  base_layer = self.base_layer
269
208
  q_proj_shard_size = base_layer.q_proj_shard_size
270
209
  kv_proj_shard_size = base_layer.kv_proj_shard_size
@@ -277,7 +216,19 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
277
216
  kv_start_idx = kv_proj_shard_size * kv_shard_id
278
217
  kv_end_idx = kv_start_idx + kv_proj_shard_size
279
218
 
280
- return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]
219
+ q_size, k_size, _ = base_layer.output_sizes
220
+ B_q_shard = B[q_start_idx:q_end_idx, :]
221
+ B_k_shard = B[q_size + kv_start_idx : q_size + kv_end_idx, :]
222
+ B_v_shard = B[q_size + k_size + kv_start_idx : q_size + k_size + kv_end_idx, :]
223
+
224
+ return torch.concat(
225
+ (
226
+ B_q_shard,
227
+ B_k_shard,
228
+ B_v_shard,
229
+ ),
230
+ dim=0,
231
+ )
281
232
 
282
233
 
283
234
  class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
@@ -294,18 +245,13 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
294
245
  self.B_buffer = B_buffer
295
246
 
296
247
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
297
- backend_kwargs = {"base_output": base_output}
298
248
  lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
299
249
  lora_output = self.lora_backend.run_lora_b_sgemm(
300
- lora_a_output,
301
- self.B_buffer[0],
302
- **backend_kwargs,
303
- )
304
- return (
305
- lora_output
306
- if self.lora_backend.fuse_output_add
307
- else base_output + lora_output
250
+ x=lora_a_output,
251
+ weights=self.B_buffer,
252
+ base_output=base_output,
308
253
  )
254
+ return lora_output
309
255
 
310
256
  def forward(self, input_: torch.Tensor):
311
257
  # duplicate the logic in RowParallelLinear
sglang/srt/lora/lora.py CHANGED
@@ -117,7 +117,6 @@ class LoRAAdapter(nn.Module):
117
117
  q_name = weight_name
118
118
  k_name = weight_name.replace("q_proj", "k_proj")
119
119
  v_name = weight_name.replace("q_proj", "v_proj")
120
- kv_name = weight_name.replace("q_proj", "kv_proj")
121
120
  qkv_name = weight_name.replace("q_proj", "qkv_proj")
122
121
 
123
122
  # If k_proj doesn't have lora, initialize it to zero
@@ -126,57 +125,27 @@ class LoRAAdapter(nn.Module):
126
125
  if "k_proj" in target_module
127
126
  else torch.zeros_like(weights[v_name])
128
127
  )
129
- if "lora_A" in weight_name:
130
- weights[qkv_name] = torch.cat(
131
- (
132
- weights[q_name],
133
- k_proj_weight,
134
- weights[v_name],
135
- ),
136
- 0,
137
- )
138
- weights.pop(q_name)
139
- if "k_proj" in target_module:
140
- weights.pop(k_name)
141
- weights.pop(v_name)
142
- else:
143
- weights[kv_name] = torch.stack(
144
- [
145
- k_proj_weight,
146
- weights[v_name],
147
- ],
148
- dim=0,
149
- )
150
- if "k_proj" in target_module:
151
- weights.pop(k_name)
152
- weights.pop(v_name)
128
+ weights[qkv_name] = torch.cat(
129
+ (
130
+ weights[q_name],
131
+ k_proj_weight,
132
+ weights[v_name],
133
+ ),
134
+ 0,
135
+ )
136
+ weights.pop(q_name)
137
+ if "k_proj" in target_module:
138
+ weights.pop(k_name)
139
+ weights.pop(v_name)
153
140
  elif "qkv_proj" in weight_name:
154
141
  # If qkv_proj is already stacked, we normalize it following the SGL convention.
155
142
  qkv_name = weight_name
156
143
  q_name = weight_name.replace("qkv_proj", "q_proj")
157
144
  k_name = weight_name.replace("qkv_proj", "k_proj")
158
145
  v_name = weight_name.replace("qkv_proj", "v_proj")
159
- kv_name = weight_name.replace("qkv_proj", "kv_proj")
160
146
  if "lora_A" in weight_name:
161
147
  weights[qkv_name] = weights[qkv_name].repeat(3, 1)
162
- else:
163
- head_size = (
164
- self.base_hf_config.hidden_size
165
- // self.base_hf_config.num_attention_heads
166
- )
167
- weights[q_name], k_proj_weight, v_proj_weight = torch.split(
168
- weights[qkv_name],
169
- [
170
- head_size * self.base_hf_config.num_attention_heads,
171
- head_size * self.base_hf_config.num_key_value_heads,
172
- head_size * self.base_hf_config.num_key_value_heads,
173
- ],
174
- dim=0,
175
- )
176
- weights[kv_name] = torch.stack(
177
- [k_proj_weight, v_proj_weight],
178
- dim=0,
179
- )
148
+ # else: no-op as LoRA B weight is already stacked.
180
149
 
181
150
  def normalize_gate_up_proj(
182
151
  self, weight_names: List[str], weights: Dict[str, torch.Tensor]
@@ -187,20 +156,14 @@ class LoRAAdapter(nn.Module):
187
156
  gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
188
157
  if up_name not in weights:
189
158
  weights[up_name] = torch.zeros_like(weights[weight_name])
190
- # FIXME: Add gate-only support for flashinfer in future implementations
191
159
  assert self.lora_backend.name == "triton", (
192
160
  f"LoRA weight initialization currently only supported for 'triton' backend. "
193
161
  f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
194
162
  f"or consider implementing custom initialization logic for other backends."
195
163
  )
196
- if "lora_A" in weight_name:
197
- weights[gate_up_name] = torch.cat(
198
- (weights[weight_name], weights[up_name]), 0
199
- )
200
- else:
201
- weights[gate_up_name] = torch.stack(
202
- [weights[weight_name], weights[up_name]], dim=0
203
- )
164
+ weights[gate_up_name] = torch.cat(
165
+ (weights[weight_name], weights[up_name]), 0
166
+ )
204
167
  weights.pop(weight_name)
205
168
  if up_name in weights:
206
169
  weights.pop(up_name)
@@ -209,12 +172,4 @@ class LoRAAdapter(nn.Module):
209
172
  gate_up_name = weight_name
210
173
  if "lora_A" in weight_name:
211
174
  weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
212
- else:
213
- output_dim = weights[gate_up_name].shape[0] // 2
214
- weights[gate_up_name] = torch.stack(
215
- [
216
- weights[gate_up_name][:output_dim, :],
217
- weights[gate_up_name][output_dim:, :],
218
- ],
219
- dim=0,
220
- )
175
+ # else: no-op as LoRA B weight is already stacked.
@@ -31,7 +31,6 @@ from sglang.srt.lora.mem_pool import LoRAMemoryPool
31
31
  from sglang.srt.lora.utils import (
32
32
  LoRABatchInfo,
33
33
  LoRAType,
34
- get_customized_names_from_hf_names,
35
34
  get_layer_id,
36
35
  get_normalized_lora_weight_names,
37
36
  get_weight_name,
@@ -144,6 +143,7 @@ class LoRAManager:
144
143
 
145
144
  # keep metadata for displayed messages
146
145
  self.lora_refs[lora_ref.lora_id] = lora_ref
146
+ self.num_pinned_loras += int(lora_ref.pinned)
147
147
  except Exception as e:
148
148
  return self.create_lora_update_result(
149
149
  success=False,
@@ -157,13 +157,22 @@ class LoRAManager:
157
157
  Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
158
158
  """
159
159
 
160
+ # Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration.
160
161
  memory_pool = getattr(self, "memory_pool", None)
161
162
  incompatible = memory_pool and not memory_pool.can_support(lora_config)
162
163
  if incompatible:
163
164
  raise ValueError(
164
- f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
165
- "Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
166
- "included in `--enable_lora_modules`."
165
+ f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current "
166
+ "LoRA memory pool configuration. Please ensure that the LoRA adapter's rank is within the configured "
167
+ "`--max-lora-rank` and that the target modules are included in `--lora-target-modules`."
168
+ )
169
+
170
+ # Ensure pinned LoRA adapters does not exceed maximal limit or cause starvation.
171
+ if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
172
+ raise ValueError(
173
+ f"Failed to load LoRA adapter {lora_ref.lora_name} as a pinned adapter. It is not allowed to pin all slots "
174
+ "in the LoRA memory pool to avoid starvation for unpinned adapters and base models. Please increase your "
175
+ "`--max-loras-per-batch` or load it as unpinned LoRA adapters."
167
176
  )
168
177
 
169
178
  def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
@@ -172,15 +181,17 @@ class LoRAManager:
172
181
  delete the corresponding LoRA modules.
173
182
  """
174
183
 
175
- adapter = self.configs.get(lora_ref.lora_id, None)
184
+ adapter = self.configs.get(lora_ref.lora_id)
185
+ lora_ref = self.lora_refs.get(lora_ref.lora_id)
176
186
  assert (
177
- adapter is not None
187
+ adapter is not None and lora_ref is not None
178
188
  ), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
179
189
 
180
190
  try:
181
191
  del self.configs[lora_ref.lora_id]
182
192
  del self.loras[lora_ref.lora_id]
183
193
  del self.lora_refs[lora_ref.lora_id]
194
+ self.num_pinned_loras -= int(lora_ref.pinned)
184
195
  except Exception as e:
185
196
  return self.create_lora_update_result(
186
197
  success=False,
@@ -189,15 +200,49 @@ class LoRAManager:
189
200
 
190
201
  return self.create_lora_update_result(success=True)
191
202
 
203
+ def validate_lora_batch(self, lora_ids: set[str]) -> bool:
204
+ """
205
+ Validate if the LoRA IDs in the batch can be loaded into the current LoRA memory pool.
206
+ """
207
+ if len(lora_ids) > self.max_loras_per_batch:
208
+ return False
209
+
210
+ # skip pinned LoRA check if no pinned LoRA adapters are loaded.
211
+ if self.num_pinned_loras == 0:
212
+ return True
213
+
214
+ # counting the number of pinned LoRA adapters in the batch.
215
+ pinned_loras_in_batch = 0
216
+ for lora_id in lora_ids:
217
+ if lora_id is not None:
218
+ lora_ref = self.lora_refs.get(lora_id)
219
+ assert (
220
+ lora_ref is not None
221
+ ), f"LoRA ID {lora_id} not found in lora_refs."
222
+ pinned_loras_in_batch += int(lora_ref.pinned)
223
+
224
+ assert pinned_loras_in_batch <= self.num_pinned_loras, (
225
+ f"Number of pinned LoRA adapters in the batch ({pinned_loras_in_batch}) exceeds the total number of pinned adapters "
226
+ f"({self.num_pinned_loras}). This indicates a bug in the LoRA loading logic."
227
+ )
228
+
229
+ required_slots = len(lora_ids) - pinned_loras_in_batch
230
+ mem_pool_vacancy = self.memory_pool.max_loras_per_batch - self.num_pinned_loras
231
+
232
+ return required_slots <= mem_pool_vacancy
233
+
192
234
  def prepare_lora_batch(self, forward_batch: ForwardBatch):
235
+
193
236
  # Load active loras into lora memory pool
194
- # TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
195
- # LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
196
- # should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
197
- # the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
198
- cur_uids = set(forward_batch.lora_paths)
237
+ cur_uids = set(forward_batch.lora_ids)
238
+
199
239
  assert len(cur_uids) <= self.max_loras_per_batch
200
- self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
240
+ self.memory_pool.prepare_lora_batch(
241
+ cur_uids=cur_uids,
242
+ lora_adapters=self.loras,
243
+ lora_modules=self.lora_modules,
244
+ lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation.
245
+ )
201
246
 
202
247
  # set up batch info shared by all lora modules
203
248
  bs = forward_batch.batch_size
@@ -211,10 +256,10 @@ class LoRAManager:
211
256
  Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
212
257
  to device (CUDA) asynchronously.
213
258
  """
214
- weight_indices = [0] * len(forward_batch.lora_paths)
259
+ weight_indices = [0] * len(forward_batch.lora_ids)
215
260
  lora_ranks = [0] * self.max_loras_per_batch
216
261
  scalings = [0] * self.max_loras_per_batch
217
- for i, uid in enumerate(forward_batch.lora_paths):
262
+ for i, uid in enumerate(forward_batch.lora_ids):
218
263
  weight_indices[i] = self.memory_pool.get_buffer_id(uid)
219
264
  if uid is not None:
220
265
  lora = self.loras[uid]
@@ -299,40 +344,19 @@ class LoRAManager:
299
344
  )
300
345
  self.lora_backend.set_batch_info(batch_info)
301
346
 
302
- # TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
303
- # this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
304
- self.update_lora_info()
305
-
306
347
  def update_lora_info(self):
307
348
  """
308
349
  Update all LoRA modules to associate them with the latest memory buffer.
309
350
  """
310
351
  for layer_id, layer_modules in enumerate(self.lora_modules):
311
352
  for module_name, module in layer_modules.items():
312
- if "qkv_proj" in module_name:
313
- module.set_lora_info(
314
- self.memory_pool.get_tensor(
315
- "qkv_proj", layer_id, LoRAType.LORA_A
316
- ),
317
- self.memory_pool.get_tensor(
318
- "q_proj", layer_id, LoRAType.LORA_B
319
- ),
320
- self.memory_pool.get_tensor(
321
- "kv_proj", layer_id, LoRAType.LORA_B
322
- ),
323
- )
324
- else:
325
- weight_name = get_weight_name(
326
- module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
327
- )
328
- module.set_lora_info(
329
- self.memory_pool.get_tensor(
330
- weight_name, layer_id, LoRAType.LORA_A
331
- ),
332
- self.memory_pool.get_tensor(
333
- weight_name, layer_id, LoRAType.LORA_B
334
- ),
335
- )
353
+ weight_name = get_weight_name(
354
+ module_name, self.memory_pool.lora_weight_names
355
+ )
356
+ module.set_lora_info(
357
+ self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
358
+ self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
359
+ )
336
360
 
337
361
  def init_state(
338
362
  self,
@@ -359,6 +383,7 @@ class LoRAManager:
359
383
  self.init_lora_weight_names()
360
384
  self.init_lora_modules()
361
385
  self.init_memory_pool()
386
+ self.update_lora_info()
362
387
 
363
388
  def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
364
389
  # Configs of all active LoRA adapters, indexed by LoRA ID.
@@ -370,6 +395,9 @@ class LoRAManager:
370
395
  # Mapping from LoRA ID to LoRARef object.
371
396
  self.lora_refs: Dict[str, LoRARef] = {}
372
397
 
398
+ # Count of pinned LoRA adapters.
399
+ self.num_pinned_loras: int = 0
400
+
373
401
  if lora_paths:
374
402
  for lora_ref in lora_paths.values():
375
403
  result = self.load_lora_adapter(lora_ref)
@@ -390,13 +418,20 @@ class LoRAManager:
390
418
  else:
391
419
  self.target_modules = set()
392
420
  for config in self.configs.values():
421
+ if not isinstance(config.target_modules, list):
422
+ raise ValueError(
423
+ f"SGLang currently only supports inferring LoRA target modules when a list of "
424
+ "suffixes is provided in `target_modules` field of PEFT config. Please explicitly "
425
+ "specify `--lora-target-modules` during server startup. You can specify `all` to "
426
+ "enable all support modules types. "
427
+ )
393
428
  self.target_modules.update(config.target_modules)
394
429
 
395
430
  if max_lora_rank is not None:
396
431
  self.max_lora_rank = max_lora_rank
397
432
  else:
398
433
  self.max_lora_rank = max(
399
- [x.hf_config["r"] for x in self.configs.values()],
434
+ [x.r for x in self.configs.values()],
400
435
  default=0,
401
436
  )
402
437
 
@@ -405,9 +440,9 @@ class LoRAManager:
405
440
  Add new LoRA weight names if needed based on the current `self.configs`.
406
441
  """
407
442
 
408
- # Target lora weight names for lora_a and lora_b modules respectively.
409
- lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
410
- self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
443
+ self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
444
+ self.target_modules
445
+ )
411
446
 
412
447
  def load_lora_weights(self, lora_ref: LoRARef):
413
448
  """
@@ -423,15 +458,6 @@ class LoRAManager:
423
458
  lora_adapter.initialize_weights()
424
459
  self.loras[lora_ref.lora_id] = lora_adapter
425
460
 
426
- # Additional checks for flashinfer backend
427
- # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
428
- if self.lora_backend == "flashinfer":
429
- lora_dims = set(x.r for x in self.configs.values())
430
- scalings = set(x.scaling for x in self.loras.values())
431
- assert (
432
- len(lora_dims) == 1 and len(scalings) == 1
433
- ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
434
-
435
461
  def init_memory_pool(self):
436
462
  """(Re)initialize the LoRA memory pool based on the current configurations."""
437
463
  self.memory_pool = LoRAMemoryPool(
@@ -456,12 +482,6 @@ class LoRAManager:
456
482
  {} for _ in range(self.base_hf_config.num_hidden_layers)
457
483
  ]
458
484
 
459
- # Target module names of customized layers defined in python/sglang/srt/layers
460
- # e.g., {"qkv_proj", "o_proj"}
461
- customized_target_names = get_customized_names_from_hf_names(
462
- self.target_modules, self.base_model
463
- )
464
-
465
485
  for module_name, module in self.base_model.named_modules():
466
486
  # TODO (lifuhuang): in the future, we should consider generalizing the
467
487
  # should_apply_lora function to support mapping by full module name instead
@@ -474,7 +494,7 @@ class LoRAManager:
474
494
  continue
475
495
 
476
496
  # The module should be converted if it is included in target_names
477
- if module_name.split(".")[-1] in customized_target_names:
497
+ if module_name.split(".")[-1] in self.lora_weight_names:
478
498
  layer_id = get_layer_id(module_name)
479
499
  self.lora_modules[layer_id][module_name] = self.set_lora_module(
480
500
  module_name, module