sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora.py CHANGED
@@ -117,7 +117,6 @@ class LoRAAdapter(nn.Module):
117
117
  q_name = weight_name
118
118
  k_name = weight_name.replace("q_proj", "k_proj")
119
119
  v_name = weight_name.replace("q_proj", "v_proj")
120
- kv_name = weight_name.replace("q_proj", "kv_proj")
121
120
  qkv_name = weight_name.replace("q_proj", "qkv_proj")
122
121
 
123
122
  # If k_proj doesn't have lora, initialize it to zero
@@ -126,57 +125,27 @@ class LoRAAdapter(nn.Module):
126
125
  if "k_proj" in target_module
127
126
  else torch.zeros_like(weights[v_name])
128
127
  )
129
- if "lora_A" in weight_name:
130
- weights[qkv_name] = torch.cat(
131
- (
132
- weights[q_name],
133
- k_proj_weight,
134
- weights[v_name],
135
- ),
136
- 0,
137
- )
138
- weights.pop(q_name)
139
- if "k_proj" in target_module:
140
- weights.pop(k_name)
141
- weights.pop(v_name)
142
- else:
143
- weights[kv_name] = torch.stack(
144
- [
145
- k_proj_weight,
146
- weights[v_name],
147
- ],
148
- dim=0,
149
- )
150
- if "k_proj" in target_module:
151
- weights.pop(k_name)
152
- weights.pop(v_name)
128
+ weights[qkv_name] = torch.cat(
129
+ (
130
+ weights[q_name],
131
+ k_proj_weight,
132
+ weights[v_name],
133
+ ),
134
+ 0,
135
+ )
136
+ weights.pop(q_name)
137
+ if "k_proj" in target_module:
138
+ weights.pop(k_name)
139
+ weights.pop(v_name)
153
140
  elif "qkv_proj" in weight_name:
154
141
  # If qkv_proj is already stacked, we normalize it following the SGL convention.
155
142
  qkv_name = weight_name
156
143
  q_name = weight_name.replace("qkv_proj", "q_proj")
157
144
  k_name = weight_name.replace("qkv_proj", "k_proj")
158
145
  v_name = weight_name.replace("qkv_proj", "v_proj")
159
- kv_name = weight_name.replace("qkv_proj", "kv_proj")
160
146
  if "lora_A" in weight_name:
161
147
  weights[qkv_name] = weights[qkv_name].repeat(3, 1)
162
- else:
163
- head_size = (
164
- self.base_hf_config.hidden_size
165
- // self.base_hf_config.num_attention_heads
166
- )
167
- weights[q_name], k_proj_weight, v_proj_weight = torch.split(
168
- weights[qkv_name],
169
- [
170
- head_size * self.base_hf_config.num_attention_heads,
171
- head_size * self.base_hf_config.num_key_value_heads,
172
- head_size * self.base_hf_config.num_key_value_heads,
173
- ],
174
- dim=0,
175
- )
176
- weights[kv_name] = torch.stack(
177
- [k_proj_weight, v_proj_weight],
178
- dim=0,
179
- )
148
+ # else: no-op as LoRA B weight is already stacked.
180
149
 
181
150
  def normalize_gate_up_proj(
182
151
  self, weight_names: List[str], weights: Dict[str, torch.Tensor]
@@ -187,20 +156,14 @@ class LoRAAdapter(nn.Module):
187
156
  gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
188
157
  if up_name not in weights:
189
158
  weights[up_name] = torch.zeros_like(weights[weight_name])
190
- # FIXME: Add gate-only support for flashinfer in future implementations
191
159
  assert self.lora_backend.name == "triton", (
192
160
  f"LoRA weight initialization currently only supported for 'triton' backend. "
193
161
  f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
194
162
  f"or consider implementing custom initialization logic for other backends."
195
163
  )
196
- if "lora_A" in weight_name:
197
- weights[gate_up_name] = torch.cat(
198
- (weights[weight_name], weights[up_name]), 0
199
- )
200
- else:
201
- weights[gate_up_name] = torch.stack(
202
- [weights[weight_name], weights[up_name]], dim=0
203
- )
164
+ weights[gate_up_name] = torch.cat(
165
+ (weights[weight_name], weights[up_name]), 0
166
+ )
204
167
  weights.pop(weight_name)
205
168
  if up_name in weights:
206
169
  weights.pop(up_name)
@@ -209,12 +172,4 @@ class LoRAAdapter(nn.Module):
209
172
  gate_up_name = weight_name
210
173
  if "lora_A" in weight_name:
211
174
  weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
212
- else:
213
- output_dim = weights[gate_up_name].shape[0] // 2
214
- weights[gate_up_name] = torch.stack(
215
- [
216
- weights[gate_up_name][:output_dim, :],
217
- weights[gate_up_name][output_dim:, :],
218
- ],
219
- dim=0,
220
- )
175
+ # else: no-op as LoRA B weight is already stacked.
@@ -31,7 +31,6 @@ from sglang.srt.lora.mem_pool import LoRAMemoryPool
31
31
  from sglang.srt.lora.utils import (
32
32
  LoRABatchInfo,
33
33
  LoRAType,
34
- get_customized_names_from_hf_names,
35
34
  get_layer_id,
36
35
  get_normalized_lora_weight_names,
37
36
  get_weight_name,
@@ -345,40 +344,19 @@ class LoRAManager:
345
344
  )
346
345
  self.lora_backend.set_batch_info(batch_info)
347
346
 
348
- # TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
349
- # this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
350
- self.update_lora_info()
351
-
352
347
  def update_lora_info(self):
353
348
  """
354
349
  Update all LoRA modules to associate them with the latest memory buffer.
355
350
  """
356
351
  for layer_id, layer_modules in enumerate(self.lora_modules):
357
352
  for module_name, module in layer_modules.items():
358
- if "qkv_proj" in module_name:
359
- module.set_lora_info(
360
- self.memory_pool.get_tensor(
361
- "qkv_proj", layer_id, LoRAType.LORA_A
362
- ),
363
- self.memory_pool.get_tensor(
364
- "q_proj", layer_id, LoRAType.LORA_B
365
- ),
366
- self.memory_pool.get_tensor(
367
- "kv_proj", layer_id, LoRAType.LORA_B
368
- ),
369
- )
370
- else:
371
- weight_name = get_weight_name(
372
- module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
373
- )
374
- module.set_lora_info(
375
- self.memory_pool.get_tensor(
376
- weight_name, layer_id, LoRAType.LORA_A
377
- ),
378
- self.memory_pool.get_tensor(
379
- weight_name, layer_id, LoRAType.LORA_B
380
- ),
381
- )
353
+ weight_name = get_weight_name(
354
+ module_name, self.memory_pool.lora_weight_names
355
+ )
356
+ module.set_lora_info(
357
+ self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
358
+ self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
359
+ )
382
360
 
383
361
  def init_state(
384
362
  self,
@@ -405,6 +383,7 @@ class LoRAManager:
405
383
  self.init_lora_weight_names()
406
384
  self.init_lora_modules()
407
385
  self.init_memory_pool()
386
+ self.update_lora_info()
408
387
 
409
388
  def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
410
389
  # Configs of all active LoRA adapters, indexed by LoRA ID.
@@ -461,9 +440,9 @@ class LoRAManager:
461
440
  Add new LoRA weight names if needed based on the current `self.configs`.
462
441
  """
463
442
 
464
- # Target lora weight names for lora_a and lora_b modules respectively.
465
- lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
466
- self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
443
+ self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
444
+ self.target_modules
445
+ )
467
446
 
468
447
  def load_lora_weights(self, lora_ref: LoRARef):
469
448
  """
@@ -479,15 +458,6 @@ class LoRAManager:
479
458
  lora_adapter.initialize_weights()
480
459
  self.loras[lora_ref.lora_id] = lora_adapter
481
460
 
482
- # Additional checks for flashinfer backend
483
- # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
484
- if self.lora_backend == "flashinfer":
485
- lora_dims = set(x.r for x in self.configs.values())
486
- scalings = set(x.scaling for x in self.loras.values())
487
- assert (
488
- len(lora_dims) == 1 and len(scalings) == 1
489
- ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
490
-
491
461
  def init_memory_pool(self):
492
462
  """(Re)initialize the LoRA memory pool based on the current configurations."""
493
463
  self.memory_pool = LoRAMemoryPool(
@@ -512,12 +482,6 @@ class LoRAManager:
512
482
  {} for _ in range(self.base_hf_config.num_hidden_layers)
513
483
  ]
514
484
 
515
- # Target module names of customized layers defined in python/sglang/srt/layers
516
- # e.g., {"qkv_proj", "o_proj"}
517
- customized_target_names = get_customized_names_from_hf_names(
518
- self.target_modules, self.base_model
519
- )
520
-
521
485
  for module_name, module in self.base_model.named_modules():
522
486
  # TODO (lifuhuang): in the future, we should consider generalizing the
523
487
  # should_apply_lora function to support mapping by full module name instead
@@ -530,7 +494,7 @@ class LoRAManager:
530
494
  continue
531
495
 
532
496
  # The module should be converted if it is included in target_names
533
- if module_name.split(".")[-1] in customized_target_names:
497
+ if module_name.split(".")[-1] in self.lora_weight_names:
534
498
  layer_id = get_layer_id(module_name)
535
499
  self.lora_modules[layer_id][module_name] = self.set_lora_module(
536
500
  module_name, module
@@ -14,7 +14,6 @@
14
14
 
15
15
 
16
16
  import asyncio
17
- from collections import defaultdict
18
17
  from dataclasses import dataclass, field, fields
19
18
  from typing import Dict, List, Optional, Union
20
19
  from uuid import uuid4
@@ -106,7 +105,6 @@ class LoRARegistry:
106
105
  f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
107
106
  )
108
107
  del self._registry[lora_name]
109
- del self._counters[lora_ref.lora_id]
110
108
 
111
109
  return lora_ref.lora_id
112
110
 
@@ -117,6 +115,9 @@ class LoRARegistry:
117
115
  """
118
116
 
119
117
  def _lookup(name: str) -> str:
118
+ if name is None:
119
+ return None
120
+
120
121
  lora_ref = self._registry.get(name, None)
121
122
  if lora_ref is None:
122
123
  raise ValueError(
@@ -135,7 +136,11 @@ class LoRARegistry:
135
136
 
136
137
  # Increment the counters only after all IDs are looked up.
137
138
  await asyncio.gather(
138
- *[self._counters[id].increment(notify_all=False) for id in lora_ids]
139
+ *[
140
+ self._counters[id].increment(notify_all=False)
141
+ for id in lora_ids
142
+ if id is not None
143
+ ]
139
144
  )
140
145
  return lora_ids
141
146
  else:
@@ -153,7 +158,11 @@ class LoRARegistry:
153
158
  await self._counters[lora_id].decrement()
154
159
  elif isinstance(lora_id, list):
155
160
  await asyncio.gather(
156
- *[self._counters[id].decrement() for id in lora_id]
161
+ *[
162
+ self._counters[id].decrement()
163
+ for id in lora_id
164
+ if id is not None
165
+ ]
157
166
  )
158
167
  else:
159
168
  raise TypeError("lora_id must be either a string or a list of strings.")
@@ -169,11 +178,13 @@ class LoRARegistry:
169
178
  assert (
170
179
  lora_id not in self._registry
171
180
  ), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
172
- counter = self._counters.get(lora_id)
173
- if counter:
174
- # Wait until no requests are using this LoRA adapter.
175
- await counter.wait_for_zero()
176
- del self._counters[lora_id]
181
+ assert (
182
+ lora_id in self._counters
183
+ ), "The LoRA ID should still have a counter if it has been registered before."
184
+
185
+ # Wait until no requests are using this LoRA adapter.
186
+ await self._counters[lora_id].wait_for_zero()
187
+ del self._counters[lora_id]
177
188
 
178
189
  def _register_adapter(self, lora_ref: LoRARef):
179
190
  """
@@ -52,7 +52,7 @@ class LoRAMemoryPool:
52
52
  tp_size: int,
53
53
  tp_rank: int,
54
54
  max_lora_rank: int,
55
- lora_weight_names: Tuple[Set[str], Set[str]],
55
+ lora_weight_names: Set[str],
56
56
  base_model: torch.nn.Module,
57
57
  ):
58
58
  self.base_hf_config: AutoConfig = base_hf_config
@@ -62,9 +62,7 @@ class LoRAMemoryPool:
62
62
  self.tp_size: int = tp_size
63
63
  self.tp_rank: int = tp_rank
64
64
  self.max_lora_rank: int = max_lora_rank
65
-
66
- # lora weight names for LoRA A and B respectively.
67
- self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
65
+ self.lora_weight_names: Set[str] = lora_weight_names
68
66
 
69
67
  # Both A_buffer and B_buffer maps lora weight names to its buffer space.
70
68
  # A_buffer contains num_layer number of row-major tensors with shape
@@ -97,12 +95,8 @@ class LoRAMemoryPool:
97
95
  """
98
96
  if config.r > self.max_lora_rank:
99
97
  return False
100
- weights_a, weights_b = get_normalized_lora_weight_names(
101
- config.target_modules
102
- )
103
- return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
104
- self.lora_weight_names[1]
105
- )
98
+ weights = get_normalized_lora_weight_names(config.target_modules)
99
+ return weights.issubset(self.lora_weight_names)
106
100
 
107
101
  if isinstance(config, LoRAConfig):
108
102
  return _can_support(config)
@@ -132,11 +126,9 @@ class LoRAMemoryPool:
132
126
  Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
133
127
  """
134
128
  _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
135
- c = get_stacked_multiply(module_name)
136
129
  if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
137
130
  output_dim = divide(output_dim, self.tp_size)
138
131
  return (
139
- c,
140
132
  self.max_loras_per_batch,
141
133
  output_dim,
142
134
  max_lora_dim,
@@ -165,13 +157,13 @@ class LoRAMemoryPool:
165
157
 
166
158
  init_buffer(
167
159
  self.A_buffer,
168
- self.lora_weight_names[0],
160
+ self.lora_weight_names,
169
161
  self.get_lora_A_shape,
170
162
  )
171
163
 
172
164
  init_buffer(
173
165
  self.B_buffer,
174
- self.lora_weight_names[1],
166
+ self.lora_weight_names,
175
167
  self.get_lora_B_shape,
176
168
  )
177
169
 
@@ -246,7 +238,7 @@ class LoRAMemoryPool:
246
238
  return
247
239
 
248
240
  assert lora_adapter is not None
249
- lora_rank = lora_adapter.config.hf_config["r"]
241
+ lora_rank = lora_adapter.config.r
250
242
  for layer_id in range(self.num_layer):
251
243
  layer_weights = lora_adapter.layers[layer_id].weights
252
244
  temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
@@ -256,73 +248,38 @@ class LoRAMemoryPool:
256
248
  weight_name: None for weight_name in self.B_buffer
257
249
  }
258
250
  for name, weights in layer_weights.items():
251
+ lora_weight_name = get_weight_name(name, self.lora_weight_names)
259
252
  if "lora_A" in name:
260
- lora_weight_name = get_weight_name(
261
- name, self.lora_weight_names, LoRAType.LORA_A
262
- )
263
253
  temp_A_buffer[lora_weight_name] = weights
264
254
  else:
265
- lora_weight_name = get_weight_name(
266
- name, self.lora_weight_names, LoRAType.LORA_B
267
- )
268
255
  temp_B_buffer[lora_weight_name] = weights
269
256
 
270
257
  if self.tp_size > 1:
271
258
  cur_layer_modules = lora_modules[layer_id]
272
259
  for module_name, module in cur_layer_modules.items():
273
- weight_name = get_weight_name(
274
- module_name, self.lora_weight_names, LoRAType.LORA_A
275
- )
260
+ weight_name = get_weight_name(module_name, self.lora_weight_names)
276
261
 
277
262
  if temp_A_buffer[weight_name] is None:
278
263
  # Skip weight slicing if the weight is not present in the adapter
279
264
  continue
280
265
 
281
- if "qkv_proj" in module_name:
282
- temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
283
- temp_A_buffer["qkv_proj"], self.tp_rank
284
- )
285
- temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = (
286
- module.slice_lora_b_weights(
287
- [temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
288
- self.tp_rank,
289
- )
290
- )
291
- else:
292
- # TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B.
293
- # Currently, we're reusing A's weight name as a workaround, relying on the fact that A and
294
- # B share the same name except for `qkv_proj`. We should clean this up once we deprecate the
295
- # FlashInfer LoRA backend.
296
- temp_A_buffer[weight_name] = module.slice_lora_a_weights(
297
- temp_A_buffer[weight_name], self.tp_rank
298
- )
299
- temp_B_buffer[weight_name] = module.slice_lora_b_weights(
300
- temp_B_buffer[weight_name], self.tp_rank
301
- )
266
+ temp_A_buffer[weight_name] = module.slice_lora_a_weights(
267
+ temp_A_buffer[weight_name], self.tp_rank
268
+ )
269
+ temp_B_buffer[weight_name] = module.slice_lora_b_weights(
270
+ temp_B_buffer[weight_name], self.tp_rank
271
+ )
302
272
 
303
273
  for name, weights in temp_A_buffer.items():
304
274
  c = get_stacked_multiply(name)
305
- buffer_view = self.A_buffer[name][layer_id][buffer_id][
306
- : lora_rank * c, :
307
- ]
275
+ target_buffer = self.A_buffer[name][layer_id]
276
+ buffer_view = target_buffer[buffer_id, : lora_rank * c, :]
308
277
  load_lora_weight_tensor(buffer_view, weights)
309
278
 
310
279
  for name, weights in temp_B_buffer.items():
311
- c = get_stacked_multiply(name)
312
- if c > 1:
313
- for stacked_id in range(c):
314
- buffer_view = self.B_buffer[name][layer_id][stacked_id][
315
- buffer_id
316
- ][:, :lora_rank]
317
- weight_slice = (
318
- weights[stacked_id] if weights is not None else None
319
- )
320
- load_lora_weight_tensor(buffer_view, weight_slice)
321
- else:
322
- buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
323
- :, :lora_rank
324
- ]
325
- load_lora_weight_tensor(buffer_view, weights)
280
+ target_buffer = self.B_buffer[name][layer_id]
281
+ buffer_view = target_buffer[buffer_id, :, :lora_rank]
282
+ load_lora_weight_tensor(buffer_view, weights)
326
283
 
327
284
  def get_tensor(
328
285
  self, weight_name: str, layer_id: int, lora_type: LoRAType
@@ -119,7 +119,7 @@ def _qkv_lora_b_kernel(
119
119
  output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
120
120
  s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
121
121
  )
122
- output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
122
+ output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < n_size)
123
123
  partial_sum += tl.load(output_ptr, mask=output_mask)
124
124
  tl.store(output_ptr, partial_sum, mask=output_mask)
125
125
 
sglang/srt/lora/utils.py CHANGED
@@ -47,34 +47,6 @@ def get_layer_id(name: str) -> int:
47
47
  return int(match.group(1))
48
48
 
49
49
 
50
- def get_customized_names_from_hf_names(
51
- hf_module_names: Set[str], base_model: torch.nn.Module
52
- ) -> Set[str]:
53
- """
54
- This function takes in a set of huggingface style module names:
55
- e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
56
- and outputs a set of module names of customized sglang layers:
57
- e.g., {"qkv_proj", "o_proj"}
58
- """
59
- if hasattr(base_model, "get_module_name"):
60
- return {base_model.get_module_name(name) for name in hf_module_names}
61
- else:
62
- """
63
- Fallback solution of mapping from config module name to module name in model class.
64
- Please check if it aligns with your base model.
65
- Please implement the function in the model class if it is not.
66
- You can reference this function in llama.py.
67
- """
68
- params_mapping = {
69
- "q_proj": "qkv_proj",
70
- "k_proj": "qkv_proj",
71
- "v_proj": "qkv_proj",
72
- "gate_proj": "gate_up_proj",
73
- "up_proj": "gate_up_proj",
74
- }
75
- return {params_mapping.get(name, name) for name in hf_module_names}
76
-
77
-
78
50
  def get_hidden_dim(
79
51
  module_name: str, config: AutoConfig, base_model: torch.nn.Module
80
52
  ) -> Tuple[int]:
@@ -92,14 +64,20 @@ def get_hidden_dim(
92
64
  Please implement the function in the model class if it is not.
93
65
  You can reference this function in llama.py.
94
66
  """
95
- if module_name in ["q_proj", "o_proj", "qkv_proj"]:
96
- return config.hidden_size, config.hidden_size
97
- elif module_name in ["kv_proj"]:
98
- return config.hidden_size, config.hidden_size // (
99
- config.num_attention_heads // config.num_key_value_heads
67
+ head_dim = getattr(
68
+ config, "head_dim", config.hidden_size // config.num_attention_heads
69
+ )
70
+ if module_name == "qkv_proj":
71
+ return config.hidden_size, head_dim * (
72
+ config.num_attention_heads + config.num_key_value_heads * 2
73
+ )
74
+ elif module_name == "o_proj":
75
+ return (
76
+ head_dim * config.num_attention_heads,
77
+ config.hidden_size,
100
78
  )
101
79
  elif module_name == "gate_up_proj":
102
- return config.hidden_size, config.intermediate_size
80
+ return config.hidden_size, config.intermediate_size * 2
103
81
  elif module_name == "down_proj":
104
82
  return config.intermediate_size, config.hidden_size
105
83
  else:
@@ -108,26 +86,22 @@ def get_hidden_dim(
108
86
 
109
87
  def get_normalized_lora_weight_names(
110
88
  target_modules: Iterable[str],
111
- ) -> Tuple[set[str], set[str]]:
89
+ ) -> set[str]:
112
90
  """
113
91
  Mapping a list of target module name to names of the normalized LoRA weights.
114
- Returned tuple contains (name for Lora A, name for Lora B)
115
92
  """
116
93
  params_mapping = {
117
- "q_proj": (["qkv_proj"], ["q_proj"]),
118
- "k_proj": (["qkv_proj"], ["kv_proj"]),
119
- "v_proj": (["qkv_proj"], ["kv_proj"]),
120
- "gate_proj": (["gate_up_proj"], ["gate_up_proj"]),
121
- "up_proj": (["gate_up_proj"], ["gate_up_proj"]),
122
- "qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
123
- "gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
94
+ "q_proj": "qkv_proj",
95
+ "k_proj": "qkv_proj",
96
+ "v_proj": "qkv_proj",
97
+ "gate_proj": "gate_up_proj",
98
+ "up_proj": "gate_up_proj",
124
99
  }
125
100
 
126
- result = (set(), set())
101
+ result = set()
127
102
  for name in target_modules:
128
- lora_a, lora_b = params_mapping.get(name, ([name], [name]))
129
- result[0].update(lora_a)
130
- result[1].update(lora_b)
103
+ weight_name = params_mapping.get(name, name)
104
+ result.add(weight_name)
131
105
  return result
132
106
 
133
107
 
@@ -137,23 +111,21 @@ def get_stacked_multiply(module_name: str) -> int:
137
111
  """
138
112
  stacked_rank = {
139
113
  "qkv_proj": 3,
140
- "kv_proj": 2,
141
114
  "gate_up_proj": 2,
142
115
  }
143
116
  return stacked_rank[module_name] if module_name in stacked_rank else 1
144
117
 
145
118
 
146
119
  def get_weight_name(
147
- target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType
120
+ target_name: str, lora_weight_names: Tuple[Set[str]]
148
121
  ) -> Optional[str]:
149
122
  """
150
- target_name is name of a given module,
151
- lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
123
+ Get the weight name in lora_weight_names that can match target_name.
124
+
152
125
  If there is a weight name in lora_weight_names that can match target_name, return this name
153
126
  Else raise ValueError.
154
127
  """
155
- idx = 0 if lora_type == LoRAType.LORA_A else 1
156
- for weight_name in lora_weight_names[idx]:
128
+ for weight_name in lora_weight_names:
157
129
  if weight_name in target_name:
158
130
  return weight_name
159
131
  raise ValueError(
@@ -161,9 +133,4 @@ def get_weight_name(
161
133
  )
162
134
 
163
135
 
164
- # TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
165
- VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
166
- COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
167
- MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
168
- QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
169
136
  ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]